[mlpack-svn] r15001 - in mlpack/tags: . mlpack-1.0.1/src/mlpack/methods/kmeans mlpack-1.0.5 mlpack-1.0.5/src/mlpack mlpack-1.0.5/src/mlpack/core/data mlpack-1.0.5/src/mlpack/core/dists mlpack-1.0.5/src/mlpack/core/kernels mlpack-1.0.5/src/mlpack/core/math mlpack-1.0.5/src/mlpack/core/metrics mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs mlpack-1.0.5/src/mlpack/core/optimizers/lrsdp mlpack-1.0.5/src/mlpack/core/optimizers/sgd mlpack-1.0.5/src/mlpack/core/tree mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree mlpack-1.0.5/src/mlpack/core/tree/cover_tree mlpack-1.0.5/src/mlpack/core/util mlpack-1.0.5/src/mlpack/methods/det mlpack-1.0.5/src/mlpack/methods/emst mlpack-1.0.5/src/mlpack/methods/fastmks mlpack-1.0.5/src/mlpack/methods/gmm mlpack-1.0.5/src/mlpack/methods/hmm mlpack-1.0.5/src/mlpack/methods/kernel_pca mlpack-1.0.5/src/mlpack/methods/kmeans mlpack-1.0.5/src/mlpack/methods/lars mlpack-1.0.5/src/mlpack/methods/linear_re gression mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding mlpack-1.0.5/src/mlpack/methods/lsh mlpack-1.0.5/src/mlpack/methods/mvu mlpack-1.0.5/src/mlpack/methods/naive_bayes mlpack-1.0.5/src/mlpack/methods/nca mlpack-1.0.5/src/mlpack/methods/neighbor_search mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies mlpack-1.0.5/src/mlpack/methods/nmf mlpack-1.0.5/src/mlpack/methods/pca mlpack-1.0.5/src/mlpack/methods/radical mlpack-1.0.5/src/mlpack/methods/range_search mlpack-1.0.5/src/mlpack/methods/rann mlpack-1.0.5/src/mlpack/methods/sparse_coding mlpack-1.0.5/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu May 2 00:20:15 EDT 2013
Author: rcurtin
Date: 2013-05-02 00:20:13 -0400 (Thu, 02 May 2013)
New Revision: 15001
Added:
mlpack/tags/mlpack-1.0.5/
mlpack/tags/mlpack-1.0.5/Doxyfile
mlpack/tags/mlpack-1.0.5/HISTORY.txt
mlpack/tags/mlpack-1.0.5/src/mlpack/core.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/load.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/load_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/save.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/save_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/discrete_distribution.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/discrete_distribution.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/gaussian_distribution.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/gaussian_distribution.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/cosine_distance.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/cosine_distance_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/epanechnikov_kernel.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/epanechnikov_kernel.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/epanechnikov_kernel_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/example_kernel.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/gaussian_kernel.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/hyperbolic_tangent_kernel.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/laplacian_kernel.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/linear_kernel.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/polynomial_kernel.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/pspectrum_string_kernel.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/pspectrum_string_kernel.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/pspectrum_string_kernel_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/spherical_kernel.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/triangular_kernel.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/clamp.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/lin_alg.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/lin_alg.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/random.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/random.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/range.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/range_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/round.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/lmetric.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/lmetric_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/mahalanobis_distance.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/mahalanobis_distance_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/lbfgs.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/lbfgs_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/test_functions.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/test_functions.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lrsdp/lrsdp.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lrsdp/lrsdp_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/sgd.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/sgd_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/test_function.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/test_function.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/ballbound.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/ballbound_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/single_tree_traverser.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/traits.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/bounds.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/cover_tree.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/first_point_is_root.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/single_tree_traverser.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/traits.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/hrectbound.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/hrectbound_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/mrkd_statistic.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/mrkd_statistic.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/mrkd_statistic_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/periodichrectbound.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/periodichrectbound_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/statistic.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/tree_traits.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli_deleter.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli_deleter.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/log.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/log.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/nulloutstream.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/option.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/option.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/option_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/prefixedoutstream.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/prefixedoutstream.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/prefixedoutstream_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/save_restore_utility.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/save_restore_utility.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/save_restore_utility_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/sfinae_utility.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/string_util.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/string_util.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/timers.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/timers.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/det_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dt_utils.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dt_utils.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dtree.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dtree.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb_rules.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb_rules_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/edge_pair.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/emst_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/union_find.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_rules.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_stat.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/ip_metric.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/ip_metric_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/em_fit.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/em_fit_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/gmm.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/gmm_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/gmm_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/phi.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_generate_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_loglik_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_train_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_util.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_util_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_viterbi_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kernel_pca/kernel_pca.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kernel_pca/kernel_pca_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kernel_pca/kernel_pca_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/allow_empty_clusters.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/kmeans.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/kmeans_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/kmeans_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/max_variance_new_cluster.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/random_partition.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/refined_start.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/refined_start_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lars/lars.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lars/lars.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lars/lars_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/linear_regression/linear_regression.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/linear_regression/linear_regression.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/linear_regression/linear_regression_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding/lcc.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding/lcc_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lsh/lsh_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lsh/lsh_search.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lsh/lsh_search_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/mvu/mvu.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/mvu/mvu.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/mvu/mvu_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/naive_bayes/nbc_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_softmax_error_function.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_softmax_error_function_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/allkfn_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/allknn_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/typedef.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/unmap.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/unmap.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/als_update_rules.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/mult_dist_update_rules.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/mult_div_update_rules.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/nmf.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/nmf_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/nmf_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/random_acol_init.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/random_init.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/pca/pca.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/pca/pca.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/pca/pca_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/radical/radical.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/radical/radical.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/radical/radical_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/range_search/range_search.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/range_search/range_search_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/range_search/range_search_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/allkrann_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search_rules.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search_rules_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_typedef.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/data_dependent_random_initializer.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/nothing_initializer.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/random_initializer.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/sparse_coding.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/sparse_coding_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/sparse_coding_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/allkfn_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/allknn_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/allkrann_search_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/arma_extend_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/aug_lagrangian_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/cli_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/det_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/distribution_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/emst_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/fastmks_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/gmm_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/hmm_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/kernel_pca_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/kernel_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/kmeans_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lars_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lbfgs_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lin_alg_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/linear_regression_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/load_save_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/local_coordinate_coding_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lrsdp_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lsh_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/math_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/mlpack_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/nbc_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/nca_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/nmf_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/old_boost_test_definitions.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/pca_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/radical_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/range_search_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/save_restore_utility_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/sgd_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/sort_policy_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/sparse_coding_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/tree_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/tree_traits_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/union_find_test.cpp
Removed:
mlpack/tags/mlpack-1.0.5/Doxyfile
mlpack/tags/mlpack-1.0.5/HISTORY.txt
mlpack/tags/mlpack-1.0.5/src/mlpack/core.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/load.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/load_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/save.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/save_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/discrete_distribution.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/discrete_distribution.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/gaussian_distribution.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/gaussian_distribution.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/cosine_distance.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/cosine_distance_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/epanechnikov_kernel.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/epanechnikov_kernel.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/epanechnikov_kernel_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/example_kernel.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/gaussian_kernel.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/hyperbolic_tangent_kernel.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/laplacian_kernel.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/linear_kernel.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/polynomial_kernel.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/pspectrum_string_kernel.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/pspectrum_string_kernel.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/pspectrum_string_kernel_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/spherical_kernel.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/triangular_kernel.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/clamp.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/lin_alg.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/lin_alg.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/random.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/random.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/range.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/range_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/round.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/lmetric.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/lmetric_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/mahalanobis_distance.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/mahalanobis_distance_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/lbfgs.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/lbfgs_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/test_functions.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/test_functions.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lrsdp/lrsdp.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lrsdp/lrsdp_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/sgd.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/sgd_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/test_function.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/test_function.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/ballbound.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/ballbound_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/single_tree_traverser.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/traits.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/bounds.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/cover_tree.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/first_point_is_root.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/single_tree_traverser.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/traits.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/hrectbound.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/hrectbound_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/mrkd_statistic.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/mrkd_statistic.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/mrkd_statistic_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/periodichrectbound.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/periodichrectbound_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/statistic.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/tree_traits.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli_deleter.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli_deleter.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/log.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/log.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/nulloutstream.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/option.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/option.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/option_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/prefixedoutstream.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/prefixedoutstream.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/prefixedoutstream_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/save_restore_utility.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/save_restore_utility.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/save_restore_utility_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/sfinae_utility.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/string_util.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/string_util.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/timers.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/timers.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/det_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dt_utils.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dt_utils.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dtree.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dtree.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb_rules.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb_rules_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/edge_pair.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/emst_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/union_find.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_rules.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_stat.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/ip_metric.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/ip_metric_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/em_fit.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/em_fit_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/gmm.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/gmm_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/gmm_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/phi.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_generate_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_loglik_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_train_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_util.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_util_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_viterbi_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kernel_pca/kernel_pca.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kernel_pca/kernel_pca_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kernel_pca/kernel_pca_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/allow_empty_clusters.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/kmeans.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/kmeans_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/kmeans_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/max_variance_new_cluster.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/random_partition.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/refined_start.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/refined_start_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lars/lars.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lars/lars.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lars/lars_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/linear_regression/linear_regression.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/linear_regression/linear_regression.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/linear_regression/linear_regression_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding/lcc.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding/lcc_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lsh/lsh_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lsh/lsh_search.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lsh/lsh_search_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/mvu/mvu.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/mvu/mvu.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/mvu/mvu_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/naive_bayes/nbc_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_softmax_error_function.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_softmax_error_function_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/allkfn_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/allknn_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/typedef.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/unmap.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/unmap.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/als_update_rules.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/mult_dist_update_rules.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/mult_div_update_rules.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/nmf.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/nmf_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/nmf_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/random_acol_init.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/random_init.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/pca/pca.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/pca/pca.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/pca/pca_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/radical/radical.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/radical/radical.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/radical/radical_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/range_search/range_search.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/range_search/range_search_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/range_search/range_search_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/allkrann_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search_rules.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search_rules_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_typedef.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/data_dependent_random_initializer.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/nothing_initializer.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/random_initializer.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/sparse_coding.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/sparse_coding_impl.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/sparse_coding_main.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/allkfn_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/allknn_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/allkrann_search_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/arma_extend_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/aug_lagrangian_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/cli_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/det_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/distribution_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/emst_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/fastmks_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/gmm_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/hmm_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/kernel_pca_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/kernel_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/kmeans_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lars_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lbfgs_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lin_alg_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/linear_regression_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/load_save_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/local_coordinate_coding_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lrsdp_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lsh_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/math_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/mlpack_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/nbc_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/nca_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/nmf_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/old_boost_test_definitions.hpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/pca_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/radical_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/range_search_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/save_restore_utility_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/sgd_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/sort_policy_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/sparse_coding_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/tree_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/tree_traits_test.cpp
mlpack/tags/mlpack-1.0.5/src/mlpack/tests/union_find_test.cpp
Modified:
mlpack/tags/mlpack-1.0.1/src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp
Log:
Branch for 1.0.5 release.
Modified: mlpack/tags/mlpack-1.0.1/src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp
===================================================================
--- mlpack/tags/mlpack-1.0.1/src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp 2013-05-02 04:14:28 UTC (rev 15000)
+++ mlpack/tags/mlpack-1.0.1/src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -32,7 +32,7 @@
// this is the sensible thing to do.
for (size_t i = 0; i < data.n_cols; i++)
{
- variances[assignments[i]] += as_scalar(
+ variances[assignments[i]] += arma::as_scalar(
var(data.col(i) - centroids.col(assignments[i])));
}
@@ -47,7 +47,7 @@
{
if (assignments[i] == maxVarCluster)
{
- double distance = as_scalar(
+ double distance = arma::as_scalar(
var(data.col(i) - centroids.col(maxVarCluster)));
if (distance > maxDistance)
Deleted: mlpack/tags/mlpack-1.0.5/Doxyfile
===================================================================
--- mlpack/branches/mlpack-1.x/Doxyfile 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/Doxyfile 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,253 +0,0 @@
-# Doxyfile 1.4.7
-
-#---------------------------------------------------------------------------
-# Project related configuration options
-#---------------------------------------------------------------------------
-PROJECT_NAME = MLPACK
-PROJECT_NUMBER = 1.0.4
-OUTPUT_DIRECTORY = ./doc
-CREATE_SUBDIRS = NO
-OUTPUT_LANGUAGE = English
-USE_WINDOWS_ENCODING = NO
-BRIEF_MEMBER_DESC = YES
-REPEAT_BRIEF = YES
-ABBREVIATE_BRIEF = "The $name class" \
- "The $name widget" \
- "The $name file" \
- is \
- provides \
- specifies \
- contains \
- represents \
- a \
- an \
- the
-ALWAYS_DETAILED_SEC = YES
-INLINE_INHERITED_MEMB = NO
-FULL_PATH_NAMES = YES
-STRIP_FROM_PATH = ./
-STRIP_FROM_INC_PATH =
-SHORT_NAMES = NO
-JAVADOC_AUTOBRIEF = YES
-MULTILINE_CPP_IS_BRIEF = NO
-DETAILS_AT_TOP = YES
-INHERIT_DOCS = YES
-SEPARATE_MEMBER_PAGES = NO
-TAB_SIZE = 2
-ALIASES =
-OPTIMIZE_OUTPUT_FOR_C = NO
-OPTIMIZE_OUTPUT_JAVA = NO
-BUILTIN_STL_SUPPORT = NO
-DISTRIBUTE_GROUP_DOC = NO
-SUBGROUPING = YES
-#---------------------------------------------------------------------------
-# Build related configuration options
-#---------------------------------------------------------------------------
-EXTRACT_ALL = YES
-EXTRACT_PRIVATE = YES
-EXTRACT_STATIC = YES
-EXTRACT_LOCAL_CLASSES = NO
-EXTRACT_LOCAL_METHODS = NO
-HIDE_UNDOC_MEMBERS = NO
-HIDE_UNDOC_CLASSES = NO
-HIDE_FRIEND_COMPOUNDS = NO
-HIDE_IN_BODY_DOCS = NO
-INTERNAL_DOCS = YES
-CASE_SENSE_NAMES = YES
-HIDE_SCOPE_NAMES = NO
-SHOW_INCLUDE_FILES = NO
-INLINE_INFO = YES
-SORT_MEMBER_DOCS = YES
-SORT_BRIEF_DOCS = YES
-SORT_BY_SCOPE_NAME = YES
-SORT_MEMBERS_CTORS_1ST = YES
-GENERATE_TODOLIST = NO
-GENERATE_TESTLIST = NO
-GENERATE_BUGLIST = YES
-GENERATE_DEPRECATEDLIST= NO
-ENABLED_SECTIONS =
-MAX_INITIALIZER_LINES = 30
-SHOW_USED_FILES = YES
-SHOW_DIRECTORIES = YES
-FILE_VERSION_FILTER =
-#---------------------------------------------------------------------------
-# configuration options related to warning and progress messages
-#---------------------------------------------------------------------------
-QUIET = NO
-WARNINGS = YES
-WARN_IF_UNDOCUMENTED = YES
-WARN_IF_DOC_ERROR = YES
-WARN_NO_PARAMDOC = YES
-WARN_FORMAT = "$file:$line: $text"
-WARN_LOGFILE =
-#---------------------------------------------------------------------------
-# configuration options related to the input files
-#---------------------------------------------------------------------------
-INPUT = ./src/mlpack \
- ./doc/guide \
- ./doc/tutorials
-FILE_PATTERNS = *.c \
- *.cc \
- *.h \
- *.hpp \
- *.cpp \
- *.txt
-RECURSIVE = YES
-EXCLUDE =
-EXCLUDE_SYMLINKS = YES
-EXCLUDE_PATTERNS = */build/* \
- */test/* \
- */arma_extend/* \
- */.svn/* \
- *_impl.cc \
- *_impl.h \
- *_impl.hpp \
- *.cpp \
- *.cc \
- *_test.cc
-EXAMPLE_PATH =
-EXAMPLE_PATTERNS = *
-EXAMPLE_RECURSIVE = NO
-IMAGE_PATH =
-INPUT_FILTER =
-FILTER_PATTERNS =
-FILTER_SOURCE_FILES = NO
-#---------------------------------------------------------------------------
-# configuration options related to source browsing
-#---------------------------------------------------------------------------
-SOURCE_BROWSER = YES
-INLINE_SOURCES = NO
-STRIP_CODE_COMMENTS = YES
-REFERENCED_BY_RELATION = YES
-REFERENCES_RELATION = YES
-REFERENCES_LINK_SOURCE = YES
-USE_HTAGS = NO
-VERBATIM_HEADERS = YES
-#---------------------------------------------------------------------------
-# configuration options related to the alphabetical class index
-#---------------------------------------------------------------------------
-ALPHABETICAL_INDEX = YES
-COLS_IN_ALPHA_INDEX = 1
-IGNORE_PREFIX =
-#---------------------------------------------------------------------------
-# configuration options related to the HTML output
-#---------------------------------------------------------------------------
-GENERATE_HTML = YES
-HTML_OUTPUT = html
-HTML_FILE_EXTENSION = .html
-HTML_HEADER =
-HTML_FOOTER =
-HTML_STYLESHEET =
-HTML_ALIGN_MEMBERS = YES
-GENERATE_HTMLHELP = NO
-CHM_FILE =
-HHC_LOCATION =
-GENERATE_CHI = NO
-BINARY_TOC = NO
-TOC_EXPAND = NO
-DISABLE_INDEX = NO
-ENUM_VALUES_PER_LINE = 1
-GENERATE_TREEVIEW = NO
-TREEVIEW_WIDTH = 250
-#---------------------------------------------------------------------------
-# configuration options related to the LaTeX output
-#---------------------------------------------------------------------------
-GENERATE_LATEX = YES
-LATEX_OUTPUT = latex
-LATEX_CMD_NAME = latex
-MAKEINDEX_CMD_NAME = makeindex
-COMPACT_LATEX = NO
-PAPER_TYPE = letter
-EXTRA_PACKAGES =
-LATEX_HEADER =
-PDF_HYPERLINKS = NO
-USE_PDFLATEX = NO
-LATEX_BATCHMODE = NO
-LATEX_HIDE_INDICES = NO
-#---------------------------------------------------------------------------
-# configuration options related to the RTF output
-#---------------------------------------------------------------------------
-GENERATE_RTF = NO
-RTF_OUTPUT = rtf
-COMPACT_RTF = NO
-RTF_HYPERLINKS = NO
-RTF_STYLESHEET_FILE =
-RTF_EXTENSIONS_FILE =
-#---------------------------------------------------------------------------
-# configuration options related to the man page output
-#---------------------------------------------------------------------------
-GENERATE_MAN = YES
-MAN_OUTPUT = man
-MAN_EXTENSION = .3
-MAN_LINKS = NO
-#---------------------------------------------------------------------------
-# configuration options related to the XML output
-#---------------------------------------------------------------------------
-GENERATE_XML = NO
-XML_OUTPUT = xml
-XML_SCHEMA =
-XML_DTD =
-XML_PROGRAMLISTING = YES
-#---------------------------------------------------------------------------
-# configuration options for the AutoGen Definitions output
-#---------------------------------------------------------------------------
-GENERATE_AUTOGEN_DEF = NO
-#---------------------------------------------------------------------------
-# configuration options related to the Perl module output
-#---------------------------------------------------------------------------
-GENERATE_PERLMOD = NO
-PERLMOD_LATEX = NO
-PERLMOD_PRETTY = YES
-PERLMOD_MAKEVAR_PREFIX =
-#---------------------------------------------------------------------------
-# Configuration options related to the preprocessor
-#---------------------------------------------------------------------------
-ENABLE_PREPROCESSING = YES
-MACRO_EXPANSION = YES
-EXPAND_ONLY_PREDEF = NO
-SEARCH_INCLUDES = YES
-INCLUDE_PATH =
-INCLUDE_FILE_PATTERNS =
-PREDEFINED =
-EXPAND_AS_DEFINED =
-SKIP_FUNCTION_MACROS = YES
-#---------------------------------------------------------------------------
-# Configuration::additions related to external references
-#---------------------------------------------------------------------------
-TAGFILES =
-GENERATE_TAGFILE =
-ALLEXTERNALS = NO
-EXTERNAL_GROUPS = YES
-PERL_PATH = /usr/bin/perl
-#---------------------------------------------------------------------------
-# Configuration options related to the dot tool
-#---------------------------------------------------------------------------
-CLASS_DIAGRAMS = YES
-HIDE_UNDOC_RELATIONS = YES
-HAVE_DOT = YES
-CLASS_GRAPH = YES
-COLLABORATION_GRAPH = YES
-GROUP_GRAPHS = YES
-UML_LOOK = NO
-TEMPLATE_RELATIONS = YES
-INCLUDE_GRAPH = YES
-INCLUDED_BY_GRAPH = YES
-CALL_GRAPH = NO
-CALLER_GRAPH = NO
-GRAPHICAL_HIERARCHY = YES
-DIRECTORY_GRAPH = YES
-DOT_IMAGE_FORMAT = png
-# Hack dark color support in through the dot path. Kind of cheating...
-DOT_PATH = dot -Gbgcolor=black
-DOTFILE_DIRS =
-MAX_DOT_GRAPH_WIDTH = 800
-MAX_DOT_GRAPH_HEIGHT = 600
-MAX_DOT_GRAPH_DEPTH = 1000
-DOT_TRANSPARENT = NO
-DOT_MULTI_TARGETS = NO
-GENERATE_LEGEND = YES
-DOT_CLEANUP = YES
-#---------------------------------------------------------------------------
-# Configuration::additions related to the search engine
-#---------------------------------------------------------------------------
-SEARCHENGINE = NO
Copied: mlpack/tags/mlpack-1.0.5/Doxyfile (from rev 14997, mlpack/branches/mlpack-1.x/Doxyfile)
===================================================================
--- mlpack/tags/mlpack-1.0.5/Doxyfile (rev 0)
+++ mlpack/tags/mlpack-1.0.5/Doxyfile 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,253 @@
+# Doxyfile 1.4.7
+
+#---------------------------------------------------------------------------
+# Project related configuration options
+#---------------------------------------------------------------------------
+PROJECT_NAME = MLPACK
+PROJECT_NUMBER = 1.0.5
+OUTPUT_DIRECTORY = ./doc
+CREATE_SUBDIRS = NO
+OUTPUT_LANGUAGE = English
+USE_WINDOWS_ENCODING = NO
+BRIEF_MEMBER_DESC = YES
+REPEAT_BRIEF = YES
+ABBREVIATE_BRIEF = "The $name class" \
+ "The $name widget" \
+ "The $name file" \
+ is \
+ provides \
+ specifies \
+ contains \
+ represents \
+ a \
+ an \
+ the
+ALWAYS_DETAILED_SEC = YES
+INLINE_INHERITED_MEMB = NO
+FULL_PATH_NAMES = YES
+STRIP_FROM_PATH = ./
+STRIP_FROM_INC_PATH =
+SHORT_NAMES = NO
+JAVADOC_AUTOBRIEF = YES
+MULTILINE_CPP_IS_BRIEF = NO
+DETAILS_AT_TOP = YES
+INHERIT_DOCS = YES
+SEPARATE_MEMBER_PAGES = NO
+TAB_SIZE = 2
+ALIASES =
+OPTIMIZE_OUTPUT_FOR_C = NO
+OPTIMIZE_OUTPUT_JAVA = NO
+BUILTIN_STL_SUPPORT = NO
+DISTRIBUTE_GROUP_DOC = NO
+SUBGROUPING = YES
+#---------------------------------------------------------------------------
+# Build related configuration options
+#---------------------------------------------------------------------------
+EXTRACT_ALL = YES
+EXTRACT_PRIVATE = YES
+EXTRACT_STATIC = YES
+EXTRACT_LOCAL_CLASSES = NO
+EXTRACT_LOCAL_METHODS = NO
+HIDE_UNDOC_MEMBERS = NO
+HIDE_UNDOC_CLASSES = NO
+HIDE_FRIEND_COMPOUNDS = NO
+HIDE_IN_BODY_DOCS = NO
+INTERNAL_DOCS = YES
+CASE_SENSE_NAMES = YES
+HIDE_SCOPE_NAMES = NO
+SHOW_INCLUDE_FILES = NO
+INLINE_INFO = YES
+SORT_MEMBER_DOCS = YES
+SORT_BRIEF_DOCS = YES
+SORT_BY_SCOPE_NAME = YES
+SORT_MEMBERS_CTORS_1ST = YES
+GENERATE_TODOLIST = NO
+GENERATE_TESTLIST = NO
+GENERATE_BUGLIST = YES
+GENERATE_DEPRECATEDLIST= NO
+ENABLED_SECTIONS =
+MAX_INITIALIZER_LINES = 30
+SHOW_USED_FILES = YES
+SHOW_DIRECTORIES = YES
+FILE_VERSION_FILTER =
+#---------------------------------------------------------------------------
+# configuration options related to warning and progress messages
+#---------------------------------------------------------------------------
+QUIET = NO
+WARNINGS = YES
+WARN_IF_UNDOCUMENTED = YES
+WARN_IF_DOC_ERROR = YES
+WARN_NO_PARAMDOC = YES
+WARN_FORMAT = "$file:$line: $text"
+WARN_LOGFILE =
+#---------------------------------------------------------------------------
+# configuration options related to the input files
+#---------------------------------------------------------------------------
+INPUT = ./src/mlpack \
+ ./doc/guide \
+ ./doc/tutorials
+FILE_PATTERNS = *.c \
+ *.cc \
+ *.h \
+ *.hpp \
+ *.cpp \
+ *.txt
+RECURSIVE = YES
+EXCLUDE =
+EXCLUDE_SYMLINKS = YES
+EXCLUDE_PATTERNS = */build/* \
+ */test/* \
+ */arma_extend/* \
+ */.svn/* \
+ *_impl.cc \
+ *_impl.h \
+ *_impl.hpp \
+ *.cpp \
+ *.cc \
+ *_test.cc
+EXAMPLE_PATH =
+EXAMPLE_PATTERNS = *
+EXAMPLE_RECURSIVE = NO
+IMAGE_PATH =
+INPUT_FILTER =
+FILTER_PATTERNS =
+FILTER_SOURCE_FILES = NO
+#---------------------------------------------------------------------------
+# configuration options related to source browsing
+#---------------------------------------------------------------------------
+SOURCE_BROWSER = YES
+INLINE_SOURCES = NO
+STRIP_CODE_COMMENTS = YES
+REFERENCED_BY_RELATION = YES
+REFERENCES_RELATION = YES
+REFERENCES_LINK_SOURCE = YES
+USE_HTAGS = NO
+VERBATIM_HEADERS = YES
+#---------------------------------------------------------------------------
+# configuration options related to the alphabetical class index
+#---------------------------------------------------------------------------
+ALPHABETICAL_INDEX = YES
+COLS_IN_ALPHA_INDEX = 1
+IGNORE_PREFIX =
+#---------------------------------------------------------------------------
+# configuration options related to the HTML output
+#---------------------------------------------------------------------------
+GENERATE_HTML = YES
+HTML_OUTPUT = html
+HTML_FILE_EXTENSION = .html
+HTML_HEADER =
+HTML_FOOTER =
+HTML_STYLESHEET =
+HTML_ALIGN_MEMBERS = YES
+GENERATE_HTMLHELP = NO
+CHM_FILE =
+HHC_LOCATION =
+GENERATE_CHI = NO
+BINARY_TOC = NO
+TOC_EXPAND = NO
+DISABLE_INDEX = NO
+ENUM_VALUES_PER_LINE = 1
+GENERATE_TREEVIEW = NO
+TREEVIEW_WIDTH = 250
+#---------------------------------------------------------------------------
+# configuration options related to the LaTeX output
+#---------------------------------------------------------------------------
+GENERATE_LATEX = YES
+LATEX_OUTPUT = latex
+LATEX_CMD_NAME = latex
+MAKEINDEX_CMD_NAME = makeindex
+COMPACT_LATEX = NO
+PAPER_TYPE = letter
+EXTRA_PACKAGES =
+LATEX_HEADER =
+PDF_HYPERLINKS = NO
+USE_PDFLATEX = NO
+LATEX_BATCHMODE = NO
+LATEX_HIDE_INDICES = NO
+#---------------------------------------------------------------------------
+# configuration options related to the RTF output
+#---------------------------------------------------------------------------
+GENERATE_RTF = NO
+RTF_OUTPUT = rtf
+COMPACT_RTF = NO
+RTF_HYPERLINKS = NO
+RTF_STYLESHEET_FILE =
+RTF_EXTENSIONS_FILE =
+#---------------------------------------------------------------------------
+# configuration options related to the man page output
+#---------------------------------------------------------------------------
+GENERATE_MAN = YES
+MAN_OUTPUT = man
+MAN_EXTENSION = .3
+MAN_LINKS = NO
+#---------------------------------------------------------------------------
+# configuration options related to the XML output
+#---------------------------------------------------------------------------
+GENERATE_XML = NO
+XML_OUTPUT = xml
+XML_SCHEMA =
+XML_DTD =
+XML_PROGRAMLISTING = YES
+#---------------------------------------------------------------------------
+# configuration options for the AutoGen Definitions output
+#---------------------------------------------------------------------------
+GENERATE_AUTOGEN_DEF = NO
+#---------------------------------------------------------------------------
+# configuration options related to the Perl module output
+#---------------------------------------------------------------------------
+GENERATE_PERLMOD = NO
+PERLMOD_LATEX = NO
+PERLMOD_PRETTY = YES
+PERLMOD_MAKEVAR_PREFIX =
+#---------------------------------------------------------------------------
+# Configuration options related to the preprocessor
+#---------------------------------------------------------------------------
+ENABLE_PREPROCESSING = YES
+MACRO_EXPANSION = YES
+EXPAND_ONLY_PREDEF = NO
+SEARCH_INCLUDES = YES
+INCLUDE_PATH =
+INCLUDE_FILE_PATTERNS =
+PREDEFINED =
+EXPAND_AS_DEFINED =
+SKIP_FUNCTION_MACROS = YES
+#---------------------------------------------------------------------------
+# Configuration::additions related to external references
+#---------------------------------------------------------------------------
+TAGFILES =
+GENERATE_TAGFILE =
+ALLEXTERNALS = NO
+EXTERNAL_GROUPS = YES
+PERL_PATH = /usr/bin/perl
+#---------------------------------------------------------------------------
+# Configuration options related to the dot tool
+#---------------------------------------------------------------------------
+CLASS_DIAGRAMS = YES
+HIDE_UNDOC_RELATIONS = YES
+HAVE_DOT = YES
+CLASS_GRAPH = YES
+COLLABORATION_GRAPH = YES
+GROUP_GRAPHS = YES
+UML_LOOK = NO
+TEMPLATE_RELATIONS = YES
+INCLUDE_GRAPH = YES
+INCLUDED_BY_GRAPH = YES
+CALL_GRAPH = NO
+CALLER_GRAPH = NO
+GRAPHICAL_HIERARCHY = YES
+DIRECTORY_GRAPH = YES
+DOT_IMAGE_FORMAT = png
+# Hack dark color support in through the dot path. Kind of cheating...
+DOT_PATH = dot -Gbgcolor=black
+DOTFILE_DIRS =
+MAX_DOT_GRAPH_WIDTH = 800
+MAX_DOT_GRAPH_HEIGHT = 600
+MAX_DOT_GRAPH_DEPTH = 1000
+DOT_TRANSPARENT = NO
+DOT_MULTI_TARGETS = NO
+GENERATE_LEGEND = YES
+DOT_CLEANUP = YES
+#---------------------------------------------------------------------------
+# Configuration::additions related to the search engine
+#---------------------------------------------------------------------------
+SEARCHENGINE = NO
Deleted: mlpack/tags/mlpack-1.0.5/HISTORY.txt
===================================================================
--- mlpack/branches/mlpack-1.x/HISTORY.txt 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/HISTORY.txt 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,89 +0,0 @@
-2013-05-01 mlpack 1.0.5
-
- * Speedups of cover tree traversers (#243).
-
- * Addition of rank-approximate nearest neighbors (RANN), found in
- src/mlpack/methods/rann/.
-
- * Addition of fast exact max-kernel search (FastMKS), found in
- src/mlpack/methods/fastmks/.
-
- * Fix for EM covariance estimation; this should improve GMM training time.
-
- * More parameters for GMM estimation.
-
- * Force GMM and GaussianDistribution covariance matrices to be positive
- definite, so that training converges much more often.
-
- * Add parameter for the tolerance of the Baum-Welch algorithm for HMM
- training.
-
- * Fix for compilation with clang compiler.
-
-2013-02-08 mlpack 1.0.4
-
- * Force minimum Armadillo version to 2.4.2.
-
- * Better output of class types to streams; a class with a ToString() method
- implemented can be sent to a stream with operator<<. See #164.
-
- * Change return type of GMM::Estimate() to double (#266).
-
- * Style fixes for k-means and RADICAL.
-
- * Handle size_t support correctly with Armadillo 3.6.2 (#267).
-
- * Add locality-sensitive hashing (LSH), found in src/mlpack/methods/lsh/.
-
- * Better tests for SGD (stochastic gradient descent) and NCA (neighborhood
- components analysis).
-
-2012-09-16 mlpack 1.0.3
-
- * Remove internal sparse matrix support because Armadillo 3.4.0 now includes
- it. When using Armadillo versions older than 3.4.0, sparse matrix support
- is not available.
-
- * NCA (neighborhood components analysis) now support an arbitrary optimizer
- (#254), including stochastic gradient descent (#258).
-
-2012-08-15 mlpack 1.0.2
-
- * Added density estimation trees, found in src/mlpack/methods/det/.
-
- * Added non-negative matrix factorization, found in src/mlpack/methods/nmf/.
-
- * Added experimental cover tree implementation, found in
- src/mlpack/core/tree/cover_tree/ (#156).
-
- * Better reporting of boost::program_options errors (#231).
-
- * Fix for timers on Windows (#218, #217).
-
- * Fix for allknn and allkfn output (#210).
-
- * Sparse coding dictionary initialization is now a template parameter (#226).
-
-2012-03-03 mlpack 1.0.1
-
- * Added kernel principal components analysis (kernel PCA), found in
- src/mlpack/methods/kernel_pca/ (#47).
-
- * Fix for Lovasz-Theta AugLagrangian tests (#188).
-
- * Fixes for allknn output (#191, #192).
-
- * Added range search executable (#198).
-
- * Adapted citations in documentation to BiBTeX; no citations in -h output
- (#201).
-
- * Stop use of 'const char*' and prefer 'std::string' (#183).
-
- * Support seeds for random numbers (#182).
-
-2011-12-17 mlpack 1.0.0
-
- * Initial release. See any resolved tickets numbered less than #196 or
- execute this query:
- http://www.mlpack.org/trac/query?status=closed&milestone=mlpack+1.0.0
Copied: mlpack/tags/mlpack-1.0.5/HISTORY.txt (from rev 15000, mlpack/branches/mlpack-1.x/HISTORY.txt)
===================================================================
--- mlpack/tags/mlpack-1.0.5/HISTORY.txt (rev 0)
+++ mlpack/tags/mlpack-1.0.5/HISTORY.txt 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,91 @@
+2013-05-01 mlpack 1.0.5
+
+ * Speedups of cover tree traversers (#243).
+
+ * Addition of rank-approximate nearest neighbors (RANN), found in
+ src/mlpack/methods/rann/.
+
+ * Addition of fast exact max-kernel search (FastMKS), found in
+ src/mlpack/methods/fastmks/.
+
+ * Fix for EM covariance estimation; this should improve GMM training time.
+
+ * More parameters for GMM estimation.
+
+ * Force GMM and GaussianDistribution covariance matrices to be positive
+ definite, so that training converges much more often.
+
+ * Add parameter for the tolerance of the Baum-Welch algorithm for HMM
+ training.
+
+ * Fix for compilation with clang compiler.
+
+ * Fix for k-furthest-neighbor-search.
+
+2013-02-08 mlpack 1.0.4
+
+ * Force minimum Armadillo version to 2.4.2.
+
+ * Better output of class types to streams; a class with a ToString() method
+ implemented can be sent to a stream with operator<<. See #164.
+
+ * Change return type of GMM::Estimate() to double (#266).
+
+ * Style fixes for k-means and RADICAL.
+
+ * Handle size_t support correctly with Armadillo 3.6.2 (#267).
+
+ * Add locality-sensitive hashing (LSH), found in src/mlpack/methods/lsh/.
+
+ * Better tests for SGD (stochastic gradient descent) and NCA (neighborhood
+ components analysis).
+
+2012-09-16 mlpack 1.0.3
+
+ * Remove internal sparse matrix support because Armadillo 3.4.0 now includes
+ it. When using Armadillo versions older than 3.4.0, sparse matrix support
+ is not available.
+
+ * NCA (neighborhood components analysis) now support an arbitrary optimizer
+ (#254), including stochastic gradient descent (#258).
+
+2012-08-15 mlpack 1.0.2
+
+ * Added density estimation trees, found in src/mlpack/methods/det/.
+
+ * Added non-negative matrix factorization, found in src/mlpack/methods/nmf/.
+
+ * Added experimental cover tree implementation, found in
+ src/mlpack/core/tree/cover_tree/ (#156).
+
+ * Better reporting of boost::program_options errors (#231).
+
+ * Fix for timers on Windows (#218, #217).
+
+ * Fix for allknn and allkfn output (#210).
+
+ * Sparse coding dictionary initialization is now a template parameter (#226).
+
+2012-03-03 mlpack 1.0.1
+
+ * Added kernel principal components analysis (kernel PCA), found in
+ src/mlpack/methods/kernel_pca/ (#47).
+
+ * Fix for Lovasz-Theta AugLagrangian tests (#188).
+
+ * Fixes for allknn output (#191, #192).
+
+ * Added range search executable (#198).
+
+ * Adapted citations in documentation to BiBTeX; no citations in -h output
+ (#201).
+
+ * Stop use of 'const char*' and prefer 'std::string' (#183).
+
+ * Support seeds for random numbers (#182).
+
+2011-12-17 mlpack 1.0.0
+
+ * Initial release. See any resolved tickets numbered less than #196 or
+ execute this query:
+ http://www.mlpack.org/trac/query?status=closed&milestone=mlpack+1.0.0
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/load.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/data/load.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/load.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,78 +0,0 @@
-/**
- * @file load.hpp
- * @author Ryan Curtin
- *
- * Load an Armadillo matrix from file. This is necessary because Armadillo does
- * not transpose matrices on input, and it allows us to give better error
- * output.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_DATA_LOAD_HPP
-#define __MLPACK_CORE_DATA_LOAD_HPP
-
-#include <mlpack/core/util/log.hpp>
-#include <mlpack/core/arma_extend/arma_extend.hpp> // Includes Armadillo.
-#include <string>
-
-namespace mlpack {
-namespace data /** Functions to load and save matrices. */ {
-
-/**
- * Loads a matrix from file, guessing the filetype from the extension. This
- * will transpose the matrix at load time. If the filetype cannot be
- * determined, an error will be given.
- *
- * The supported types of files are the same as found in Armadillo:
- *
- * - CSV (csv_ascii), denoted by .csv, or optionally .txt
- * - ASCII (raw_ascii), denoted by .txt
- * - Armadillo ASCII (arma_ascii), also denoted by .txt
- * - PGM (pgm_binary), denoted by .pgm
- * - PPM (ppm_binary), denoted by .ppm
- * - Raw binary (raw_binary), denoted by .bin
- * - Armadillo binary (arma_binary), denoted by .bin
- * - HDF5, denoted by .hdf, .hdf5, .h5, or .he5
- *
- * If the file extension is not one of those types, an error will be given.
- * This is preferable to Armadillo's default behavior of loading an unknown
- * filetype as raw_binary, which can have very confusing effects.
- *
- * If the parameter 'fatal' is set to true, the program will exit with an error
- * if the matrix does not load successfully. The parameter 'transpose' controls
- * whether or not the matrix is transposed after loading. In most cases,
- * because data is generally stored in a row-major format and MLPACK requires
- * column-major matrices, this should be left at its default value of 'true'.
- *
- * @param filename Name of file to load.
- * @param matrix Matrix to load contents of file into.
- * @param fatal If an error should be reported as fatal (default false).
- * @param transpose If true, transpose the matrix after loading.
- * @return Boolean value indicating success or failure of load.
- */
-template<typename eT>
-bool Load(const std::string& filename,
- arma::Mat<eT>& matrix,
- bool fatal = false,
- bool transpose = true);
-
-}; // namespace data
-}; // namespace mlpack
-
-// Include implementation.
-#include "load_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/load.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/data/load.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/load.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/load.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,78 @@
+/**
+ * @file load.hpp
+ * @author Ryan Curtin
+ *
+ * Load an Armadillo matrix from file. This is necessary because Armadillo does
+ * not transpose matrices on input, and it allows us to give better error
+ * output.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_DATA_LOAD_HPP
+#define __MLPACK_CORE_DATA_LOAD_HPP
+
+#include <mlpack/core/util/log.hpp>
+#include <mlpack/core/arma_extend/arma_extend.hpp> // Includes Armadillo.
+#include <string>
+
+namespace mlpack {
+namespace data /** Functions to load and save matrices. */ {
+
+/**
+ * Loads a matrix from file, guessing the filetype from the extension. This
+ * will transpose the matrix at load time. If the filetype cannot be
+ * determined, an error will be given.
+ *
+ * The supported types of files are the same as found in Armadillo:
+ *
+ * - CSV (csv_ascii), denoted by .csv, or optionally .txt
+ * - ASCII (raw_ascii), denoted by .txt
+ * - Armadillo ASCII (arma_ascii), also denoted by .txt
+ * - PGM (pgm_binary), denoted by .pgm
+ * - PPM (ppm_binary), denoted by .ppm
+ * - Raw binary (raw_binary), denoted by .bin
+ * - Armadillo binary (arma_binary), denoted by .bin
+ * - HDF5, denoted by .hdf, .hdf5, .h5, or .he5
+ *
+ * If the file extension is not one of those types, an error will be given.
+ * This is preferable to Armadillo's default behavior of loading an unknown
+ * filetype as raw_binary, which can have very confusing effects.
+ *
+ * If the parameter 'fatal' is set to true, the program will exit with an error
+ * if the matrix does not load successfully. The parameter 'transpose' controls
+ * whether or not the matrix is transposed after loading. In most cases,
+ * because data is generally stored in a row-major format and MLPACK requires
+ * column-major matrices, this should be left at its default value of 'true'.
+ *
+ * @param filename Name of file to load.
+ * @param matrix Matrix to load contents of file into.
+ * @param fatal If an error should be reported as fatal (default false).
+ * @param transpose If true, transpose the matrix after loading.
+ * @return Boolean value indicating success or failure of load.
+ */
+template<typename eT>
+bool Load(const std::string& filename,
+ arma::Mat<eT>& matrix,
+ bool fatal = false,
+ bool transpose = true);
+
+}; // namespace data
+}; // namespace mlpack
+
+// Include implementation.
+#include "load_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/load_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/data/load_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/load_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,230 +0,0 @@
-/**
- * @file load_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of templatized load() function defined in load.hpp.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_DATA_LOAD_IMPL_HPP
-#define __MLPACK_CORE_DATA_LOAD_IMPL_HPP
-
-// In case it hasn't already been included.
-#include "load.hpp"
-
-#include <algorithm>
-#include <mlpack/core/util/timers.hpp>
-
-namespace mlpack {
-namespace data {
-
-template<typename eT>
-bool Load(const std::string& filename,
- arma::Mat<eT>& matrix,
- bool fatal,
- bool transpose)
-{
- Timer::Start("loading_data");
-
- // First we will try to discriminate by file extension.
- size_t ext = filename.rfind('.');
- if (ext == std::string::npos)
- {
- if (fatal)
- Log::Fatal << "Cannot determine type of file '" << filename << "'; "
- << "no extension is present." << std::endl;
- else
- Log::Warn << "Cannot determine type of file '" << filename << "'; "
- << "no extension is present. Load failed." << std::endl;
-
- Timer::Stop("loading_data");
- return false;
- }
-
- // Get the extension and force it to lowercase.
- std::string extension = filename.substr(ext + 1);
- std::transform(extension.begin(), extension.end(), extension.begin(),
- ::tolower);
-
- // Catch nonexistent files by opening the stream ourselves.
- std::fstream stream;
- stream.open(filename.c_str(), std::fstream::in);
-
- if (!stream.is_open())
- {
- if (fatal)
- Log::Fatal << "Cannot open file '" << filename << "'. " << std::endl;
- else
- Log::Warn << "Cannot open file '" << filename << "'; load failed."
- << std::endl;
-
- Timer::Stop("loading_data");
- return false;
- }
-
- bool unknownType = false;
- arma::file_type loadType;
- std::string stringType;
-
- if (extension == "csv")
- {
- loadType = arma::csv_ascii;
- stringType = "CSV data";
- }
- else if (extension == "txt")
- {
- // This could be raw ASCII or Armadillo ASCII (ASCII with size header).
- // We'll let Armadillo do its guessing (although we have to check if it is
- // arma_ascii ourselves) and see what we come up with.
-
- // This is taken from load_auto_detect() in diskio_meat.hpp
- const std::string ARMA_MAT_TXT = "ARMA_MAT_TXT";
- char* rawHeader = new char[ARMA_MAT_TXT.length() + 1];
- std::streampos pos = stream.tellg();
-
- stream.read(rawHeader, std::streamsize(ARMA_MAT_TXT.length()));
- rawHeader[ARMA_MAT_TXT.length()] = '\0';
- stream.clear();
- stream.seekg(pos); // Reset stream position after peeking.
-
- if (std::string(rawHeader) == ARMA_MAT_TXT)
- {
- loadType = arma::arma_ascii;
- stringType = "Armadillo ASCII formatted data";
- }
- else // It's not arma_ascii. Now we let Armadillo guess.
- {
- loadType = arma::diskio::guess_file_type(stream);
-
- if (loadType == arma::raw_ascii) // Raw ASCII (space-separated).
- stringType = "raw ASCII formatted data";
- else if (loadType == arma::csv_ascii) // CSV can be .txt too.
- stringType = "CSV data";
- else // Unknown .txt... we will throw an error.
- unknownType = true;
- }
-
- delete[] rawHeader;
- }
- else if (extension == "bin")
- {
- // This could be raw binary or Armadillo binary (binary with header). We
- // will check to see if it is Armadillo binary.
- const std::string ARMA_MAT_BIN = "ARMA_MAT_BIN";
- char *rawHeader = new char[ARMA_MAT_BIN.length() + 1];
-
- std::streampos pos = stream.tellg();
-
- stream.read(rawHeader, std::streamsize(ARMA_MAT_BIN.length()));
- rawHeader[ARMA_MAT_BIN.length()] = '\0';
- stream.clear();
- stream.seekg(pos); // Reset stream position after peeking.
-
- if (std::string(rawHeader) == ARMA_MAT_BIN)
- {
- stringType = "Armadillo binary formatted data";
- loadType = arma::arma_binary;
- }
- else // We can only assume it's raw binary.
- {
- stringType = "raw binary formatted data";
- loadType = arma::raw_binary;
- }
-
- delete[] rawHeader;
- }
- else if (extension == "pgm")
- {
- loadType = arma::pgm_binary;
- stringType = "PGM data";
- }
- else if (extension == "h5" || extension == "hdf5" || extension == "hdf" ||
- extension == "he5")
- {
-#ifdef ARMA_USE_HDF5
- loadType = arma::hdf5_binary;
- stringType = "HDF5 data";
-#else
- if (fatal)
- Log::Fatal << "Attempted to load '" << filename << "' as HDF5 data, but "
- << "Armadillo was compiled without HDF5 support. Load failed."
- << std::endl;
- else
- Log::Warn << "Attempted to load '" << filename << "' as HDF5 data, but "
- << "Armadillo was compiled without HDF5 support. Load failed."
- << std::endl;
-
- Timer::Stop("loading_data");
- return false;
-#endif
- }
- else // Unknown extension...
- {
- unknownType = true;
- loadType = arma::raw_binary; // Won't be used; prevent a warning.
- stringType = "";
- }
-
- // Provide error if we don't know the type.
- if (unknownType)
- {
- if (fatal)
- Log::Fatal << "Unable to detect type of '" << filename << "'; "
- << "incorrect extension?" << std::endl;
- else
- Log::Warn << "Unable to detect type of '" << filename << "'; load failed."
- << " Incorrect extension?" << std::endl;
-
- Timer::Stop("loading_data");
- return false;
- }
-
- // Try to load the file; but if it's raw_binary, it could be a problem.
- if (loadType == arma::raw_binary)
- Log::Warn << "Loading '" << filename << "' as " << stringType << "; "
- << "but this may not be the actual filetype!" << std::endl;
- else
- Log::Info << "Loading '" << filename << "' as " << stringType << ". "
- << std::flush;
-
- const bool success = matrix.load(stream, loadType);
-
- if (!success)
- {
- Log::Info << std::endl;
- if (fatal)
- Log::Fatal << "Loading from '" << filename << "' failed." << std::endl;
- else
- Log::Warn << "Loading from '" << filename << "' failed." << std::endl;
- }
- else
- Log::Info << "Size is " << (transpose ? matrix.n_cols : matrix.n_rows)
- << " x " << (transpose ? matrix.n_rows : matrix.n_cols) << ".\n";
-
- // Now transpose the matrix, if necessary.
- if (transpose)
- matrix = trans(matrix);
-
- Timer::Stop("loading_data");
-
- // Finally, return the success indicator.
- return success;
-}
-
-}; // namespace data
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/load_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/data/load_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/load_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/load_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,230 @@
+/**
+ * @file load_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of templatized load() function defined in load.hpp.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_DATA_LOAD_IMPL_HPP
+#define __MLPACK_CORE_DATA_LOAD_IMPL_HPP
+
+// In case it hasn't already been included.
+#include "load.hpp"
+
+#include <algorithm>
+#include <mlpack/core/util/timers.hpp>
+
+namespace mlpack {
+namespace data {
+
+template<typename eT>
+bool Load(const std::string& filename,
+ arma::Mat<eT>& matrix,
+ bool fatal,
+ bool transpose)
+{
+ Timer::Start("loading_data");
+
+ // First we will try to discriminate by file extension.
+ size_t ext = filename.rfind('.');
+ if (ext == std::string::npos)
+ {
+ if (fatal)
+ Log::Fatal << "Cannot determine type of file '" << filename << "'; "
+ << "no extension is present." << std::endl;
+ else
+ Log::Warn << "Cannot determine type of file '" << filename << "'; "
+ << "no extension is present. Load failed." << std::endl;
+
+ Timer::Stop("loading_data");
+ return false;
+ }
+
+ // Get the extension and force it to lowercase.
+ std::string extension = filename.substr(ext + 1);
+ std::transform(extension.begin(), extension.end(), extension.begin(),
+ ::tolower);
+
+ // Catch nonexistent files by opening the stream ourselves.
+ std::fstream stream;
+ stream.open(filename.c_str(), std::fstream::in);
+
+ if (!stream.is_open())
+ {
+ if (fatal)
+ Log::Fatal << "Cannot open file '" << filename << "'. " << std::endl;
+ else
+ Log::Warn << "Cannot open file '" << filename << "'; load failed."
+ << std::endl;
+
+ Timer::Stop("loading_data");
+ return false;
+ }
+
+ bool unknownType = false;
+ arma::file_type loadType;
+ std::string stringType;
+
+ if (extension == "csv")
+ {
+ loadType = arma::csv_ascii;
+ stringType = "CSV data";
+ }
+ else if (extension == "txt")
+ {
+ // This could be raw ASCII or Armadillo ASCII (ASCII with size header).
+ // We'll let Armadillo do its guessing (although we have to check if it is
+ // arma_ascii ourselves) and see what we come up with.
+
+ // This is taken from load_auto_detect() in diskio_meat.hpp
+ const std::string ARMA_MAT_TXT = "ARMA_MAT_TXT";
+ char* rawHeader = new char[ARMA_MAT_TXT.length() + 1];
+ std::streampos pos = stream.tellg();
+
+ stream.read(rawHeader, std::streamsize(ARMA_MAT_TXT.length()));
+ rawHeader[ARMA_MAT_TXT.length()] = '\0';
+ stream.clear();
+ stream.seekg(pos); // Reset stream position after peeking.
+
+ if (std::string(rawHeader) == ARMA_MAT_TXT)
+ {
+ loadType = arma::arma_ascii;
+ stringType = "Armadillo ASCII formatted data";
+ }
+ else // It's not arma_ascii. Now we let Armadillo guess.
+ {
+ loadType = arma::diskio::guess_file_type(stream);
+
+ if (loadType == arma::raw_ascii) // Raw ASCII (space-separated).
+ stringType = "raw ASCII formatted data";
+ else if (loadType == arma::csv_ascii) // CSV can be .txt too.
+ stringType = "CSV data";
+ else // Unknown .txt... we will throw an error.
+ unknownType = true;
+ }
+
+ delete[] rawHeader;
+ }
+ else if (extension == "bin")
+ {
+ // This could be raw binary or Armadillo binary (binary with header). We
+ // will check to see if it is Armadillo binary.
+ const std::string ARMA_MAT_BIN = "ARMA_MAT_BIN";
+ char *rawHeader = new char[ARMA_MAT_BIN.length() + 1];
+
+ std::streampos pos = stream.tellg();
+
+ stream.read(rawHeader, std::streamsize(ARMA_MAT_BIN.length()));
+ rawHeader[ARMA_MAT_BIN.length()] = '\0';
+ stream.clear();
+ stream.seekg(pos); // Reset stream position after peeking.
+
+ if (std::string(rawHeader) == ARMA_MAT_BIN)
+ {
+ stringType = "Armadillo binary formatted data";
+ loadType = arma::arma_binary;
+ }
+ else // We can only assume it's raw binary.
+ {
+ stringType = "raw binary formatted data";
+ loadType = arma::raw_binary;
+ }
+
+ delete[] rawHeader;
+ }
+ else if (extension == "pgm")
+ {
+ loadType = arma::pgm_binary;
+ stringType = "PGM data";
+ }
+ else if (extension == "h5" || extension == "hdf5" || extension == "hdf" ||
+ extension == "he5")
+ {
+#ifdef ARMA_USE_HDF5
+ loadType = arma::hdf5_binary;
+ stringType = "HDF5 data";
+#else
+ if (fatal)
+ Log::Fatal << "Attempted to load '" << filename << "' as HDF5 data, but "
+ << "Armadillo was compiled without HDF5 support. Load failed."
+ << std::endl;
+ else
+ Log::Warn << "Attempted to load '" << filename << "' as HDF5 data, but "
+ << "Armadillo was compiled without HDF5 support. Load failed."
+ << std::endl;
+
+ Timer::Stop("loading_data");
+ return false;
+#endif
+ }
+ else // Unknown extension...
+ {
+ unknownType = true;
+ loadType = arma::raw_binary; // Won't be used; prevent a warning.
+ stringType = "";
+ }
+
+ // Provide error if we don't know the type.
+ if (unknownType)
+ {
+ if (fatal)
+ Log::Fatal << "Unable to detect type of '" << filename << "'; "
+ << "incorrect extension?" << std::endl;
+ else
+ Log::Warn << "Unable to detect type of '" << filename << "'; load failed."
+ << " Incorrect extension?" << std::endl;
+
+ Timer::Stop("loading_data");
+ return false;
+ }
+
+ // Try to load the file; but if it's raw_binary, it could be a problem.
+ if (loadType == arma::raw_binary)
+ Log::Warn << "Loading '" << filename << "' as " << stringType << "; "
+ << "but this may not be the actual filetype!" << std::endl;
+ else
+ Log::Info << "Loading '" << filename << "' as " << stringType << ". "
+ << std::flush;
+
+ const bool success = matrix.load(stream, loadType);
+
+ if (!success)
+ {
+ Log::Info << std::endl;
+ if (fatal)
+ Log::Fatal << "Loading from '" << filename << "' failed." << std::endl;
+ else
+ Log::Warn << "Loading from '" << filename << "' failed." << std::endl;
+ }
+ else
+ Log::Info << "Size is " << (transpose ? matrix.n_cols : matrix.n_rows)
+ << " x " << (transpose ? matrix.n_rows : matrix.n_cols) << ".\n";
+
+ // Now transpose the matrix, if necessary.
+ if (transpose)
+ matrix = trans(matrix);
+
+ Timer::Stop("loading_data");
+
+ // Finally, return the success indicator.
+ return success;
+}
+
+}; // namespace data
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/save.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/data/save.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/save.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,75 +0,0 @@
-/**
- * @file save.hpp
- * @author Ryan Curtin
- *
- * Save an Armadillo matrix to file. This is necessary because Armadillo does
- * not transpose matrices upon saving, and it allows us to give better error
- * output.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_DATA_SAVE_HPP
-#define __MLPACK_CORE_DATA_SAVE_HPP
-
-#include <mlpack/core/util/log.hpp>
-#include <mlpack/core/arma_extend/arma_extend.hpp> // Includes Armadillo.
-#include <string>
-
-namespace mlpack {
-namespace data /** Functions to load and save matrices. */ {
-
-/**
- * Saves a matrix to file, guessing the filetype from the extension. This
- * will transpose the matrix at save time. If the filetype cannot be
- * determined, an error will be given.
- *
- * The supported types of files are the same as found in Armadillo:
- *
- * - CSV (csv_ascii), denoted by .csv, or optionally .txt
- * - ASCII (raw_ascii), denoted by .txt
- * - Armadillo ASCII (arma_ascii), also denoted by .txt
- * - PGM (pgm_binary), denoted by .pgm
- * - PPM (ppm_binary), denoted by .ppm
- * - Raw binary (raw_binary), denoted by .bin
- * - Armadillo binary (arma_binary), denoted by .bin
- * - HDF5 (hdf5_binary), denoted by .hdf5, .hdf, .h5, or .he5
- *
- * If the file extension is not one of those types, an error will be given. If
- * the 'fatal' parameter is set to true, an error will cause the program to
- * exit. If the 'transpose' parameter is set to true, the matrix will be
- * transposed before saving. Generally, because MLPACK stores matrices in a
- * column-major format and most datasets are stored on disk as row-major, this
- * parameter should be left at its default value of 'true'.
- *
- * @param filename Name of file to save to.
- * @param matrix Matrix to save into file.
- * @param fatal If an error should be reported as fatal (default false).
- * @param transpose If true, transpose the matrix before saving.
- * @return Boolean value indicating success or failure of save.
- */
-template<typename eT>
-bool Save(const std::string& filename,
- const arma::Mat<eT>& matrix,
- bool fatal = false,
- bool transpose = true);
-
-}; // namespace data
-}; // namespace mlpack
-
-// Include implementation.
-#include "save_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/save.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/data/save.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/save.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/save.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,75 @@
+/**
+ * @file save.hpp
+ * @author Ryan Curtin
+ *
+ * Save an Armadillo matrix to file. This is necessary because Armadillo does
+ * not transpose matrices upon saving, and it allows us to give better error
+ * output.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_DATA_SAVE_HPP
+#define __MLPACK_CORE_DATA_SAVE_HPP
+
+#include <mlpack/core/util/log.hpp>
+#include <mlpack/core/arma_extend/arma_extend.hpp> // Includes Armadillo.
+#include <string>
+
+namespace mlpack {
+namespace data /** Functions to load and save matrices. */ {
+
+/**
+ * Saves a matrix to file, guessing the filetype from the extension. This
+ * will transpose the matrix at save time. If the filetype cannot be
+ * determined, an error will be given.
+ *
+ * The supported types of files are the same as found in Armadillo:
+ *
+ * - CSV (csv_ascii), denoted by .csv, or optionally .txt
+ * - ASCII (raw_ascii), denoted by .txt
+ * - Armadillo ASCII (arma_ascii), also denoted by .txt
+ * - PGM (pgm_binary), denoted by .pgm
+ * - PPM (ppm_binary), denoted by .ppm
+ * - Raw binary (raw_binary), denoted by .bin
+ * - Armadillo binary (arma_binary), denoted by .bin
+ * - HDF5 (hdf5_binary), denoted by .hdf5, .hdf, .h5, or .he5
+ *
+ * If the file extension is not one of those types, an error will be given. If
+ * the 'fatal' parameter is set to true, an error will cause the program to
+ * exit. If the 'transpose' parameter is set to true, the matrix will be
+ * transposed before saving. Generally, because MLPACK stores matrices in a
+ * column-major format and most datasets are stored on disk as row-major, this
+ * parameter should be left at its default value of 'true'.
+ *
+ * @param filename Name of file to save to.
+ * @param matrix Matrix to save into file.
+ * @param fatal If an error should be reported as fatal (default false).
+ * @param transpose If true, transpose the matrix before saving.
+ * @return Boolean value indicating success or failure of save.
+ */
+template<typename eT>
+bool Save(const std::string& filename,
+ const arma::Mat<eT>& matrix,
+ bool fatal = false,
+ bool transpose = true);
+
+}; // namespace data
+}; // namespace mlpack
+
+// Include implementation.
+#include "save_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/save_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/data/save_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/save_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,178 +0,0 @@
-/**
- * @file save_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of save functionality.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_DATA_SAVE_IMPL_HPP
-#define __MLPACK_CORE_DATA_SAVE_IMPL_HPP
-
-// In case it hasn't already been included.
-#include "save.hpp"
-
-namespace mlpack {
-namespace data {
-
-template<typename eT>
-bool Save(const std::string& filename,
- const arma::Mat<eT>& matrix,
- bool fatal,
- bool transpose)
-{
- Timer::Start("saving_data");
-
- // First we will try to discriminate by file extension.
- size_t ext = filename.rfind('.');
- if (ext == std::string::npos)
- {
- if (fatal)
- Log::Fatal << "No extension given with filename '" << filename << "'; "
- << "type unknown. Save failed." << std::endl;
- else
- Log::Warn << "No extension given with filename '" << filename << "'; "
- << "type unknown. Save failed." << std::endl;
-
- return false;
- }
-
- // Get the actual extension.
- std::string extension = filename.substr(ext + 1);
-
- // Catch errors opening the file.
- std::fstream stream;
- stream.open(filename.c_str(), std::fstream::out);
-
- if (!stream.is_open())
- {
- if (fatal)
- Log::Fatal << "Cannot open file '" << filename << "' for writing. "
- << "Save failed." << std::endl;
- else
- Log::Warn << "Cannot open file '" << filename << "' for writing; save "
- << "failed." << std::endl;
-
- Timer::Stop("saving_data");
- return false;
- }
-
- bool unknownType = false;
- arma::file_type saveType;
- std::string stringType;
-
- if (extension == "csv")
- {
- saveType = arma::csv_ascii;
- stringType = "CSV data";
- }
- else if (extension == "txt")
- {
- saveType = arma::raw_ascii;
- stringType = "raw ASCII formatted data";
- }
- else if (extension == "bin")
- {
- saveType = arma::arma_binary;
- stringType = "Armadillo binary formatted data";
- }
- else if (extension == "pgm")
- {
- saveType = arma::pgm_binary;
- stringType = "PGM data";
- }
- else if (extension == "h5" || extension == "hdf5" || extension == "hdf" ||
- extension == "he5")
- {
-#ifdef ARMA_USE_HDF5
- saveType = arma::hdf5_binary;
- stringType = "HDF5 data";
-#else
- if (fatal)
- Log::Fatal << "Attempted to save HDF5 data to '" << filename << "', but "
- << "Armadillo was compiled without HDF5 support. Save failed."
- << std::endl;
- else
- Log::Warn << "Attempted to save HDF5 data to '" << filename << "', but "
- << "Armadillo was compiled without HDF5 support. Save failed."
- << std::endl;
-
- Timer::Stop("saving_data");
- return false;
-#endif
- }
- else
- {
- unknownType = true;
- saveType = arma::raw_binary; // Won't be used; prevent a warning.
- stringType = "";
- }
-
- // Provide error if we don't know the type.
- if (unknownType)
- {
- if (fatal)
- Log::Fatal << "Unable to determine format to save to from filename '"
- << filename << "'. Save failed." << std::endl;
- else
- Log::Warn << "Unable to determine format to save to from filename '"
- << filename << "'. Save failed." << std::endl;
- }
-
- // Try to save the file.
- Log::Info << "Saving " << stringType << " to '" << filename << "'."
- << std::endl;
-
- // Transpose the matrix.
- if (transpose)
- {
- arma::Mat<eT> tmp = trans(matrix);
-
- if (!tmp.quiet_save(stream, saveType))
- {
- if (fatal)
- Log::Fatal << "Save to '" << filename << "' failed." << std::endl;
- else
- Log::Warn << "Save to '" << filename << "' failed." << std::endl;
-
- Timer::Stop("saving_data");
- return false;
- }
- }
- else
- {
- if (!matrix.quiet_save(stream, saveType))
- {
- if (fatal)
- Log::Fatal << "Save to '" << filename << "' failed." << std::endl;
- else
- Log::Warn << "Save to '" << filename << "' failed." << std::endl;
-
- Timer::Stop("saving_data");
- return false;
- }
- }
-
- Timer::Stop("saving_data");
-
- // Finally return success.
- return true;
-}
-
-}; // namespace data
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/save_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/data/save_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/save_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/data/save_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,178 @@
+/**
+ * @file save_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of save functionality.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_DATA_SAVE_IMPL_HPP
+#define __MLPACK_CORE_DATA_SAVE_IMPL_HPP
+
+// In case it hasn't already been included.
+#include "save.hpp"
+
+namespace mlpack {
+namespace data {
+
+template<typename eT>
+bool Save(const std::string& filename,
+ const arma::Mat<eT>& matrix,
+ bool fatal,
+ bool transpose)
+{
+ Timer::Start("saving_data");
+
+ // First we will try to discriminate by file extension.
+ size_t ext = filename.rfind('.');
+ if (ext == std::string::npos)
+ {
+ if (fatal)
+ Log::Fatal << "No extension given with filename '" << filename << "'; "
+ << "type unknown. Save failed." << std::endl;
+ else
+ Log::Warn << "No extension given with filename '" << filename << "'; "
+ << "type unknown. Save failed." << std::endl;
+
+ return false;
+ }
+
+ // Get the actual extension.
+ std::string extension = filename.substr(ext + 1);
+
+ // Catch errors opening the file.
+ std::fstream stream;
+ stream.open(filename.c_str(), std::fstream::out);
+
+ if (!stream.is_open())
+ {
+ if (fatal)
+ Log::Fatal << "Cannot open file '" << filename << "' for writing. "
+ << "Save failed." << std::endl;
+ else
+ Log::Warn << "Cannot open file '" << filename << "' for writing; save "
+ << "failed." << std::endl;
+
+ Timer::Stop("saving_data");
+ return false;
+ }
+
+ bool unknownType = false;
+ arma::file_type saveType;
+ std::string stringType;
+
+ if (extension == "csv")
+ {
+ saveType = arma::csv_ascii;
+ stringType = "CSV data";
+ }
+ else if (extension == "txt")
+ {
+ saveType = arma::raw_ascii;
+ stringType = "raw ASCII formatted data";
+ }
+ else if (extension == "bin")
+ {
+ saveType = arma::arma_binary;
+ stringType = "Armadillo binary formatted data";
+ }
+ else if (extension == "pgm")
+ {
+ saveType = arma::pgm_binary;
+ stringType = "PGM data";
+ }
+ else if (extension == "h5" || extension == "hdf5" || extension == "hdf" ||
+ extension == "he5")
+ {
+#ifdef ARMA_USE_HDF5
+ saveType = arma::hdf5_binary;
+ stringType = "HDF5 data";
+#else
+ if (fatal)
+ Log::Fatal << "Attempted to save HDF5 data to '" << filename << "', but "
+ << "Armadillo was compiled without HDF5 support. Save failed."
+ << std::endl;
+ else
+ Log::Warn << "Attempted to save HDF5 data to '" << filename << "', but "
+ << "Armadillo was compiled without HDF5 support. Save failed."
+ << std::endl;
+
+ Timer::Stop("saving_data");
+ return false;
+#endif
+ }
+ else
+ {
+ unknownType = true;
+ saveType = arma::raw_binary; // Won't be used; prevent a warning.
+ stringType = "";
+ }
+
+ // Provide error if we don't know the type.
+ if (unknownType)
+ {
+ if (fatal)
+ Log::Fatal << "Unable to determine format to save to from filename '"
+ << filename << "'. Save failed." << std::endl;
+ else
+ Log::Warn << "Unable to determine format to save to from filename '"
+ << filename << "'. Save failed." << std::endl;
+ }
+
+ // Try to save the file.
+ Log::Info << "Saving " << stringType << " to '" << filename << "'."
+ << std::endl;
+
+ // Transpose the matrix.
+ if (transpose)
+ {
+ arma::Mat<eT> tmp = trans(matrix);
+
+ if (!tmp.quiet_save(stream, saveType))
+ {
+ if (fatal)
+ Log::Fatal << "Save to '" << filename << "' failed." << std::endl;
+ else
+ Log::Warn << "Save to '" << filename << "' failed." << std::endl;
+
+ Timer::Stop("saving_data");
+ return false;
+ }
+ }
+ else
+ {
+ if (!matrix.quiet_save(stream, saveType))
+ {
+ if (fatal)
+ Log::Fatal << "Save to '" << filename << "' failed." << std::endl;
+ else
+ Log::Warn << "Save to '" << filename << "' failed." << std::endl;
+
+ Timer::Stop("saving_data");
+ return false;
+ }
+ }
+
+ Timer::Stop("saving_data");
+
+ // Finally return success.
+ return true;
+}
+
+}; // namespace data
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/discrete_distribution.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/dists/discrete_distribution.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/discrete_distribution.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,132 +0,0 @@
-/**
- * @file discrete_distribution.cpp
- * @author Ryan Curtin
- *
- * Implementation of DiscreteDistribution probability distribution.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "discrete_distribution.hpp"
-
-using namespace mlpack;
-using namespace mlpack::distribution;
-
-/**
- * Return a randomly generated observation according to the probability
- * distribution defined by this object.
- */
-arma::vec DiscreteDistribution::Random() const
-{
- // Generate a random number.
- double randObs = math::Random();
- arma::vec result(1);
-
- double sumProb = 0;
- for (size_t obs = 0; obs < probabilities.n_elem; obs++)
- {
- if ((sumProb += probabilities[obs]) >= randObs)
- {
- result[0] = obs;
- return result;
- }
- }
-
- // This shouldn't happen.
- result[0] = probabilities.n_elem - 1;
- return result;
-}
-
-/**
- * Estimate the probability distribution directly from the given observations.
- */
-void DiscreteDistribution::Estimate(const arma::mat& observations)
-{
- // Clear old probabilities.
- probabilities.zeros();
-
- // Add the probability of each observation. The addition of 0.5 to the
- // observation is to turn the default flooring operation of the size_t cast
- // into a rounding operation.
- for (size_t i = 0; i < observations.n_cols; i++)
- {
- const size_t obs = size_t(observations(0, i) + 0.5);
-
- // Ensure that the observation is within the bounds.
- if (obs >= probabilities.n_elem)
- {
- Log::Debug << "DiscreteDistribution::Estimate(): observation " << i
- << " (" << obs << ") is invalid; observation must be in [0, "
- << probabilities.n_elem << "] for this distribution." << std::endl;
- }
-
- probabilities(obs)++;
- }
-
- // Now normalize the distribution.
- double sum = accu(probabilities);
- if (sum > 0)
- probabilities /= sum;
- else
- probabilities.fill(1 / probabilities.n_elem); // Force normalization.
-}
-
-/**
- * Estimate the probability distribution from the given observations when also
- * given probabilities that each observation is from this distribution.
- */
-void DiscreteDistribution::Estimate(const arma::mat& observations,
- const arma::vec& probObs)
-{
- // Clear old probabilities.
- probabilities.zeros();
-
- // Add the probability of each observation. The addition of 0.5 to the
- // observation is to turn the default flooring operation of the size_t cast
- // into a rounding observation.
- for (size_t i = 0; i < observations.n_cols; i++)
- {
- const size_t obs = size_t(observations(0, i) + 0.5);
-
- // Ensure that the observation is within the bounds.
- if (obs >= probabilities.n_elem)
- {
- Log::Debug << "DiscreteDistribution::Estimate(): observation " << i
- << " (" << obs << ") is invalid; observation must be in [0, "
- << probabilities.n_elem << "] for this distribution." << std::endl;
- }
-
- probabilities(obs) += probObs[i];
- }
-
- // Now normalize the distribution.
- double sum = accu(probabilities);
- if (sum > 0)
- probabilities /= sum;
- else
- probabilities.fill(1 / probabilities.n_elem); // Force normalization.
-}
-
-/*
- * Returns a string representation of this object.
- */
-std::string DiscreteDistribution::ToString() const
-{
- std::ostringstream convert;
- convert << "DiscreteDistribution [" << this << "]" << std::endl;
- convert << "Probabilities" << std::endl << probabilities;
- return convert.str();
-}
-
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/discrete_distribution.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/dists/discrete_distribution.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/discrete_distribution.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/discrete_distribution.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,132 @@
+/**
+ * @file discrete_distribution.cpp
+ * @author Ryan Curtin
+ *
+ * Implementation of DiscreteDistribution probability distribution.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "discrete_distribution.hpp"
+
+using namespace mlpack;
+using namespace mlpack::distribution;
+
+/**
+ * Return a randomly generated observation according to the probability
+ * distribution defined by this object.
+ */
+arma::vec DiscreteDistribution::Random() const
+{
+ // Generate a random number.
+ double randObs = math::Random();
+ arma::vec result(1);
+
+ double sumProb = 0;
+ for (size_t obs = 0; obs < probabilities.n_elem; obs++)
+ {
+ if ((sumProb += probabilities[obs]) >= randObs)
+ {
+ result[0] = obs;
+ return result;
+ }
+ }
+
+ // This shouldn't happen.
+ result[0] = probabilities.n_elem - 1;
+ return result;
+}
+
+/**
+ * Estimate the probability distribution directly from the given observations.
+ */
+void DiscreteDistribution::Estimate(const arma::mat& observations)
+{
+ // Clear old probabilities.
+ probabilities.zeros();
+
+ // Add the probability of each observation. The addition of 0.5 to the
+ // observation is to turn the default flooring operation of the size_t cast
+ // into a rounding operation.
+ for (size_t i = 0; i < observations.n_cols; i++)
+ {
+ const size_t obs = size_t(observations(0, i) + 0.5);
+
+ // Ensure that the observation is within the bounds.
+ if (obs >= probabilities.n_elem)
+ {
+ Log::Debug << "DiscreteDistribution::Estimate(): observation " << i
+ << " (" << obs << ") is invalid; observation must be in [0, "
+ << probabilities.n_elem << "] for this distribution." << std::endl;
+ }
+
+ probabilities(obs)++;
+ }
+
+ // Now normalize the distribution.
+ double sum = accu(probabilities);
+ if (sum > 0)
+ probabilities /= sum;
+ else
+ probabilities.fill(1 / probabilities.n_elem); // Force normalization.
+}
+
+/**
+ * Estimate the probability distribution from the given observations when also
+ * given probabilities that each observation is from this distribution.
+ */
+void DiscreteDistribution::Estimate(const arma::mat& observations,
+ const arma::vec& probObs)
+{
+ // Clear old probabilities.
+ probabilities.zeros();
+
+ // Add the probability of each observation. The addition of 0.5 to the
+ // observation is to turn the default flooring operation of the size_t cast
+ // into a rounding observation.
+ for (size_t i = 0; i < observations.n_cols; i++)
+ {
+ const size_t obs = size_t(observations(0, i) + 0.5);
+
+ // Ensure that the observation is within the bounds.
+ if (obs >= probabilities.n_elem)
+ {
+ Log::Debug << "DiscreteDistribution::Estimate(): observation " << i
+ << " (" << obs << ") is invalid; observation must be in [0, "
+ << probabilities.n_elem << "] for this distribution." << std::endl;
+ }
+
+ probabilities(obs) += probObs[i];
+ }
+
+ // Now normalize the distribution.
+ double sum = accu(probabilities);
+ if (sum > 0)
+ probabilities /= sum;
+ else
+ probabilities.fill(1 / probabilities.n_elem); // Force normalization.
+}
+
+/*
+ * Returns a string representation of this object.
+ */
+std::string DiscreteDistribution::ToString() const
+{
+ std::ostringstream convert;
+ convert << "DiscreteDistribution [" << this << "]" << std::endl;
+ convert << "Probabilities" << std::endl << probabilities;
+ return convert.str();
+}
+
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/discrete_distribution.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/dists/discrete_distribution.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/discrete_distribution.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,169 +0,0 @@
-/**
- * @file discrete_distribution.hpp
- * @author Ryan Curtin
- *
- * Implementation of the discrete distribution, where each discrete observation
- * has a given probability.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_HMM_DISTRIBUTIONS_DISCRETE_DISTRIBUTION_HPP
-#define __MLPACK_METHODS_HMM_DISTRIBUTIONS_DISCRETE_DISTRIBUTION_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace distribution /** Probability distributions. */ {
-
-/**
- * A discrete distribution where the only observations are discrete
- * observations. This is useful (for example) with discrete Hidden Markov
- * Models, where observations are non-negative integers representing specific
- * emissions.
- *
- * No bounds checking is performed for observations, so if an invalid
- * observation is passed (i.e. observation > numObservations), a crash will
- * probably occur.
- *
- * This distribution only supports one-dimensional observations, so when passing
- * an arma::vec as an observation, it should only have one dimension
- * (vec.n_rows == 1). Any additional dimensions will simply be ignored.
- *
- * @note
- * This class, like every other class in MLPACK, uses arma::vec to represent
- * observations. While a discrete distribution only has positive integers
- * (size_t) as observations, these can be converted to doubles (which is what
- * arma::vec holds). This distribution internally converts those doubles back
- * into size_t before comparisons.
- * @endnote
- */
-class DiscreteDistribution
-{
- public:
- /**
- * Default constructor, which creates a distribution that has no observations.
- */
- DiscreteDistribution() { /* nothing to do */ }
-
- /**
- * Define the discrete distribution as having numObservations possible
- * observations. The probability in each state will be set to (1 /
- * numObservations).
- *
- * @param numObservations Number of possible observations this distribution
- * can have.
- */
- DiscreteDistribution(const size_t numObservations) :
- probabilities(arma::ones<arma::vec>(numObservations) / numObservations)
- { /* nothing to do */ }
-
- /**
- * Define the discrete distribution as having the given probabilities for each
- * observation.
- *
- * @param probabilities Probabilities of each possible observation.
- */
- DiscreteDistribution(const arma::vec& probabilities)
- {
- // We must be sure that our distribution is normalized.
- double sum = accu(probabilities);
- if (sum > 0)
- this->probabilities = probabilities / sum;
- else
- {
- this->probabilities.set_size(probabilities.n_elem);
- this->probabilities.fill(1 / probabilities.n_elem);
- }
- }
-
- /**
- * Get the dimensionality of the distribution.
- */
- size_t Dimensionality() const { return 1; }
-
- /**
- * Return the probability of the given observation. If the observation is
- * greater than the number of possible observations, then a crash will
- * probably occur -- bounds checking is not performed.
- *
- * @param observation Observation to return the probability of.
- * @return Probability of the given observation.
- */
- double Probability(const arma::vec& observation) const
- {
- // Adding 0.5 helps ensure that we cast the floating point to a size_t
- // correctly.
- const size_t obs = size_t(observation[0] + 0.5);
-
- // Ensure that the observation is within the bounds.
- if (obs >= probabilities.n_elem)
- {
- Log::Debug << "DiscreteDistribution::Probability(): received observation "
- << obs << "; observation must be in [0, " << probabilities.n_elem
- << "] for this distribution." << std::endl;
- }
-
- return probabilities(obs);
- }
-
- /**
- * Return a randomly generated observation (one-dimensional vector; one
- * observation) according to the probability distribution defined by this
- * object.
- *
- * @return Random observation.
- */
- arma::vec Random() const;
-
- /**
- * Estimate the probability distribution directly from the given observations.
- * If any of the observations is greater than numObservations, a crash is
- * likely to occur.
- *
- * @param observations List of observations.
- */
- void Estimate(const arma::mat& observations);
-
- /**
- * Estimate the probability distribution from the given observations, taking
- * into account the probability of each observation actually being from this
- * distribution.
- *
- * @param observations List of observations.
- * @param probabilities List of probabilities that each observation is
- * actually from this distribution.
- */
- void Estimate(const arma::mat& observations,
- const arma::vec& probabilities);
-
- //! Return the vector of probabilities.
- const arma::vec& Probabilities() const { return probabilities; }
- //! Modify the vector of probabilities.
- arma::vec& Probabilities() { return probabilities; }
-
- /*
- * Returns a string representation of this object.
- */
- std::string ToString() const;
-
- private:
- arma::vec probabilities;
-};
-
-}; // namespace distribution
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/discrete_distribution.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/dists/discrete_distribution.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/discrete_distribution.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/discrete_distribution.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,169 @@
+/**
+ * @file discrete_distribution.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the discrete distribution, where each discrete observation
+ * has a given probability.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_HMM_DISTRIBUTIONS_DISCRETE_DISTRIBUTION_HPP
+#define __MLPACK_METHODS_HMM_DISTRIBUTIONS_DISCRETE_DISTRIBUTION_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace distribution /** Probability distributions. */ {
+
+/**
+ * A discrete distribution where the only observations are discrete
+ * observations. This is useful (for example) with discrete Hidden Markov
+ * Models, where observations are non-negative integers representing specific
+ * emissions.
+ *
+ * No bounds checking is performed for observations, so if an invalid
+ * observation is passed (i.e. observation > numObservations), a crash will
+ * probably occur.
+ *
+ * This distribution only supports one-dimensional observations, so when passing
+ * an arma::vec as an observation, it should only have one dimension
+ * (vec.n_rows == 1). Any additional dimensions will simply be ignored.
+ *
+ * @note
+ * This class, like every other class in MLPACK, uses arma::vec to represent
+ * observations. While a discrete distribution only has positive integers
+ * (size_t) as observations, these can be converted to doubles (which is what
+ * arma::vec holds). This distribution internally converts those doubles back
+ * into size_t before comparisons.
+ * @endnote
+ */
+class DiscreteDistribution
+{
+ public:
+ /**
+ * Default constructor, which creates a distribution that has no observations.
+ */
+ DiscreteDistribution() { /* nothing to do */ }
+
+ /**
+ * Define the discrete distribution as having numObservations possible
+ * observations. The probability in each state will be set to (1 /
+ * numObservations).
+ *
+ * @param numObservations Number of possible observations this distribution
+ * can have.
+ */
+ DiscreteDistribution(const size_t numObservations) :
+ probabilities(arma::ones<arma::vec>(numObservations) / numObservations)
+ { /* nothing to do */ }
+
+ /**
+ * Define the discrete distribution as having the given probabilities for each
+ * observation.
+ *
+ * @param probabilities Probabilities of each possible observation.
+ */
+ DiscreteDistribution(const arma::vec& probabilities)
+ {
+ // We must be sure that our distribution is normalized.
+ double sum = accu(probabilities);
+ if (sum > 0)
+ this->probabilities = probabilities / sum;
+ else
+ {
+ this->probabilities.set_size(probabilities.n_elem);
+ this->probabilities.fill(1 / probabilities.n_elem);
+ }
+ }
+
+ /**
+ * Get the dimensionality of the distribution.
+ */
+ size_t Dimensionality() const { return 1; }
+
+ /**
+ * Return the probability of the given observation. If the observation is
+ * greater than the number of possible observations, then a crash will
+ * probably occur -- bounds checking is not performed.
+ *
+ * @param observation Observation to return the probability of.
+ * @return Probability of the given observation.
+ */
+ double Probability(const arma::vec& observation) const
+ {
+ // Adding 0.5 helps ensure that we cast the floating point to a size_t
+ // correctly.
+ const size_t obs = size_t(observation[0] + 0.5);
+
+ // Ensure that the observation is within the bounds.
+ if (obs >= probabilities.n_elem)
+ {
+ Log::Debug << "DiscreteDistribution::Probability(): received observation "
+ << obs << "; observation must be in [0, " << probabilities.n_elem
+ << "] for this distribution." << std::endl;
+ }
+
+ return probabilities(obs);
+ }
+
+ /**
+ * Return a randomly generated observation (one-dimensional vector; one
+ * observation) according to the probability distribution defined by this
+ * object.
+ *
+ * @return Random observation.
+ */
+ arma::vec Random() const;
+
+ /**
+ * Estimate the probability distribution directly from the given observations.
+ * If any of the observations is greater than numObservations, a crash is
+ * likely to occur.
+ *
+ * @param observations List of observations.
+ */
+ void Estimate(const arma::mat& observations);
+
+ /**
+ * Estimate the probability distribution from the given observations, taking
+ * into account the probability of each observation actually being from this
+ * distribution.
+ *
+ * @param observations List of observations.
+ * @param probabilities List of probabilities that each observation is
+ * actually from this distribution.
+ */
+ void Estimate(const arma::mat& observations,
+ const arma::vec& probabilities);
+
+ //! Return the vector of probabilities.
+ const arma::vec& Probabilities() const { return probabilities; }
+ //! Modify the vector of probabilities.
+ arma::vec& Probabilities() { return probabilities; }
+
+ /*
+ * Returns a string representation of this object.
+ */
+ std::string ToString() const;
+
+ private:
+ arma::vec probabilities;
+};
+
+}; // namespace distribution
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/gaussian_distribution.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/dists/gaussian_distribution.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/gaussian_distribution.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,161 +0,0 @@
-/**
- * @file gaussian_distribution.cpp
- * @author Ryan Curtin
- *
- * Implementation of Gaussian distribution class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "gaussian_distribution.hpp"
-
-using namespace mlpack;
-using namespace mlpack::distribution;
-
-arma::vec GaussianDistribution::Random() const
-{
- // Should we store chol(covariance) for easier calculation later?
- return trans(chol(covariance)) * arma::randn<arma::vec>(mean.n_elem) + mean;
-}
-
-/**
- * Estimate the Gaussian distribution directly from the given observations.
- *
- * @param observations List of observations.
- */
-void GaussianDistribution::Estimate(const arma::mat& observations)
-{
- if (observations.n_cols > 0)
- {
- mean.zeros(observations.n_rows);
- covariance.zeros(observations.n_rows, observations.n_rows);
- }
- else // This will end up just being empty.
- {
- mean.zeros(0);
- covariance.zeros(0);
- return;
- }
-
- // Calculate the mean.
- for (size_t i = 0; i < observations.n_cols; i++)
- mean += observations.col(i);
-
- // Normalize the mean.
- mean /= observations.n_cols;
-
- // Now calculate the covariance.
- for (size_t i = 0; i < observations.n_cols; i++)
- {
- arma::vec obsNoMean = observations.col(i) - mean;
- covariance += obsNoMean * trans(obsNoMean);
- }
-
- // Finish estimating the covariance by normalizing, with the (1 / (n - 1)) so
- // that it is the unbiased estimator.
- covariance /= (observations.n_cols - 1);
-
- // Ensure that the covariance is positive definite.
- if (det(covariance) <= 1e-50)
- {
- Log::Debug << "GaussianDistribution::Estimate(): Covariance matrix is not "
- << "positive definite. Adding perturbation." << std::endl;
-
- double perturbation = 1e-30;
- while (det(covariance) <= 1e-50)
- {
- covariance.diag() += perturbation;
- perturbation *= 10; // Slow, but we don't want to add too much.
- }
- }
-}
-
-/**
- * Estimate the Gaussian distribution from the given observations, taking into
- * account the probability of each observation actually being from this
- * distribution.
- */
-void GaussianDistribution::Estimate(const arma::mat& observations,
- const arma::vec& probabilities)
-{
- if (observations.n_cols > 0)
- {
- mean.zeros(observations.n_rows);
- covariance.zeros(observations.n_rows, observations.n_rows);
- }
- else // This will end up just being empty.
- {
- mean.zeros(0);
- covariance.zeros(0);
- return;
- }
-
- double sumProb = 0;
-
- // First calculate the mean, and save the sum of all the probabilities for
- // later normalization.
- for (size_t i = 0; i < observations.n_cols; i++)
- {
- mean += probabilities[i] * observations.col(i);
- sumProb += probabilities[i];
- }
-
- if (sumProb == 0)
- {
- // Nothing in this Gaussian! At least set the covariance so that it's
- // invertible.
- covariance.diag() += 1e-50;
- return;
- }
-
- // Normalize.
- mean /= sumProb;
-
- // Now find the covariance.
- for (size_t i = 0; i < observations.n_cols; i++)
- {
- arma::vec obsNoMean = observations.col(i) - mean;
- covariance += probabilities[i] * (obsNoMean * trans(obsNoMean));
- }
-
- // This is probably biased, but I don't know how to unbias it.
- covariance /= sumProb;
-
- // Ensure that the covariance is positive definite.
- if (det(covariance) <= 1e-50)
- {
- Log::Debug << "GaussianDistribution::Estimate(): Covariance matrix is not "
- << "positive definite. Adding perturbation." << std::endl;
-
- double perturbation = 1e-30;
- while (det(covariance) <= 1e-50)
- {
- covariance.diag() += perturbation;
- perturbation *= 10; // Slow, but we don't want to add too much.
- }
- }
-}
-
-/**
- * Returns a string representation of this object.
- */
-std::string GaussianDistribution::ToString() const
-{
- std::ostringstream convert;
- convert << "GaussianDistribution: " << this << std::endl;
- convert << "mean: " << std::endl << mean << std::endl;
- convert << "covariance: " << std::endl << covariance << std::endl;
- return convert.str();
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/gaussian_distribution.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/dists/gaussian_distribution.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/gaussian_distribution.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/gaussian_distribution.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,161 @@
+/**
+ * @file gaussian_distribution.cpp
+ * @author Ryan Curtin
+ *
+ * Implementation of Gaussian distribution class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "gaussian_distribution.hpp"
+
+using namespace mlpack;
+using namespace mlpack::distribution;
+
+arma::vec GaussianDistribution::Random() const
+{
+ // Should we store chol(covariance) for easier calculation later?
+ return trans(chol(covariance)) * arma::randn<arma::vec>(mean.n_elem) + mean;
+}
+
+/**
+ * Estimate the Gaussian distribution directly from the given observations.
+ *
+ * @param observations List of observations.
+ */
+void GaussianDistribution::Estimate(const arma::mat& observations)
+{
+ if (observations.n_cols > 0)
+ {
+ mean.zeros(observations.n_rows);
+ covariance.zeros(observations.n_rows, observations.n_rows);
+ }
+ else // This will end up just being empty.
+ {
+ mean.zeros(0);
+ covariance.zeros(0);
+ return;
+ }
+
+ // Calculate the mean.
+ for (size_t i = 0; i < observations.n_cols; i++)
+ mean += observations.col(i);
+
+ // Normalize the mean.
+ mean /= observations.n_cols;
+
+ // Now calculate the covariance.
+ for (size_t i = 0; i < observations.n_cols; i++)
+ {
+ arma::vec obsNoMean = observations.col(i) - mean;
+ covariance += obsNoMean * trans(obsNoMean);
+ }
+
+ // Finish estimating the covariance by normalizing, with the (1 / (n - 1)) so
+ // that it is the unbiased estimator.
+ covariance /= (observations.n_cols - 1);
+
+ // Ensure that the covariance is positive definite.
+ if (det(covariance) <= 1e-50)
+ {
+ Log::Debug << "GaussianDistribution::Estimate(): Covariance matrix is not "
+ << "positive definite. Adding perturbation." << std::endl;
+
+ double perturbation = 1e-30;
+ while (det(covariance) <= 1e-50)
+ {
+ covariance.diag() += perturbation;
+ perturbation *= 10; // Slow, but we don't want to add too much.
+ }
+ }
+}
+
+/**
+ * Estimate the Gaussian distribution from the given observations, taking into
+ * account the probability of each observation actually being from this
+ * distribution.
+ */
+void GaussianDistribution::Estimate(const arma::mat& observations,
+ const arma::vec& probabilities)
+{
+ if (observations.n_cols > 0)
+ {
+ mean.zeros(observations.n_rows);
+ covariance.zeros(observations.n_rows, observations.n_rows);
+ }
+ else // This will end up just being empty.
+ {
+ mean.zeros(0);
+ covariance.zeros(0);
+ return;
+ }
+
+ double sumProb = 0;
+
+ // First calculate the mean, and save the sum of all the probabilities for
+ // later normalization.
+ for (size_t i = 0; i < observations.n_cols; i++)
+ {
+ mean += probabilities[i] * observations.col(i);
+ sumProb += probabilities[i];
+ }
+
+ if (sumProb == 0)
+ {
+ // Nothing in this Gaussian! At least set the covariance so that it's
+ // invertible.
+ covariance.diag() += 1e-50;
+ return;
+ }
+
+ // Normalize.
+ mean /= sumProb;
+
+ // Now find the covariance.
+ for (size_t i = 0; i < observations.n_cols; i++)
+ {
+ arma::vec obsNoMean = observations.col(i) - mean;
+ covariance += probabilities[i] * (obsNoMean * trans(obsNoMean));
+ }
+
+ // This is probably biased, but I don't know how to unbias it.
+ covariance /= sumProb;
+
+ // Ensure that the covariance is positive definite.
+ if (det(covariance) <= 1e-50)
+ {
+ Log::Debug << "GaussianDistribution::Estimate(): Covariance matrix is not "
+ << "positive definite. Adding perturbation." << std::endl;
+
+ double perturbation = 1e-30;
+ while (det(covariance) <= 1e-50)
+ {
+ covariance.diag() += perturbation;
+ perturbation *= 10; // Slow, but we don't want to add too much.
+ }
+ }
+}
+
+/**
+ * Returns a string representation of this object.
+ */
+std::string GaussianDistribution::ToString() const
+{
+ std::ostringstream convert;
+ convert << "GaussianDistribution: " << this << std::endl;
+ convert << "mean: " << std::endl << mean << std::endl;
+ convert << "covariance: " << std::endl << covariance << std::endl;
+ return convert.str();
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/gaussian_distribution.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/dists/gaussian_distribution.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/gaussian_distribution.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,127 +0,0 @@
-/**
- * @file gaussian_distribution.hpp
- * @author Ryan Curtin
- *
- * Implementation of the Gaussian distribution.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_HMM_DISTRIBUTIONS_GAUSSIAN_DISTRIBUTION_HPP
-#define __MLPACK_METHODS_HMM_DISTRIBUTIONS_GAUSSIAN_DISTRIBUTION_HPP
-
-#include <mlpack/core.hpp>
-// Should be somewhere else, maybe in core.
-#include <mlpack/methods/gmm/phi.hpp>
-
-namespace mlpack {
-namespace distribution {
-
-/**
- * A single multivariate Gaussian distribution.
- */
-class GaussianDistribution
-{
- private:
- //! Mean of the distribution.
- arma::vec mean;
- //! Covariance of the distribution.
- arma::mat covariance;
-
- public:
- /**
- * Default constructor, which creates a Gaussian with zero dimension.
- */
- GaussianDistribution() { /* nothing to do */ }
-
- /**
- * Create a Gaussian distribution with zero mean and identity covariance with
- * the given dimensionality.
- */
- GaussianDistribution(const size_t dimension) :
- mean(arma::zeros<arma::vec>(dimension)),
- covariance(arma::eye<arma::mat>(dimension, dimension))
- { /* Nothing to do. */ }
-
- /**
- * Create a Gaussian distribution with the given mean and covariance.
- */
- GaussianDistribution(const arma::vec& mean, const arma::mat& covariance) :
- mean(mean), covariance(covariance) { /* Nothing to do. */ }
-
- //! Return the dimensionality of this distribution.
- size_t Dimensionality() const { return mean.n_elem; }
-
- /**
- * Return the probability of the given observation.
- */
- double Probability(const arma::vec& observation) const
- {
- return mlpack::gmm::phi(observation, mean, covariance);
- }
-
- /**
- * Return a randomly generated observation according to the probability
- * distribution defined by this object.
- *
- * @return Random observation from this Gaussian distribution.
- */
- arma::vec Random() const;
-
- /**
- * Estimate the Gaussian distribution directly from the given observations.
- *
- * @param observations List of observations.
- */
- void Estimate(const arma::mat& observations);
-
- /**
- * Estimate the Gaussian distribution from the given observations, taking into
- * account the probability of each observation actually being from this
- * distribution.
- */
- void Estimate(const arma::mat& observations,
- const arma::vec& probabilities);
-
- /**
- * Return the mean.
- */
- const arma::vec& Mean() const { return mean; }
-
- /**
- * Return a modifiable copy of the mean.
- */
- arma::vec& Mean() { return mean; }
-
- /**
- * Return the covariance matrix.
- */
- const arma::mat& Covariance() const { return covariance; }
-
- /**
- * Return a modifiable copy of the covariance.
- */
- arma::mat& Covariance() { return covariance; }
-
- /**
- * Returns a string representation of this object.
- */
- std::string ToString() const;
-};
-
-}; // namespace distribution
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/gaussian_distribution.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/dists/gaussian_distribution.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/gaussian_distribution.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/dists/gaussian_distribution.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,127 @@
+/**
+ * @file gaussian_distribution.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the Gaussian distribution.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_HMM_DISTRIBUTIONS_GAUSSIAN_DISTRIBUTION_HPP
+#define __MLPACK_METHODS_HMM_DISTRIBUTIONS_GAUSSIAN_DISTRIBUTION_HPP
+
+#include <mlpack/core.hpp>
+// Should be somewhere else, maybe in core.
+#include <mlpack/methods/gmm/phi.hpp>
+
+namespace mlpack {
+namespace distribution {
+
+/**
+ * A single multivariate Gaussian distribution.
+ */
+class GaussianDistribution
+{
+ private:
+ //! Mean of the distribution.
+ arma::vec mean;
+ //! Covariance of the distribution.
+ arma::mat covariance;
+
+ public:
+ /**
+ * Default constructor, which creates a Gaussian with zero dimension.
+ */
+ GaussianDistribution() { /* nothing to do */ }
+
+ /**
+ * Create a Gaussian distribution with zero mean and identity covariance with
+ * the given dimensionality.
+ */
+ GaussianDistribution(const size_t dimension) :
+ mean(arma::zeros<arma::vec>(dimension)),
+ covariance(arma::eye<arma::mat>(dimension, dimension))
+ { /* Nothing to do. */ }
+
+ /**
+ * Create a Gaussian distribution with the given mean and covariance.
+ */
+ GaussianDistribution(const arma::vec& mean, const arma::mat& covariance) :
+ mean(mean), covariance(covariance) { /* Nothing to do. */ }
+
+ //! Return the dimensionality of this distribution.
+ size_t Dimensionality() const { return mean.n_elem; }
+
+ /**
+ * Return the probability of the given observation.
+ */
+ double Probability(const arma::vec& observation) const
+ {
+ return mlpack::gmm::phi(observation, mean, covariance);
+ }
+
+ /**
+ * Return a randomly generated observation according to the probability
+ * distribution defined by this object.
+ *
+ * @return Random observation from this Gaussian distribution.
+ */
+ arma::vec Random() const;
+
+ /**
+ * Estimate the Gaussian distribution directly from the given observations.
+ *
+ * @param observations List of observations.
+ */
+ void Estimate(const arma::mat& observations);
+
+ /**
+ * Estimate the Gaussian distribution from the given observations, taking into
+ * account the probability of each observation actually being from this
+ * distribution.
+ */
+ void Estimate(const arma::mat& observations,
+ const arma::vec& probabilities);
+
+ /**
+ * Return the mean.
+ */
+ const arma::vec& Mean() const { return mean; }
+
+ /**
+ * Return a modifiable copy of the mean.
+ */
+ arma::vec& Mean() { return mean; }
+
+ /**
+ * Return the covariance matrix.
+ */
+ const arma::mat& Covariance() const { return covariance; }
+
+ /**
+ * Return a modifiable copy of the covariance.
+ */
+ arma::mat& Covariance() { return covariance; }
+
+ /**
+ * Returns a string representation of this object.
+ */
+ std::string ToString() const;
+};
+
+}; // namespace distribution
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/cosine_distance.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/cosine_distance.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/cosine_distance.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,60 +0,0 @@
-/**
- * @file cosine_distance.hpp
- * @author Ryan Curtin
- *
- * This implements the cosine distance (or cosine similarity) between two
- * vectors, which is a measure of the angle between the two vectors.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_KERNELS_COSINE_DISTANCE_HPP
-#define __MLPACK_CORE_KERNELS_COSINE_DISTANCE_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace kernel {
-
-/**
- * The cosine distance (or cosine similarity). It is defined by
- *
- * @f[
- * d(a, b) = \frac{a^T b}{|| a || || b ||}
- * @f]
- *
- * and this class assumes the standard L2 inner product.
- */
-class CosineDistance
-{
- public:
- /**
- * Computes the cosine distance between two points.
- *
- * @param a First vector.
- * @param b Second vector.
- * @return d(a, b).
- */
- template<typename VecType>
- static double Evaluate(const VecType& a, const VecType& b);
-};
-
-}; // namespace kernel
-}; // namespace mlpack
-
-// Include implementation.
-#include "cosine_distance_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/cosine_distance.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/cosine_distance.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/cosine_distance.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/cosine_distance.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,60 @@
+/**
+ * @file cosine_distance.hpp
+ * @author Ryan Curtin
+ *
+ * This implements the cosine distance (or cosine similarity) between two
+ * vectors, which is a measure of the angle between the two vectors.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_KERNELS_COSINE_DISTANCE_HPP
+#define __MLPACK_CORE_KERNELS_COSINE_DISTANCE_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace kernel {
+
+/**
+ * The cosine distance (or cosine similarity). It is defined by
+ *
+ * @f[
+ * d(a, b) = \frac{a^T b}{|| a || || b ||}
+ * @f]
+ *
+ * and this class assumes the standard L2 inner product.
+ */
+class CosineDistance
+{
+ public:
+ /**
+ * Computes the cosine distance between two points.
+ *
+ * @param a First vector.
+ * @param b Second vector.
+ * @return d(a, b).
+ */
+ template<typename VecType>
+ static double Evaluate(const VecType& a, const VecType& b);
+};
+
+}; // namespace kernel
+}; // namespace mlpack
+
+// Include implementation.
+#include "cosine_distance_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/cosine_distance_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/cosine_distance_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/cosine_distance_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,40 +0,0 @@
-/**
- * @file cosine_distance_impl.hpp
- * @author Ryan Curtin
- *
- * This implements the cosine distance.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_KERNELS_COSINE_DISTANCE_IMPL_HPP
-#define __MLPACK_CORE_KERNELS_COSINE_DISTANCE_IMPL_HPP
-
-#include "cosine_distance.hpp"
-
-namespace mlpack {
-namespace kernel {
-
-template<typename VecType>
-double CosineDistance::Evaluate(const VecType& a, const VecType& b)
-{
- // Since we are using the L2 inner product, this is easy.
- return dot(a, b) / (norm(a, 2) * norm(b, 2));
-}
-
-}; // namespace kernel
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/cosine_distance_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/cosine_distance_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/cosine_distance_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/cosine_distance_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,40 @@
+/**
+ * @file cosine_distance_impl.hpp
+ * @author Ryan Curtin
+ *
+ * This implements the cosine distance.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_KERNELS_COSINE_DISTANCE_IMPL_HPP
+#define __MLPACK_CORE_KERNELS_COSINE_DISTANCE_IMPL_HPP
+
+#include "cosine_distance.hpp"
+
+namespace mlpack {
+namespace kernel {
+
+template<typename VecType>
+double CosineDistance::Evaluate(const VecType& a, const VecType& b)
+{
+ // Since we are using the L2 inner product, this is easy.
+ return dot(a, b) / (norm(a, 2) * norm(b, 2));
+}
+
+}; // namespace kernel
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/epanechnikov_kernel.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/epanechnikov_kernel.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/epanechnikov_kernel.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,48 +0,0 @@
-/**
- * @file epanechnikov_kernel.cpp
- * @author Neil Slagle
- *
- * Implementation of non-template Epanechnikov kernels.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "epanechnikov_kernel.hpp"
-
-#include <boost/math/special_functions/gamma.hpp>
-
-using namespace mlpack;
-using namespace mlpack::kernel;
-
-/**
- * Compute the normalizer of this Epanechnikov kernel for the given dimension.
- *
- * @param dimension Dimension to calculate the normalizer for.
- */
-double EpanechnikovKernel::Normalizer(const size_t dimension)
-{
- return 2.0 * pow(bandwidth, (double) dimension) *
- std::pow(M_PI, dimension / 2.0) /
- (boost::math::tgamma(dimension / 2.0 + 1.0) * (dimension + 2.0));
-}
-
-/**
- * Evaluate the kernel not for two points but for a numerical value.
- */
-double EpanechnikovKernel::Evaluate(const double t)
-{
- double evaluatee = 1.0 - t * t * inverseBandwidthSquared;
- return (evaluatee > 0.0) ? evaluatee : 0.0;
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/epanechnikov_kernel.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/epanechnikov_kernel.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/epanechnikov_kernel.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/epanechnikov_kernel.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,48 @@
+/**
+ * @file epanechnikov_kernel.cpp
+ * @author Neil Slagle
+ *
+ * Implementation of non-template Epanechnikov kernels.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "epanechnikov_kernel.hpp"
+
+#include <boost/math/special_functions/gamma.hpp>
+
+using namespace mlpack;
+using namespace mlpack::kernel;
+
+/**
+ * Compute the normalizer of this Epanechnikov kernel for the given dimension.
+ *
+ * @param dimension Dimension to calculate the normalizer for.
+ */
+double EpanechnikovKernel::Normalizer(const size_t dimension)
+{
+ return 2.0 * pow(bandwidth, (double) dimension) *
+ std::pow(M_PI, dimension / 2.0) /
+ (boost::math::tgamma(dimension / 2.0 + 1.0) * (dimension + 2.0));
+}
+
+/**
+ * Evaluate the kernel not for two points but for a numerical value.
+ */
+double EpanechnikovKernel::Evaluate(const double t)
+{
+ double evaluatee = 1.0 - t * t * inverseBandwidthSquared;
+ return (evaluatee > 0.0) ? evaluatee : 0.0;
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/epanechnikov_kernel.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/epanechnikov_kernel.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/epanechnikov_kernel.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,98 +0,0 @@
-/**
- * @file epanechnikov_kernel.hpp
- * @author Neil Slagle
- *
- * Definition of the Epanechnikov kernel.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_KERNELS_EPANECHNIKOV_KERNEL_HPP
-#define __MLPACK_CORE_KERNELS_EPANECHNIKOV_KERNEL_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace kernel {
-
-/**
- * The Epanechnikov kernel, defined as
- *
- * @f[
- * K(x, y) = \max \{0, 1 - || x - y ||^2_2 / b^2 \}
- * @f]
- *
- * where @f$ b @f$ is the bandwidth the of the kernel (defaults to 1.0).
- */
-class EpanechnikovKernel
-{
- public:
- /**
- * Instantiate the Epanechnikov kernel with the given bandwidth (default 1.0).
- *
- * @param bandwidth Bandwidth of the kernel.
- */
- EpanechnikovKernel(const double bandwidth = 1.0) :
- bandwidth(bandwidth),
- inverseBandwidthSquared(1.0 / (bandwidth * bandwidth))
- { }
-
- /**
- * Evaluate the Epanechnikov kernel on the given two inputs.
- *
- * @param a One input vector.
- * @param b The other input vector.
- */
- template<typename Vec1Type, typename Vec2Type>
- double Evaluate(const Vec1Type& a, const Vec2Type& b);
-
- /**
- * Obtains the convolution integral [integral of K(||x-a||) K(||b-x||) dx]
- * for the two vectors.
- *
- * @tparam VecType Type of vector (arma::vec, arma::spvec should be expected).
- * @param a First vector.
- * @param b Second vector.
- * @return the convolution integral value.
- */
- template<typename VecType>
- double ConvolutionIntegral(const VecType& a, const VecType& b);
-
- /**
- * Compute the normalizer of this Epanechnikov kernel for the given dimension.
- *
- * @param dimension Dimension to calculate the normalizer for.
- */
- double Normalizer(const size_t dimension);
-
- /**
- * Evaluate the kernel not for two points but for a numerical value.
- */
- double Evaluate(const double t);
-
- private:
- //! Bandwidth of the kernel.
- double bandwidth;
- //! Cached value of the inverse bandwidth squared (to speed up computation).
- double inverseBandwidthSquared;
-};
-
-}; // namespace kernel
-}; // namespace mlpack
-
-// Include implementation.
-#include "epanechnikov_kernel_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/epanechnikov_kernel.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/epanechnikov_kernel.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/epanechnikov_kernel.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/epanechnikov_kernel.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,98 @@
+/**
+ * @file epanechnikov_kernel.hpp
+ * @author Neil Slagle
+ *
+ * Definition of the Epanechnikov kernel.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_KERNELS_EPANECHNIKOV_KERNEL_HPP
+#define __MLPACK_CORE_KERNELS_EPANECHNIKOV_KERNEL_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace kernel {
+
+/**
+ * The Epanechnikov kernel, defined as
+ *
+ * @f[
+ * K(x, y) = \max \{0, 1 - || x - y ||^2_2 / b^2 \}
+ * @f]
+ *
+ * where @f$ b @f$ is the bandwidth the of the kernel (defaults to 1.0).
+ */
+class EpanechnikovKernel
+{
+ public:
+ /**
+ * Instantiate the Epanechnikov kernel with the given bandwidth (default 1.0).
+ *
+ * @param bandwidth Bandwidth of the kernel.
+ */
+ EpanechnikovKernel(const double bandwidth = 1.0) :
+ bandwidth(bandwidth),
+ inverseBandwidthSquared(1.0 / (bandwidth * bandwidth))
+ { }
+
+ /**
+ * Evaluate the Epanechnikov kernel on the given two inputs.
+ *
+ * @param a One input vector.
+ * @param b The other input vector.
+ */
+ template<typename Vec1Type, typename Vec2Type>
+ double Evaluate(const Vec1Type& a, const Vec2Type& b);
+
+ /**
+ * Obtains the convolution integral [integral of K(||x-a||) K(||b-x||) dx]
+ * for the two vectors.
+ *
+ * @tparam VecType Type of vector (arma::vec, arma::spvec should be expected).
+ * @param a First vector.
+ * @param b Second vector.
+ * @return the convolution integral value.
+ */
+ template<typename VecType>
+ double ConvolutionIntegral(const VecType& a, const VecType& b);
+
+ /**
+ * Compute the normalizer of this Epanechnikov kernel for the given dimension.
+ *
+ * @param dimension Dimension to calculate the normalizer for.
+ */
+ double Normalizer(const size_t dimension);
+
+ /**
+ * Evaluate the kernel not for two points but for a numerical value.
+ */
+ double Evaluate(const double t);
+
+ private:
+ //! Bandwidth of the kernel.
+ double bandwidth;
+ //! Cached value of the inverse bandwidth squared (to speed up computation).
+ double inverseBandwidthSquared;
+};
+
+}; // namespace kernel
+}; // namespace mlpack
+
+// Include implementation.
+#include "epanechnikov_kernel_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/epanechnikov_kernel_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/epanechnikov_kernel_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/epanechnikov_kernel_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,88 +0,0 @@
-/**
- * @file epanechnikov_kernel_impl.hpp
- * @author Neil Slagle
- *
- * Implementation of template-based Epanechnikov kernel functions.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_KERNELS_EPANECHNIKOV_KERNEL_IMPL_HPP
-#define __MLPACK_CORE_KERNELS_EPANECHNIKOV_KERNEL_IMPL_HPP
-
-// In case it hasn't already been included.
-#include "epanechnikov_kernel.hpp"
-
-#include <mlpack/core/metrics/lmetric.hpp>
-
-namespace mlpack {
-namespace kernel {
-
-template<typename Vec1Type, typename Vec2Type>
-inline double EpanechnikovKernel::Evaluate(const Vec1Type& a, const Vec2Type& b)
-{
- return std::max(0.0, 1.0 - metric::SquaredEuclideanDistance::Evaluate(a, b)
- * inverseBandwidthSquared);
-}
-
-/**
- * Obtains the convolution integral [integral of K(||x-a||) K(||b-x||) dx]
- * for the two vectors.
- *
- * @tparam VecType Type of vector (arma::vec, arma::spvec should be expected).
- * @param a First vector.
- * @param b Second vector.
- * @return the convolution integral value.
- */
-template<typename VecType>
-double EpanechnikovKernel::ConvolutionIntegral(const VecType& a,
- const VecType& b)
-{
- double distance = sqrt(metric::SquaredEuclideanDistance::Evaluate(a, b));
- if (distance >= 2.0 * bandwidth)
- return 0.0;
-
- double volumeSquared = std::pow(Normalizer(a.n_rows), 2.0);
-
- switch (a.n_rows)
- {
- case 1:
- return 1.0 / volumeSquared *
- (16.0 / 15.0 * bandwidth - 4.0 * distance * distance /
- (3.0 * bandwidth) + 2.0 * distance * distance * distance /
- (3.0 * bandwidth * bandwidth) -
- std::pow(distance, 5.0) / (30.0 * std::pow(bandwidth, 4.0)));
- break;
- case 2:
- return 1.0 / volumeSquared *
- ((2.0 / 3.0 * bandwidth * bandwidth - distance * distance) *
- asin(sqrt(1.0 - std::pow(distance / (2.0 * bandwidth), 2.0))) +
- sqrt(4.0 * bandwidth * bandwidth - distance * distance) *
- (distance / 6.0 + 2.0 / 9.0 * distance *
- std::pow(distance / bandwidth, 2.0) - distance / 72.0 *
- std::pow(distance / bandwidth, 4.0)));
- break;
- default:
- Log::Fatal << "EpanechnikovKernel::ConvolutionIntegral(): dimension "
- << a.n_rows << " not supported.";
- return -1.0; // This line will not execute.
- break;
- }
-}
-
-}; // namespace kernel
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/epanechnikov_kernel_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/epanechnikov_kernel_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/epanechnikov_kernel_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/epanechnikov_kernel_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,88 @@
+/**
+ * @file epanechnikov_kernel_impl.hpp
+ * @author Neil Slagle
+ *
+ * Implementation of template-based Epanechnikov kernel functions.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_KERNELS_EPANECHNIKOV_KERNEL_IMPL_HPP
+#define __MLPACK_CORE_KERNELS_EPANECHNIKOV_KERNEL_IMPL_HPP
+
+// In case it hasn't already been included.
+#include "epanechnikov_kernel.hpp"
+
+#include <mlpack/core/metrics/lmetric.hpp>
+
+namespace mlpack {
+namespace kernel {
+
+template<typename Vec1Type, typename Vec2Type>
+inline double EpanechnikovKernel::Evaluate(const Vec1Type& a, const Vec2Type& b)
+{
+ return std::max(0.0, 1.0 - metric::SquaredEuclideanDistance::Evaluate(a, b)
+ * inverseBandwidthSquared);
+}
+
+/**
+ * Obtains the convolution integral [integral of K(||x-a||) K(||b-x||) dx]
+ * for the two vectors.
+ *
+ * @tparam VecType Type of vector (arma::vec, arma::spvec should be expected).
+ * @param a First vector.
+ * @param b Second vector.
+ * @return the convolution integral value.
+ */
+template<typename VecType>
+double EpanechnikovKernel::ConvolutionIntegral(const VecType& a,
+ const VecType& b)
+{
+ double distance = sqrt(metric::SquaredEuclideanDistance::Evaluate(a, b));
+ if (distance >= 2.0 * bandwidth)
+ return 0.0;
+
+ double volumeSquared = std::pow(Normalizer(a.n_rows), 2.0);
+
+ switch (a.n_rows)
+ {
+ case 1:
+ return 1.0 / volumeSquared *
+ (16.0 / 15.0 * bandwidth - 4.0 * distance * distance /
+ (3.0 * bandwidth) + 2.0 * distance * distance * distance /
+ (3.0 * bandwidth * bandwidth) -
+ std::pow(distance, 5.0) / (30.0 * std::pow(bandwidth, 4.0)));
+ break;
+ case 2:
+ return 1.0 / volumeSquared *
+ ((2.0 / 3.0 * bandwidth * bandwidth - distance * distance) *
+ asin(sqrt(1.0 - std::pow(distance / (2.0 * bandwidth), 2.0))) +
+ sqrt(4.0 * bandwidth * bandwidth - distance * distance) *
+ (distance / 6.0 + 2.0 / 9.0 * distance *
+ std::pow(distance / bandwidth, 2.0) - distance / 72.0 *
+ std::pow(distance / bandwidth, 4.0)));
+ break;
+ default:
+ Log::Fatal << "EpanechnikovKernel::ConvolutionIntegral(): dimension "
+ << a.n_rows << " not supported.";
+ return -1.0; // This line will not execute.
+ break;
+ }
+}
+
+}; // namespace kernel
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/example_kernel.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/example_kernel.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/example_kernel.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,151 +0,0 @@
-/**
- * @file example_kernel.hpp
- * @author Ryan Curtin
- *
- * This is an example kernel. If you are making your own kernel, follow the
- * outline specified in this file.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_KERNELS_EXAMPLE_KERNEL_HPP
-#define __MLPACK_CORE_KERNELS_EXAMPLE_KERNEL_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-
-/**
- * @brief Kernel functions.
- *
- * This namespace contains kernel functions, which evaluate some kernel function
- * @f$ K(x, y) @f$ for some arbitrary vectors @f$ x @f$ and @f$ y @f$ of the
- * same dimension. The single restriction on the function @f$ K(x, y) @f$ is
- * that it must satisfy Mercer's condition:
- *
- * @f[
- * \int \int K(x, y) g(x) g(y) dx dy \ge 0
- * @f]
- *
- * for all square integrable functions @f$ g(x) @f$.
- *
- * The kernels in this namespace all implement the same methods as the
- * ExampleKernel class. Any additional custom kernels should implement all the
- * methods that class implements; in addition, any method using a kernel should
- * rely on any arbitrary kernel function class having a default constructor and
- * a function
- *
- * @code
- * double Evaluate(arma::vec&, arma::vec&);
- * @endcode
- */
-namespace kernel {
-
-/**
- * An example kernel function. This is not a useful kernel, but it implements
- * the two functions necessary to satisfy the Kernel policy (so that a class can
- * be used whenever an MLPACK method calls for a `typename Kernel` template
- * parameter.
- *
- * All that is necessary is a constructor and an `Evaluate()` function. More
- * methods could be added; for instance, one useful idea is a constructor which
- * takes parameters for a kernel (for instance, the width of the Gaussian for a
- * Gaussian kernel). However, MLPACK methods cannot count on these various
- * constructors existing, which is why most methods allow passing an
- * already-instantiated kernel object (and by default the method will construct
- * the kernel with the default constructor). So, for instance,
- *
- * @code
- * GaussianKernel k(5.0);
- * KDE<GaussianKernel> kde(dataset, k);
- * @endcode
- *
- * will set up KDE using a Gaussian kernel with a width of 5.0, but
- *
- * @code
- * KDE<GaussianKernel> kde(dataset);
- * @endcode
- *
- * will create the kernel with the default constructor. It is important (but
- * not strictly mandatory) that your default constructor still gives a working
- * kernel.
- *
- * @note
- * Not all kernels require state. For instance, the regular dot product needs
- * no parameters. In that case, no local variables are necessary and
- * `Evaluate()` can (and should) be declared static. However, for greater
- * generalization, MLPACK methods expect all kernels to require state and hence
- * must store instantiated kernel functions; this is why a default constructor
- * is necessary.
- * @endnote
- */
-class ExampleKernel
-{
- public:
- /**
- * The default constructor, which takes no parameters. Because our simple
- * example kernel has no internal parameters that need to be stored, the
- * constructor does not need to do anything. For a more complex example, see
- * the GaussianKernel, which stores an internal parameter.
- */
- ExampleKernel() { }
-
- /**
- * Evaluates the kernel function for two given vectors. In this case, because
- * our simple example kernel has no internal parameters, we can declare the
- * function static. For a more complex example which cannot be declared
- * static, see the GaussianKernel, which stores an internal parameter.
- *
- * @tparam VecType Type of vector (arma::vec, arma::spvec should be expected).
- * @param a First vector.
- * @param b Second vector.
- * @return K(a, b).
- */
- template<typename VecType>
- static double Evaluate(const VecType& a, const VecType& b) { return 0; }
-
- /**
- * Obtains the convolution integral [integral K(||x-a||)K(||b-x||)dx]
- * for the two vectors. In this case, because
- * our simple example kernel has no internal parameters, we can declare the
- * function static. For a more complex example which cannot be declared
- * static, see the GaussianKernel, which stores an internal parameter.
- *
- * @tparam VecType Type of vector (arma::vec, arma::spvec should be expected).
- * @param a First vector.
- * @param b Second vector.
- * @return the convolution integral value.
- */
- template<typename VecType>
- static double ConvolutionIntegral(const VecType& a, const VecType& b)
- { return 0; }
-
- /**
- * Obtains the normalizing volume for the kernel with dimension $dimension$.
- * In this case, because our simple example kernel has no internal parameters,
- * we can declare the function static. For a more complex example which
- * cannot be declared static, see the GaussianKernel, which stores an internal
- * parameter.
- *
- * @param dimension the dimension of the space.
- * @return the normalization constant.
- */
- static double Normalizer(size_t dimension) { return 0; }
-};
-
-}; // namespace kernel
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/example_kernel.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/example_kernel.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/example_kernel.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/example_kernel.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,151 @@
+/**
+ * @file example_kernel.hpp
+ * @author Ryan Curtin
+ *
+ * This is an example kernel. If you are making your own kernel, follow the
+ * outline specified in this file.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_KERNELS_EXAMPLE_KERNEL_HPP
+#define __MLPACK_CORE_KERNELS_EXAMPLE_KERNEL_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+
+/**
+ * @brief Kernel functions.
+ *
+ * This namespace contains kernel functions, which evaluate some kernel function
+ * @f$ K(x, y) @f$ for some arbitrary vectors @f$ x @f$ and @f$ y @f$ of the
+ * same dimension. The single restriction on the function @f$ K(x, y) @f$ is
+ * that it must satisfy Mercer's condition:
+ *
+ * @f[
+ * \int \int K(x, y) g(x) g(y) dx dy \ge 0
+ * @f]
+ *
+ * for all square integrable functions @f$ g(x) @f$.
+ *
+ * The kernels in this namespace all implement the same methods as the
+ * ExampleKernel class. Any additional custom kernels should implement all the
+ * methods that class implements; in addition, any method using a kernel should
+ * rely on any arbitrary kernel function class having a default constructor and
+ * a function
+ *
+ * @code
+ * double Evaluate(arma::vec&, arma::vec&);
+ * @endcode
+ */
+namespace kernel {
+
+/**
+ * An example kernel function. This is not a useful kernel, but it implements
+ * the two functions necessary to satisfy the Kernel policy (so that a class can
+ * be used whenever an MLPACK method calls for a `typename Kernel` template
+ * parameter.
+ *
+ * All that is necessary is a constructor and an `Evaluate()` function. More
+ * methods could be added; for instance, one useful idea is a constructor which
+ * takes parameters for a kernel (for instance, the width of the Gaussian for a
+ * Gaussian kernel). However, MLPACK methods cannot count on these various
+ * constructors existing, which is why most methods allow passing an
+ * already-instantiated kernel object (and by default the method will construct
+ * the kernel with the default constructor). So, for instance,
+ *
+ * @code
+ * GaussianKernel k(5.0);
+ * KDE<GaussianKernel> kde(dataset, k);
+ * @endcode
+ *
+ * will set up KDE using a Gaussian kernel with a width of 5.0, but
+ *
+ * @code
+ * KDE<GaussianKernel> kde(dataset);
+ * @endcode
+ *
+ * will create the kernel with the default constructor. It is important (but
+ * not strictly mandatory) that your default constructor still gives a working
+ * kernel.
+ *
+ * @note
+ * Not all kernels require state. For instance, the regular dot product needs
+ * no parameters. In that case, no local variables are necessary and
+ * `Evaluate()` can (and should) be declared static. However, for greater
+ * generalization, MLPACK methods expect all kernels to require state and hence
+ * must store instantiated kernel functions; this is why a default constructor
+ * is necessary.
+ * @endnote
+ */
+class ExampleKernel
+{
+ public:
+ /**
+ * The default constructor, which takes no parameters. Because our simple
+ * example kernel has no internal parameters that need to be stored, the
+ * constructor does not need to do anything. For a more complex example, see
+ * the GaussianKernel, which stores an internal parameter.
+ */
+ ExampleKernel() { }
+
+ /**
+ * Evaluates the kernel function for two given vectors. In this case, because
+ * our simple example kernel has no internal parameters, we can declare the
+ * function static. For a more complex example which cannot be declared
+ * static, see the GaussianKernel, which stores an internal parameter.
+ *
+ * @tparam VecType Type of vector (arma::vec, arma::spvec should be expected).
+ * @param a First vector.
+ * @param b Second vector.
+ * @return K(a, b).
+ */
+ template<typename VecType>
+ static double Evaluate(const VecType& a, const VecType& b) { return 0; }
+
+ /**
+ * Obtains the convolution integral [integral K(||x-a||)K(||b-x||)dx]
+ * for the two vectors. In this case, because
+ * our simple example kernel has no internal parameters, we can declare the
+ * function static. For a more complex example which cannot be declared
+ * static, see the GaussianKernel, which stores an internal parameter.
+ *
+ * @tparam VecType Type of vector (arma::vec, arma::spvec should be expected).
+ * @param a First vector.
+ * @param b Second vector.
+ * @return the convolution integral value.
+ */
+ template<typename VecType>
+ static double ConvolutionIntegral(const VecType& a, const VecType& b)
+ { return 0; }
+
+ /**
+ * Obtains the normalizing volume for the kernel with dimension $dimension$.
+ * In this case, because our simple example kernel has no internal parameters,
+ * we can declare the function static. For a more complex example which
+ * cannot be declared static, see the GaussianKernel, which stores an internal
+ * parameter.
+ *
+ * @param dimension the dimension of the space.
+ * @return the normalization constant.
+ */
+ static double Normalizer(size_t dimension) { return 0; }
+};
+
+}; // namespace kernel
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/gaussian_kernel.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/gaussian_kernel.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/gaussian_kernel.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,143 +0,0 @@
-/**
- * @file gaussian_kernel.hpp
- * @author Wei Guan
- * @author James Cline
- * @author Ryan Curtin
- *
- * Implementation of the Gaussian kernel (GaussianKernel).
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_KERNELS_GAUSSIAN_KERNEL_HPP
-#define __MLPACK_CORE_KERNELS_GAUSSIAN_KERNEL_HPP
-
-#include <mlpack/core.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-
-namespace mlpack {
-namespace kernel {
-
-/**
- * The standard Gaussian kernel. Given two vectors @f$ x @f$, @f$ y @f$, and a
- * bandwidth @f$ \mu @f$ (set in the constructor),
- *
- * @f[
- * K(x, y) = \exp(-\frac{|| x - y ||^2}{2 \mu^2}).
- * @f]
- *
- * The implementation is all in the header file because it is so simple.
- */
-class GaussianKernel
-{
- public:
- /**
- * Default constructor; sets bandwidth to 1.0.
- */
- GaussianKernel() : bandwidth(1.0), gamma(-0.5)
- { }
-
- /**
- * Construct the Gaussian kernel with a custom bandwidth.
- *
- * @param bandwidth The bandwidth of the kernel (@f$\mu at f$).
- */
- GaussianKernel(double bandwidth) :
- bandwidth(bandwidth),
- gamma(-0.5 * pow(bandwidth, -2.0))
- { }
-
- /**
- * Evaluation of the Gaussian kernel. This could be generalized to use any
- * distance metric, not the Euclidean distance, but for now, the Euclidean
- * distance is used.
- *
- * @tparam VecType Type of vector (likely arma::vec or arma::spvec).
- * @param a First vector.
- * @param b Second vector.
- * @return K(a, b) using the bandwidth (@f$\mu at f$) specified in the
- * constructor.
- */
- template<typename VecType>
- double Evaluate(const VecType& a, const VecType& b) const
- {
- // The precalculation of gamma saves us a little computation time.
- return exp(gamma * metric::SquaredEuclideanDistance::Evaluate(a, b));
- }
-
- /**
- * Evaluation of the Gaussian kernel using a double precision argument.
- *
- * @param t double value.
- * @return K(t) using the bandwidth (@f$\mu at f$) specified in the
- * constructor.
- */
- double Evaluate(double t) const
- {
- // The precalculation of gamma saves us a little computation time.
- return exp(gamma * t * t);
- }
- /**
- * Obtain the normalization constant of the Gaussian kernel.
- *
- * @param dimension
- * @return the normalization constant
- */
- double Normalizer(size_t dimension)
- {
- return pow(sqrt(2.0 * M_PI) * bandwidth, (double) dimension);
- }
- /**
- * Obtain a convolution integral of the Gaussian kernel.
- *
- * @param a, first vector
- * @param b, second vector
- * @return the convolution integral
- */
- template<typename VecType>
- double ConvolutionIntegral(const VecType& a, const VecType& b)
- {
- return Evaluate(sqrt(metric::SquaredEuclideanDistance::Evaluate(a, b) / 2.0)) /
- (Normalizer(a.n_rows) * pow(2.0, (double) a.n_rows / 2.0));
- }
-
-
- //! Get the bandwidth.
- double Bandwidth() const { return bandwidth; }
-
- //! Modify the bandwidth. This takes an argument because we must update the
- //! precalculated constant (gamma).
- void Bandwidth(const double bandwidth)
- {
- this->bandwidth = bandwidth;
- this->gamma = -0.5 * pow(bandwidth, -2.0);
- }
-
- //! Get the precalculated constant.
- double Gamma() const { return gamma; }
-
- private:
- //! Kernel bandwidth.
- double bandwidth;
-
- //! Precalculated constant depending on the bandwidth;
- //! @f$ \gamma = -\frac{1}{2 \mu^2} @f$.
- double gamma;
-};
-
-}; // namespace kernel
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/gaussian_kernel.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/gaussian_kernel.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/gaussian_kernel.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/gaussian_kernel.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,143 @@
+/**
+ * @file gaussian_kernel.hpp
+ * @author Wei Guan
+ * @author James Cline
+ * @author Ryan Curtin
+ *
+ * Implementation of the Gaussian kernel (GaussianKernel).
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_KERNELS_GAUSSIAN_KERNEL_HPP
+#define __MLPACK_CORE_KERNELS_GAUSSIAN_KERNEL_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+
+namespace mlpack {
+namespace kernel {
+
+/**
+ * The standard Gaussian kernel. Given two vectors @f$ x @f$, @f$ y @f$, and a
+ * bandwidth @f$ \mu @f$ (set in the constructor),
+ *
+ * @f[
+ * K(x, y) = \exp(-\frac{|| x - y ||^2}{2 \mu^2}).
+ * @f]
+ *
+ * The implementation is all in the header file because it is so simple.
+ */
+class GaussianKernel
+{
+ public:
+ /**
+ * Default constructor; sets bandwidth to 1.0.
+ */
+ GaussianKernel() : bandwidth(1.0), gamma(-0.5)
+ { }
+
+ /**
+ * Construct the Gaussian kernel with a custom bandwidth.
+ *
+ * @param bandwidth The bandwidth of the kernel (@f$\mu at f$).
+ */
+ GaussianKernel(double bandwidth) :
+ bandwidth(bandwidth),
+ gamma(-0.5 * pow(bandwidth, -2.0))
+ { }
+
+ /**
+ * Evaluation of the Gaussian kernel. This could be generalized to use any
+ * distance metric, not the Euclidean distance, but for now, the Euclidean
+ * distance is used.
+ *
+ * @tparam VecType Type of vector (likely arma::vec or arma::spvec).
+ * @param a First vector.
+ * @param b Second vector.
+ * @return K(a, b) using the bandwidth (@f$\mu at f$) specified in the
+ * constructor.
+ */
+ template<typename VecType>
+ double Evaluate(const VecType& a, const VecType& b) const
+ {
+ // The precalculation of gamma saves us a little computation time.
+ return exp(gamma * metric::SquaredEuclideanDistance::Evaluate(a, b));
+ }
+
+ /**
+ * Evaluation of the Gaussian kernel using a double precision argument.
+ *
+ * @param t double value.
+ * @return K(t) using the bandwidth (@f$\mu at f$) specified in the
+ * constructor.
+ */
+ double Evaluate(double t) const
+ {
+ // The precalculation of gamma saves us a little computation time.
+ return exp(gamma * t * t);
+ }
+ /**
+ * Obtain the normalization constant of the Gaussian kernel.
+ *
+ * @param dimension
+ * @return the normalization constant
+ */
+ double Normalizer(size_t dimension)
+ {
+ return pow(sqrt(2.0 * M_PI) * bandwidth, (double) dimension);
+ }
+ /**
+ * Obtain a convolution integral of the Gaussian kernel.
+ *
+ * @param a, first vector
+ * @param b, second vector
+ * @return the convolution integral
+ */
+ template<typename VecType>
+ double ConvolutionIntegral(const VecType& a, const VecType& b)
+ {
+ return Evaluate(sqrt(metric::SquaredEuclideanDistance::Evaluate(a, b) / 2.0)) /
+ (Normalizer(a.n_rows) * pow(2.0, (double) a.n_rows / 2.0));
+ }
+
+
+ //! Get the bandwidth.
+ double Bandwidth() const { return bandwidth; }
+
+ //! Modify the bandwidth. This takes an argument because we must update the
+ //! precalculated constant (gamma).
+ void Bandwidth(const double bandwidth)
+ {
+ this->bandwidth = bandwidth;
+ this->gamma = -0.5 * pow(bandwidth, -2.0);
+ }
+
+ //! Get the precalculated constant.
+ double Gamma() const { return gamma; }
+
+ private:
+ //! Kernel bandwidth.
+ double bandwidth;
+
+ //! Precalculated constant depending on the bandwidth;
+ //! @f$ \gamma = -\frac{1}{2 \mu^2} @f$.
+ double gamma;
+};
+
+}; // namespace kernel
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/hyperbolic_tangent_kernel.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/hyperbolic_tangent_kernel.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/hyperbolic_tangent_kernel.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,91 +0,0 @@
-/**
- * @file hyperbolic_tangent_kernel.hpp
- * @author Ajinkya Kale <kaleajinkya at gmail.com>
- *
- * Implementation of the hyperbolic tangent kernel.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_KERNELS_HYPERBOLIC_TANGENT_KERNEL_HPP
-#define __MLPACK_CORE_KERNELS_HYPERBOLIC_TANGENT_KERNEL_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace kernel {
-
-/**
- * Hyperbolic tangent kernel. For any two vectors @f$ x @f$, @f$ y @f$ and a
- * given scale @f$ s @f$ and offset @f$ t @f$
- *
- * @f[
- * K(x, y) = \tanh(s <x, y> + t)
- * @f]
- */
-class HyperbolicTangentKernel
-{
- public:
- /**
- * This constructor sets the default scale to 1.0 and offset to 0.0.
- */
- HyperbolicTangentKernel() : scale(1.0), offset(0.0)
- { }
-
- /**
- * Construct the hyperbolic tangent kernel with custom scale factor and
- * offset.
- *
- * @param scale Scaling factor for <x, y>.
- * @param offset Kernel offset.
- */
- HyperbolicTangentKernel(double scale, double offset) :
- scale(scale), offset(offset)
- { }
-
- /**
- * Evaluate the hyperbolic tangent kernel. This evaluation uses Armadillo's
- * dot() function.
- *
- * @tparam VecType Type of vector (should be arma::vec or arma::spvec).
- * @param a First vector.
- * @param b Second vector.
- * @return K(a, b).
- */
- template<typename VecType>
- double Evaluate(const VecType& a, const VecType& b)
- {
- return tanh(scale * arma::dot(a, b) + offset);
- }
-
- //! Get scale factor.
- double Scale() const { return scale; }
- //! Modify scale factor.
- double& Scale() { return scale; }
-
- //! Get offset for the kernel.
- double Offset() const { return offset; }
- //! Modify offset for the kernel.
- double& Offset() { return offset; }
-
- private:
- double scale;
- double offset;
-};
-
-}; // namespace kernel
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/hyperbolic_tangent_kernel.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/hyperbolic_tangent_kernel.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/hyperbolic_tangent_kernel.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/hyperbolic_tangent_kernel.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,91 @@
+/**
+ * @file hyperbolic_tangent_kernel.hpp
+ * @author Ajinkya Kale <kaleajinkya at gmail.com>
+ *
+ * Implementation of the hyperbolic tangent kernel.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_KERNELS_HYPERBOLIC_TANGENT_KERNEL_HPP
+#define __MLPACK_CORE_KERNELS_HYPERBOLIC_TANGENT_KERNEL_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace kernel {
+
+/**
+ * Hyperbolic tangent kernel. For any two vectors @f$ x @f$, @f$ y @f$ and a
+ * given scale @f$ s @f$ and offset @f$ t @f$
+ *
+ * @f[
+ * K(x, y) = \tanh(s <x, y> + t)
+ * @f]
+ */
+class HyperbolicTangentKernel
+{
+ public:
+ /**
+ * This constructor sets the default scale to 1.0 and offset to 0.0.
+ */
+ HyperbolicTangentKernel() : scale(1.0), offset(0.0)
+ { }
+
+ /**
+ * Construct the hyperbolic tangent kernel with custom scale factor and
+ * offset.
+ *
+ * @param scale Scaling factor for <x, y>.
+ * @param offset Kernel offset.
+ */
+ HyperbolicTangentKernel(double scale, double offset) :
+ scale(scale), offset(offset)
+ { }
+
+ /**
+ * Evaluate the hyperbolic tangent kernel. This evaluation uses Armadillo's
+ * dot() function.
+ *
+ * @tparam VecType Type of vector (should be arma::vec or arma::spvec).
+ * @param a First vector.
+ * @param b Second vector.
+ * @return K(a, b).
+ */
+ template<typename VecType>
+ double Evaluate(const VecType& a, const VecType& b)
+ {
+ return tanh(scale * arma::dot(a, b) + offset);
+ }
+
+ //! Get scale factor.
+ double Scale() const { return scale; }
+ //! Modify scale factor.
+ double& Scale() { return scale; }
+
+ //! Get offset for the kernel.
+ double Offset() const { return offset; }
+ //! Modify offset for the kernel.
+ double& Offset() { return offset; }
+
+ private:
+ double scale;
+ double offset;
+};
+
+}; // namespace kernel
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/laplacian_kernel.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/laplacian_kernel.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/laplacian_kernel.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,102 +0,0 @@
-/**
- * @file laplacian_kernel.hpp
- * @author Ajinkya Kale <kaleajinkya at gmail.com>
- *
- * Implementation of the Laplacian kernel (LaplacianKernel).
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_KERNELS_LAPLACIAN_KERNEL_HPP
-#define __MLPACK_CORE_KERNELS_LAPLACIAN_KERNEL_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace kernel {
-
-/**
- * The standard Laplacian kernel. Given two vectors @f$ x @f$, @f$ y @f$, and a
- * bandwidth @f$ \mu @f$ (set in the constructor),
- *
- * @f[
- * K(x, y) = \exp(-\frac{|| x - y ||}{\mu}).
- * @f]
- *
- * The implementation is all in the header file because it is so simple.
- */
-class LaplacianKernel
-{
- public:
- /**
- * Default constructor; sets bandwidth to 1.0.
- */
- LaplacianKernel() : bandwidth(1.0)
- { }
-
- /**
- * Construct the Laplacian kernel with a custom bandwidth.
- *
- * @param bandwidth The bandwidth of the kernel (@f$\mu at f$).
- */
- LaplacianKernel(double bandwidth) :
- bandwidth(bandwidth)
- { }
-
- /**
- * Evaluation of the Laplacian kernel. This could be generalized to use any
- * distance metric, not the Euclidean distance, but for now, the Euclidean
- * distance is used.
- *
- * @tparam VecType Type of vector (likely arma::vec or arma::spvec).
- * @param a First vector.
- * @param b Second vector.
- * @return K(a, b) using the bandwidth (@f$\mu at f$) specified in the
- * constructor.
- */
- template<typename VecType>
- double Evaluate(const VecType& a, const VecType& b) const
- {
- // The precalculation of gamma saves us a little computation time.
- return exp(-metric::EuclideanDistance::Evaluate(a, b) / bandwidth);
- }
-
- /**
- * Evaluation of the Laplacian kernel using a double precision argument.
- *
- * @param t double value.
- * @return K(t) using the bandwidth (@f$\mu at f$) specified in the
- * constructor.
- */
- double Evaluate(double t) const
- {
- // The precalculation of gamma saves us a little computation time.
- return exp(-t / bandwidth);
- }
-
- //! Get the bandwidth.
- double Bandwidth() const { return bandwidth; }
- //! Modify the bandwidth.
- double& Bandwidth() { return bandwidth; }
-
- private:
- //! Kernel bandwidth.
- double bandwidth;
-};
-
-}; // namespace kernel
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/laplacian_kernel.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/laplacian_kernel.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/laplacian_kernel.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/laplacian_kernel.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,102 @@
+/**
+ * @file laplacian_kernel.hpp
+ * @author Ajinkya Kale <kaleajinkya at gmail.com>
+ *
+ * Implementation of the Laplacian kernel (LaplacianKernel).
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_KERNELS_LAPLACIAN_KERNEL_HPP
+#define __MLPACK_CORE_KERNELS_LAPLACIAN_KERNEL_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace kernel {
+
+/**
+ * The standard Laplacian kernel. Given two vectors @f$ x @f$, @f$ y @f$, and a
+ * bandwidth @f$ \mu @f$ (set in the constructor),
+ *
+ * @f[
+ * K(x, y) = \exp(-\frac{|| x - y ||}{\mu}).
+ * @f]
+ *
+ * The implementation is all in the header file because it is so simple.
+ */
+class LaplacianKernel
+{
+ public:
+ /**
+ * Default constructor; sets bandwidth to 1.0.
+ */
+ LaplacianKernel() : bandwidth(1.0)
+ { }
+
+ /**
+ * Construct the Laplacian kernel with a custom bandwidth.
+ *
+ * @param bandwidth The bandwidth of the kernel (@f$\mu at f$).
+ */
+ LaplacianKernel(double bandwidth) :
+ bandwidth(bandwidth)
+ { }
+
+ /**
+ * Evaluation of the Laplacian kernel. This could be generalized to use any
+ * distance metric, not the Euclidean distance, but for now, the Euclidean
+ * distance is used.
+ *
+ * @tparam VecType Type of vector (likely arma::vec or arma::spvec).
+ * @param a First vector.
+ * @param b Second vector.
+ * @return K(a, b) using the bandwidth (@f$\mu at f$) specified in the
+ * constructor.
+ */
+ template<typename VecType>
+ double Evaluate(const VecType& a, const VecType& b) const
+ {
+ // The precalculation of gamma saves us a little computation time.
+ return exp(-metric::EuclideanDistance::Evaluate(a, b) / bandwidth);
+ }
+
+ /**
+ * Evaluation of the Laplacian kernel using a double precision argument.
+ *
+ * @param t double value.
+ * @return K(t) using the bandwidth (@f$\mu at f$) specified in the
+ * constructor.
+ */
+ double Evaluate(double t) const
+ {
+ // The precalculation of gamma saves us a little computation time.
+ return exp(-t / bandwidth);
+ }
+
+ //! Get the bandwidth.
+ double Bandwidth() const { return bandwidth; }
+ //! Modify the bandwidth.
+ double& Bandwidth() { return bandwidth; }
+
+ private:
+ //! Kernel bandwidth.
+ double bandwidth;
+};
+
+}; // namespace kernel
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/linear_kernel.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/linear_kernel.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/linear_kernel.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,70 +0,0 @@
-/**
- * @file linear_kernel.hpp
- * @author Wei Guan
- * @author James Cline
- * @author Ryan Curtin
- *
- * Implementation of the linear kernel (just the standard dot product).
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_KERNELS_LINEAR_KERNEL_HPP
-#define __MLPACK_CORE_KERNELS_LINEAR_KERNEL_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace kernel {
-
-/**
- * The simple linear kernel (dot product). For any two vectors @f$ x @f$ and
- * @f$ y @f$,
- *
- * @f[
- * K(x, y) = x^T y
- * @f]
- *
- * This kernel has no parameters and therefore the evaluation can be static.
- */
-class LinearKernel
-{
- public:
- /**
- * This constructor does nothing; the linear kernel has no parameters to
- * store.
- */
- LinearKernel() { }
-
- /**
- * Simple evaluation of the dot product. This evaluation uses Armadillo's
- * dot() function.
- *
- * @tparam VecType Type of vector (should be arma::vec or arma::spvec).
- * @param a First vector.
- * @param b Second vector.
- * @return K(a, b).
- */
- template<typename VecType>
- static double Evaluate(const VecType& a, const VecType& b)
- {
- return arma::dot(a, b);
- }
-};
-
-}; // namespace kernel
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/linear_kernel.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/linear_kernel.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/linear_kernel.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/linear_kernel.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,70 @@
+/**
+ * @file linear_kernel.hpp
+ * @author Wei Guan
+ * @author James Cline
+ * @author Ryan Curtin
+ *
+ * Implementation of the linear kernel (just the standard dot product).
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_KERNELS_LINEAR_KERNEL_HPP
+#define __MLPACK_CORE_KERNELS_LINEAR_KERNEL_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace kernel {
+
+/**
+ * The simple linear kernel (dot product). For any two vectors @f$ x @f$ and
+ * @f$ y @f$,
+ *
+ * @f[
+ * K(x, y) = x^T y
+ * @f]
+ *
+ * This kernel has no parameters and therefore the evaluation can be static.
+ */
+class LinearKernel
+{
+ public:
+ /**
+ * This constructor does nothing; the linear kernel has no parameters to
+ * store.
+ */
+ LinearKernel() { }
+
+ /**
+ * Simple evaluation of the dot product. This evaluation uses Armadillo's
+ * dot() function.
+ *
+ * @tparam VecType Type of vector (should be arma::vec or arma::spvec).
+ * @param a First vector.
+ * @param b Second vector.
+ * @return K(a, b).
+ */
+ template<typename VecType>
+ static double Evaluate(const VecType& a, const VecType& b)
+ {
+ return arma::dot(a, b);
+ }
+};
+
+}; // namespace kernel
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/polynomial_kernel.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/polynomial_kernel.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/polynomial_kernel.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,88 +0,0 @@
-/**
- * @file polynomial_kernel.hpp
- * @author Ajinkya Kale <kaleajinkya at gmail.com>
- *
- * Implementation of the polynomial kernel (just the standard dot product).
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_KERNELS_POLYNOMIAL_KERNEL_HPP
-#define __MLPACK_CORE_KERNELS_POLYNOMIAL_KERNEL_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace kernel {
-
-/**
- * The simple polynomial kernel. For any two vectors @f$ x @f$, @f$ y @f$,
- * @f$ degree @f$ and @f$ offset @f$,
- *
- * @f[
- * K(x, y) = (x^T * y + offset) ^ {degree}.
- * @f]
- */
-class PolynomialKernel
-{
- public:
- /**
- * Construct the Polynomial Kernel with the given offset and degree. If the
- * arguments are omitted, the default degree is 2 and the default offset is 0.
- *
- * @param offset Offset of the dot product of the arguments.
- * @param degree Degree of the polynomial.
- */
- PolynomialKernel(const double degree = 2.0, const double offset = 0.0) :
- degree(degree),
- offset(offset)
- { }
-
- /**
- * Simple evaluation of the dot product. This evaluation uses Armadillo's
- * dot() function.
- *
- * @tparam VecType Type of vector (should be arma::vec or arma::spvec).
- * @param a First vector.
- * @param b Second vector.
- * @return K(a, b).
- */
- template<typename VecType>
- double Evaluate(const VecType& a, const VecType& b) const
- {
- return pow((arma::dot(a, b) + offset), degree);
- }
-
- //! Get the degree of the polynomial.
- const double& Degree() const { return degree; }
- //! Modify the degree of the polynomial.
- double& Degree() { return degree; }
-
- //! Get the offset of the dot product of the arguments.
- const double& Offset() const { return offset; }
- //! Modify the offset of the dot product of the arguments.
- double& Offset() { return offset; }
-
- private:
- //! The degree of the polynomial.
- double degree;
- //! The offset of the dot product of the arguments.
- double offset;
-};
-
-}; // namespace kernel
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/polynomial_kernel.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/polynomial_kernel.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/polynomial_kernel.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/polynomial_kernel.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,88 @@
+/**
+ * @file polynomial_kernel.hpp
+ * @author Ajinkya Kale <kaleajinkya at gmail.com>
+ *
+ * Implementation of the polynomial kernel (just the standard dot product).
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_KERNELS_POLYNOMIAL_KERNEL_HPP
+#define __MLPACK_CORE_KERNELS_POLYNOMIAL_KERNEL_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace kernel {
+
+/**
+ * The simple polynomial kernel. For any two vectors @f$ x @f$, @f$ y @f$,
+ * @f$ degree @f$ and @f$ offset @f$,
+ *
+ * @f[
+ * K(x, y) = (x^T * y + offset) ^ {degree}.
+ * @f]
+ */
+class PolynomialKernel
+{
+ public:
+ /**
+ * Construct the Polynomial Kernel with the given offset and degree. If the
+ * arguments are omitted, the default degree is 2 and the default offset is 0.
+ *
+ * @param offset Offset of the dot product of the arguments.
+ * @param degree Degree of the polynomial.
+ */
+ PolynomialKernel(const double degree = 2.0, const double offset = 0.0) :
+ degree(degree),
+ offset(offset)
+ { }
+
+ /**
+ * Simple evaluation of the dot product. This evaluation uses Armadillo's
+ * dot() function.
+ *
+ * @tparam VecType Type of vector (should be arma::vec or arma::spvec).
+ * @param a First vector.
+ * @param b Second vector.
+ * @return K(a, b).
+ */
+ template<typename VecType>
+ double Evaluate(const VecType& a, const VecType& b) const
+ {
+ return pow((arma::dot(a, b) + offset), degree);
+ }
+
+ //! Get the degree of the polynomial.
+ const double& Degree() const { return degree; }
+ //! Modify the degree of the polynomial.
+ double& Degree() { return degree; }
+
+ //! Get the offset of the dot product of the arguments.
+ const double& Offset() const { return offset; }
+ //! Modify the offset of the dot product of the arguments.
+ double& Offset() { return offset; }
+
+ private:
+ //! The degree of the polynomial.
+ double degree;
+ //! The offset of the dot product of the arguments.
+ double offset;
+};
+
+}; // namespace kernel
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/pspectrum_string_kernel.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/pspectrum_string_kernel.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/pspectrum_string_kernel.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,98 +0,0 @@
-/**
- * @file pspectrum_string_kernel.cpp
- * @author Ryan Curtin
- *
- * Implementation of the p-spectrum string kernel, created for use with FastMKS.
- * Instead of passing a data matrix to FastMKS which stores the kernels, pass a
- * one-dimensional data matrix (data vector) to FastMKS which stores indices of
- * strings; then, the actual strings are given to the PSpectrumStringKernel at
- * construction time, and the kernel knows to map the indices to actual strings.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "pspectrum_string_kernel.hpp"
-
-using namespace std;
-using namespace mlpack;
-using namespace mlpack::kernel;
-
-/**
- * Initialize the PSpectrumStringKernel with the given string datasets. For
- * more information on this, see the general class documentation.
- *
- * @param datasets Sets of string data. @param p The length of substrings to
- * search.
- */
-mlpack::kernel::PSpectrumStringKernel::PSpectrumStringKernel(
- const std::vector<std::vector<std::string> >& datasets,
- const size_t p) :
- datasets(datasets),
- p(p)
-{
- // We have to assemble the counts of substrings. This is not a particularly
- // fast operation, unfortunately, but it only needs to be done once.
- Log::Info << "Assembling counts of substrings of length " << p << "."
- << std::endl;
-
- // Resize for number of datasets.
- counts.resize(datasets.size());
-
- for (size_t dataset = 0; dataset < datasets.size(); ++dataset)
- {
- const std::vector<std::string>& set = datasets[dataset];
-
- // Resize for number of strings in dataset.
- counts[dataset].resize(set.size());
-
- // Inspect each string in the dataset.
- for (size_t index = 0; index < set.size(); ++index)
- {
- // Convenience references.
- const std::string& str = set[index];
- std::map<std::string, int>& mapping = counts[dataset][index];
-
- size_t start = 0;
- while ((start + p) <= str.length())
- {
- string sub = str.substr(start, p);
-
- // Convert all characters to lowercase.
- bool invalid = false;
- for (size_t j = 0; j < p; ++j)
- {
- if (!isalnum(sub[j]))
- {
- invalid = true;
- break; // Only consider substrings with alphanumerics.
- }
-
- sub[j] = tolower(sub[j]);
- }
-
- // Increment position in string.
- ++start;
-
- if (!invalid)
- {
- // Add to the map.
- ++mapping[sub];
- }
- }
- }
- }
-
- Log::Info << "Substring extraction complete." << std::endl;
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/pspectrum_string_kernel.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/pspectrum_string_kernel.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/pspectrum_string_kernel.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/pspectrum_string_kernel.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,98 @@
+/**
+ * @file pspectrum_string_kernel.cpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the p-spectrum string kernel, created for use with FastMKS.
+ * Instead of passing a data matrix to FastMKS which stores the kernels, pass a
+ * one-dimensional data matrix (data vector) to FastMKS which stores indices of
+ * strings; then, the actual strings are given to the PSpectrumStringKernel at
+ * construction time, and the kernel knows to map the indices to actual strings.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "pspectrum_string_kernel.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::kernel;
+
+/**
+ * Initialize the PSpectrumStringKernel with the given string datasets. For
+ * more information on this, see the general class documentation.
+ *
+ * @param datasets Sets of string data. @param p The length of substrings to
+ * search.
+ */
+mlpack::kernel::PSpectrumStringKernel::PSpectrumStringKernel(
+ const std::vector<std::vector<std::string> >& datasets,
+ const size_t p) :
+ datasets(datasets),
+ p(p)
+{
+ // We have to assemble the counts of substrings. This is not a particularly
+ // fast operation, unfortunately, but it only needs to be done once.
+ Log::Info << "Assembling counts of substrings of length " << p << "."
+ << std::endl;
+
+ // Resize for number of datasets.
+ counts.resize(datasets.size());
+
+ for (size_t dataset = 0; dataset < datasets.size(); ++dataset)
+ {
+ const std::vector<std::string>& set = datasets[dataset];
+
+ // Resize for number of strings in dataset.
+ counts[dataset].resize(set.size());
+
+ // Inspect each string in the dataset.
+ for (size_t index = 0; index < set.size(); ++index)
+ {
+ // Convenience references.
+ const std::string& str = set[index];
+ std::map<std::string, int>& mapping = counts[dataset][index];
+
+ size_t start = 0;
+ while ((start + p) <= str.length())
+ {
+ string sub = str.substr(start, p);
+
+ // Convert all characters to lowercase.
+ bool invalid = false;
+ for (size_t j = 0; j < p; ++j)
+ {
+ if (!isalnum(sub[j]))
+ {
+ invalid = true;
+ break; // Only consider substrings with alphanumerics.
+ }
+
+ sub[j] = tolower(sub[j]);
+ }
+
+ // Increment position in string.
+ ++start;
+
+ if (!invalid)
+ {
+ // Add to the map.
+ ++mapping[sub];
+ }
+ }
+ }
+ }
+
+ Log::Info << "Substring extraction complete." << std::endl;
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/pspectrum_string_kernel.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/pspectrum_string_kernel.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/pspectrum_string_kernel.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,131 +0,0 @@
-/**
- * @file pspectrum_string_kernel.hpp
- * @author Ryan Curtin
- *
- * Implementation of the p-spectrum string kernel, created for use with FastMKS.
- * Instead of passing a data matrix to FastMKS which stores the kernels, pass a
- * one-dimensional data matrix (data vector) to FastMKS which stores indices of
- * strings; then, the actual strings are given to the PSpectrumStringKernel at
- * construction time, and the kernel knows to map the indices to actual strings.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_KERNELS_PSPECTRUM_STRING_KERNEL_HPP
-#define __MLPACK_CORE_KERNELS_PSPECTRUM_STRING_KERNEL_HPP
-
-#include <map>
-#include <string>
-#include <vector>
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace kernel {
-
-/**
- * The p-spectrum string kernel. Given a length p, the p-spectrum kernel finds
- * the contiguous subsequence match count between two strings. The kernel will
- * take every possible substring of length p of one string and count how many
- * times it appears in the other string.
- *
- * The string kernel, when created, must be passed a reference to a series of
- * string datasets (std::vector<std::vector<std::string> >&). This is because
- * MLPACK only supports datasets which are Armadillo matrices -- and a dataset
- * of variable-length strings cannot be easily cast into an Armadillo matrix.
- *
- * Therefore, once the PSpectrumStringKernel is created with a reference to the
- * string datasets, a "fake" Armadillo data matrix must be created, which simply
- * holds indices to the strings they represent. This "fake" matrix has two rows
- * and n columns (where n is the number of strings in the dataset). The first
- * row holds the index of the dataset (remember, the kernel can have multiple
- * datasets), and the second row holds the index of the string. A fake matrix
- * containing only strings from dataset 0 might look like this:
- *
- * [[0 0 0 0 0 0 0 0 0]
- * [0 1 2 3 4 5 6 7 8]]
- *
- * This fake matrix is then given to the machine learning method, which will
- * eventually call PSpectrumStringKernel::Evaluate(a, b), where a and b are two
- * columns of the fake matrix. The string kernel will then map these fake
- * columns back to the strings they represent, and then correctly evaluate the
- * kernel.
- *
- * Unfortunately, not every machine learning method will work with this kernel.
- * Only machine learning methods which do not ever operate on the explicit
- * representation of points can use this kernel. So, for instance, one cannot
- * build a kd-tree on strings, because the BinarySpaceTree<> class will split
- * the data according to the fake data matrix -- resulting in a meaningless
- * tree. This kernel was originally written for the FastMKS method; so, at the
- * very least, it will work with that.
- */
-class PSpectrumStringKernel
-{
- public:
- /**
- * Initialize the PSpectrumStringKernel with the given string datasets. For
- * more information on this, see the general class documentation.
- *
- * @param datasets Sets of string data.
- * @param p The length of substrings to search.
- */
- PSpectrumStringKernel(const std::vector<std::vector<std::string> >& datasets,
- const size_t p);
-
- /**
- * Evaluate the kernel for the string indices given. As mentioned in the
- * class documentation, a and b should be 2-element vectors, where the first
- * element contains the index of the dataset and the second element contains
- * the index of the string. Therefore, if [2 3] is passed for a, the string
- * used will be datasets[2][3] (datasets is of type
- * std::vector<std::vector<std::string> >&).
- *
- * @param a Index of string and dataset for first string.
- * @param b Index of string and dataset for second string.
- */
- template<typename VecType>
- double Evaluate(const VecType& a, const VecType& b) const;
-
- //! Access the lists of substrings.
- const std::vector<std::vector<std::map<std::string, int> > >& Counts() const
- { return counts; }
- //! Modify the lists of substrings.
- std::vector<std::vector<std::map<std::string, int> > >& Counts()
- { return counts; }
-
- //! Access the value of p.
- size_t P() const { return p; }
- //! Modify the value of p.
- size_t& P() { return p; }
-
- private:
- //! The datasets.
- const std::vector<std::vector<std::string> >& datasets;
-
- //! Mappings of the datasets to counts of substrings. Such a huge structure
- //! is not wonderful...
- std::vector<std::vector<std::map<std::string, int> > > counts;
-
- //! The value of p to use in calculation.
- size_t p;
-};
-
-}; // namespace kernel
-}; // namespace mlpack
-
-// Include implementation of templated Evaluate().
-#include "pspectrum_string_kernel_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/pspectrum_string_kernel.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/pspectrum_string_kernel.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/pspectrum_string_kernel.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/pspectrum_string_kernel.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,131 @@
+/**
+ * @file pspectrum_string_kernel.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the p-spectrum string kernel, created for use with FastMKS.
+ * Instead of passing a data matrix to FastMKS which stores the kernels, pass a
+ * one-dimensional data matrix (data vector) to FastMKS which stores indices of
+ * strings; then, the actual strings are given to the PSpectrumStringKernel at
+ * construction time, and the kernel knows to map the indices to actual strings.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_KERNELS_PSPECTRUM_STRING_KERNEL_HPP
+#define __MLPACK_CORE_KERNELS_PSPECTRUM_STRING_KERNEL_HPP
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace kernel {
+
+/**
+ * The p-spectrum string kernel. Given a length p, the p-spectrum kernel finds
+ * the contiguous subsequence match count between two strings. The kernel will
+ * take every possible substring of length p of one string and count how many
+ * times it appears in the other string.
+ *
+ * The string kernel, when created, must be passed a reference to a series of
+ * string datasets (std::vector<std::vector<std::string> >&). This is because
+ * MLPACK only supports datasets which are Armadillo matrices -- and a dataset
+ * of variable-length strings cannot be easily cast into an Armadillo matrix.
+ *
+ * Therefore, once the PSpectrumStringKernel is created with a reference to the
+ * string datasets, a "fake" Armadillo data matrix must be created, which simply
+ * holds indices to the strings they represent. This "fake" matrix has two rows
+ * and n columns (where n is the number of strings in the dataset). The first
+ * row holds the index of the dataset (remember, the kernel can have multiple
+ * datasets), and the second row holds the index of the string. A fake matrix
+ * containing only strings from dataset 0 might look like this:
+ *
+ * [[0 0 0 0 0 0 0 0 0]
+ * [0 1 2 3 4 5 6 7 8]]
+ *
+ * This fake matrix is then given to the machine learning method, which will
+ * eventually call PSpectrumStringKernel::Evaluate(a, b), where a and b are two
+ * columns of the fake matrix. The string kernel will then map these fake
+ * columns back to the strings they represent, and then correctly evaluate the
+ * kernel.
+ *
+ * Unfortunately, not every machine learning method will work with this kernel.
+ * Only machine learning methods which do not ever operate on the explicit
+ * representation of points can use this kernel. So, for instance, one cannot
+ * build a kd-tree on strings, because the BinarySpaceTree<> class will split
+ * the data according to the fake data matrix -- resulting in a meaningless
+ * tree. This kernel was originally written for the FastMKS method; so, at the
+ * very least, it will work with that.
+ */
+class PSpectrumStringKernel
+{
+ public:
+ /**
+ * Initialize the PSpectrumStringKernel with the given string datasets. For
+ * more information on this, see the general class documentation.
+ *
+ * @param datasets Sets of string data.
+ * @param p The length of substrings to search.
+ */
+ PSpectrumStringKernel(const std::vector<std::vector<std::string> >& datasets,
+ const size_t p);
+
+ /**
+ * Evaluate the kernel for the string indices given. As mentioned in the
+ * class documentation, a and b should be 2-element vectors, where the first
+ * element contains the index of the dataset and the second element contains
+ * the index of the string. Therefore, if [2 3] is passed for a, the string
+ * used will be datasets[2][3] (datasets is of type
+ * std::vector<std::vector<std::string> >&).
+ *
+ * @param a Index of string and dataset for first string.
+ * @param b Index of string and dataset for second string.
+ */
+ template<typename VecType>
+ double Evaluate(const VecType& a, const VecType& b) const;
+
+ //! Access the lists of substrings.
+ const std::vector<std::vector<std::map<std::string, int> > >& Counts() const
+ { return counts; }
+ //! Modify the lists of substrings.
+ std::vector<std::vector<std::map<std::string, int> > >& Counts()
+ { return counts; }
+
+ //! Access the value of p.
+ size_t P() const { return p; }
+ //! Modify the value of p.
+ size_t& P() { return p; }
+
+ private:
+ //! The datasets.
+ const std::vector<std::vector<std::string> >& datasets;
+
+ //! Mappings of the datasets to counts of substrings. Such a huge structure
+ //! is not wonderful...
+ std::vector<std::vector<std::map<std::string, int> > > counts;
+
+ //! The value of p to use in calculation.
+ size_t p;
+};
+
+}; // namespace kernel
+}; // namespace mlpack
+
+// Include implementation of templated Evaluate().
+#include "pspectrum_string_kernel_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/pspectrum_string_kernel_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/pspectrum_string_kernel_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/pspectrum_string_kernel_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,91 +0,0 @@
-/**
- * @file pspectrum_string_kernel_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of the p-spectrum string kernel, created for use with FastMKS.
- * Instead of passing a data matrix to FastMKS which stores the kernels, pass a
- * one-dimensional data matrix (data vector) to FastMKS which stores indices of
- * strings; then, the actual strings are given to the PSpectrumStringKernel at
- * construction time, and the kernel knows to map the indices to actual strings.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_KERNELS_PSPECTRUM_STRING_KERNEL_IMPL_HPP
-#define __MLPACK_CORE_KERNELS_PSPECTRUM_STRING_KERNEL_IMPL_HPP
-
-// In case it has not been included yet.
-#include "pspectrum_string_kernel.hpp"
-
-namespace mlpack {
-namespace kernel {
-
-/**
- * Evaluate the kernel for the string indices given. As mentioned in the class
- * documentation, a and b should be 2-element vectors, where the first element
- * contains the index of the dataset and the second element contains the index
- * of the string. Therefore, if [2 3] is passed for a, the string used will be
- * datasets[2][3] (datasets is of type std::vector<std::vector<std::string> >&).
- *
- * @param a Index of string and dataset for first string.
- * @param b Index of string and dataset for second string.
- */
-template<typename VecType>
-double PSpectrumStringKernel::Evaluate(const VecType& a,
- const VecType& b) const
-{
- // Get the map of substrings for the two strings we are interested in.
- const std::map<std::string, int>& aMap = counts[a[0]][a[1]];
- const std::map<std::string, int>& bMap = counts[b[0]][b[1]];
-
- double eval = 0;
-
- // Loop through the two maps (which, when iterated through, are sorted
- // alphabetically).
- std::map<std::string, int>::const_iterator aIt = aMap.begin();
- std::map<std::string, int>::const_iterator bIt = bMap.begin();
-
- while ((aIt != aMap.end()) && (bIt != bMap.end()))
- {
- // Compare alphabetically (this is how std::map is ordered).
- int result = (*aIt).first.compare((*bIt).first);
-
- if (result == 0) // The same substring.
- {
- eval += ((*aIt).second * (*bIt).second);
-
- // Now increment both.
- ++aIt;
- ++bIt;
- }
- else if (result > 0)
- {
- // aIt is "ahead" of bIt (alphabetically); so increment bIt to "catch up".
- ++bIt;
- }
- else
- {
- // bIt is "ahead" of aIt (alphabetically); so increment aIt to "catch up".
- ++aIt;
- }
- }
-
- return eval;
-}
-
-}; // namespace kernel
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/pspectrum_string_kernel_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/pspectrum_string_kernel_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/pspectrum_string_kernel_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/pspectrum_string_kernel_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,91 @@
+/**
+ * @file pspectrum_string_kernel_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the p-spectrum string kernel, created for use with FastMKS.
+ * Instead of passing a data matrix to FastMKS which stores the kernels, pass a
+ * one-dimensional data matrix (data vector) to FastMKS which stores indices of
+ * strings; then, the actual strings are given to the PSpectrumStringKernel at
+ * construction time, and the kernel knows to map the indices to actual strings.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_KERNELS_PSPECTRUM_STRING_KERNEL_IMPL_HPP
+#define __MLPACK_CORE_KERNELS_PSPECTRUM_STRING_KERNEL_IMPL_HPP
+
+// In case it has not been included yet.
+#include "pspectrum_string_kernel.hpp"
+
+namespace mlpack {
+namespace kernel {
+
+/**
+ * Evaluate the kernel for the string indices given. As mentioned in the class
+ * documentation, a and b should be 2-element vectors, where the first element
+ * contains the index of the dataset and the second element contains the index
+ * of the string. Therefore, if [2 3] is passed for a, the string used will be
+ * datasets[2][3] (datasets is of type std::vector<std::vector<std::string> >&).
+ *
+ * @param a Index of string and dataset for first string.
+ * @param b Index of string and dataset for second string.
+ */
+template<typename VecType>
+double PSpectrumStringKernel::Evaluate(const VecType& a,
+ const VecType& b) const
+{
+ // Get the map of substrings for the two strings we are interested in.
+ const std::map<std::string, int>& aMap = counts[a[0]][a[1]];
+ const std::map<std::string, int>& bMap = counts[b[0]][b[1]];
+
+ double eval = 0;
+
+ // Loop through the two maps (which, when iterated through, are sorted
+ // alphabetically).
+ std::map<std::string, int>::const_iterator aIt = aMap.begin();
+ std::map<std::string, int>::const_iterator bIt = bMap.begin();
+
+ while ((aIt != aMap.end()) && (bIt != bMap.end()))
+ {
+ // Compare alphabetically (this is how std::map is ordered).
+ int result = (*aIt).first.compare((*bIt).first);
+
+ if (result == 0) // The same substring.
+ {
+ eval += ((*aIt).second * (*bIt).second);
+
+ // Now increment both.
+ ++aIt;
+ ++bIt;
+ }
+ else if (result > 0)
+ {
+ // aIt is "ahead" of bIt (alphabetically); so increment bIt to "catch up".
+ ++bIt;
+ }
+ else
+ {
+ // bIt is "ahead" of aIt (alphabetically); so increment aIt to "catch up".
+ ++aIt;
+ }
+ }
+
+ return eval;
+}
+
+}; // namespace kernel
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/spherical_kernel.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/spherical_kernel.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/spherical_kernel.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,105 +0,0 @@
-/**
- * @file spherical_kernel.hpp
- * @author Neil Slagle
- *
- * This is an example kernel. If you are making your own kernel, follow the
- * outline specified in this file.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_KERNELS_SPHERICAL_KERNEL_H
-#define __MLPACK_CORE_KERNELS_SPHERICAL_KERNEL_H
-
-#include <boost/math/special_functions/gamma.hpp>
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace kernel {
-
-class SphericalKernel
-{
- public:
- SphericalKernel() :
- bandwidth(1.0),
- bandwidthSquared(1.0) {}
- SphericalKernel(double b) :
- bandwidth(b),
- bandwidthSquared(b*b) {}
-
- template<typename VecType>
- double Evaluate(const VecType& a, const VecType& b)
- {
- return
- (metric::SquaredEuclideanDistance::Evaluate(a, b) <= bandwidthSquared) ?
- 1.0 : 0.0;
- }
- /**
- * Obtains the convolution integral [integral K(||x-a||)K(||b-x||)dx]
- * for the two vectors. In this case, because
- * our simple example kernel has no internal parameters, we can declare the
- * function static. For a more complex example which cannot be declared
- * static, see the GaussianKernel, which stores an internal parameter.
- *
- * @tparam VecType Type of vector (arma::vec, arma::spvec should be expected).
- * @param a First vector.
- * @param b Second vector.
- * @return the convolution integral value.
- */
- template<typename VecType>
- double ConvolutionIntegral(const VecType& a, const VecType& b)
- {
- double distance = sqrt(metric::SquaredEuclideanDistance::Evaluate(a, b));
- if (distance >= 2.0 * bandwidth)
- {
- return 0.0;
- }
- double volumeSquared = pow(Normalizer(a.n_rows), 2.0);
-
- switch(a.n_rows)
- {
- case 1:
- return 1.0 / volumeSquared * (2.0 * bandwidth - distance);
- break;
- case 2:
- return 1.0 / volumeSquared *
- (2.0 * bandwidth * bandwidth * acos(distance/(2.0 * bandwidth)) -
- distance / 4.0 * sqrt(4.0*bandwidth*bandwidth-distance*distance));
- break;
- default:
- Log::Fatal << "The spherical kernel does not support convolution\
- integrals above dimension two, yet..." << std::endl;
- return -1.0;
- break;
- }
- }
- double Normalizer(size_t dimension)
- {
- return pow(bandwidth, (double) dimension) * pow(M_PI, dimension / 2.0) /
- boost::math::tgamma(dimension / 2.0 + 1.0);
- }
- double Evaluate(double t)
- {
- return (t <= bandwidth) ? 1.0 : 0.0;
- }
- private:
- double bandwidth;
- double bandwidthSquared;
-};
-
-}; // namespace kernel
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/spherical_kernel.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/spherical_kernel.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/spherical_kernel.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/spherical_kernel.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,105 @@
+/**
+ * @file spherical_kernel.hpp
+ * @author Neil Slagle
+ *
+ * This is an example kernel. If you are making your own kernel, follow the
+ * outline specified in this file.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_KERNELS_SPHERICAL_KERNEL_H
+#define __MLPACK_CORE_KERNELS_SPHERICAL_KERNEL_H
+
+#include <boost/math/special_functions/gamma.hpp>
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace kernel {
+
+class SphericalKernel
+{
+ public:
+ SphericalKernel() :
+ bandwidth(1.0),
+ bandwidthSquared(1.0) {}
+ SphericalKernel(double b) :
+ bandwidth(b),
+ bandwidthSquared(b*b) {}
+
+ template<typename VecType>
+ double Evaluate(const VecType& a, const VecType& b)
+ {
+ return
+ (metric::SquaredEuclideanDistance::Evaluate(a, b) <= bandwidthSquared) ?
+ 1.0 : 0.0;
+ }
+ /**
+ * Obtains the convolution integral [integral K(||x-a||)K(||b-x||)dx]
+ * for the two vectors. In this case, because
+ * our simple example kernel has no internal parameters, we can declare the
+ * function static. For a more complex example which cannot be declared
+ * static, see the GaussianKernel, which stores an internal parameter.
+ *
+ * @tparam VecType Type of vector (arma::vec, arma::spvec should be expected).
+ * @param a First vector.
+ * @param b Second vector.
+ * @return the convolution integral value.
+ */
+ template<typename VecType>
+ double ConvolutionIntegral(const VecType& a, const VecType& b)
+ {
+ double distance = sqrt(metric::SquaredEuclideanDistance::Evaluate(a, b));
+ if (distance >= 2.0 * bandwidth)
+ {
+ return 0.0;
+ }
+ double volumeSquared = pow(Normalizer(a.n_rows), 2.0);
+
+ switch(a.n_rows)
+ {
+ case 1:
+ return 1.0 / volumeSquared * (2.0 * bandwidth - distance);
+ break;
+ case 2:
+ return 1.0 / volumeSquared *
+ (2.0 * bandwidth * bandwidth * acos(distance/(2.0 * bandwidth)) -
+ distance / 4.0 * sqrt(4.0*bandwidth*bandwidth-distance*distance));
+ break;
+ default:
+ Log::Fatal << "The spherical kernel does not support convolution\
+ integrals above dimension two, yet..." << std::endl;
+ return -1.0;
+ break;
+ }
+ }
+ double Normalizer(size_t dimension)
+ {
+ return pow(bandwidth, (double) dimension) * pow(M_PI, dimension / 2.0) /
+ boost::math::tgamma(dimension / 2.0 + 1.0);
+ }
+ double Evaluate(double t)
+ {
+ return (t <= bandwidth) ? 1.0 : 0.0;
+ }
+ private:
+ double bandwidth;
+ double bandwidthSquared;
+};
+
+}; // namespace kernel
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/triangular_kernel.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/triangular_kernel.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/triangular_kernel.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,76 +0,0 @@
-/**
- * @file triangular_kernel.hpp
- * @author Ryan Curtin
- *
- * Definition and implementation of the trivially simple triangular kernel.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_KERNELS_TRIANGULAR_KERNEL_HPP
-#define __MLPACK_CORE_KERNELS_TRIANGULAR_KERNEL_HPP
-
-#include <mlpack/core.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-
-namespace mlpack {
-namespace kernel {
-
-/**
- * The trivially simple triangular kernel, defined by
- *
- * @f[
- * K(x, y) = \max \{ 0, 1 - \frac{|| x - y ||_2}{b} \}
- * @f]
- *
- * where \f$ b \f$ is the bandwidth of the kernel.
- */
-class TriangularKernel
-{
- public:
- /**
- * Initialize the triangular kernel with the given bandwidth (default 1.0).
- *
- * @param bandwidth Bandwidth of the triangular kernel.
- */
- TriangularKernel(const double bandwidth = 1.0) : bandwidth(bandwidth) { }
-
- /**
- * Evaluate the triangular kernel for the two given vectors.
- *
- * @param a First vector.
- * @param b Second vector.
- */
- template<typename Vec1Type, typename Vec2Type>
- double Evaluate(const Vec1Type& a, const Vec2Type& b)
- {
- return std::max(0.0, (1 - metric::EuclideanDistance::Evaluate(a, b) /
- bandwidth));
- }
-
- //! Get the bandwidth of the kernel.
- double Bandwidth() const { return bandwidth; }
- //! Modify the bandwidth of the kernel.
- double& Bandwidth() { return bandwidth; }
-
- private:
- //! The bandwidth of the kernel.
- double bandwidth;
-};
-
-}; // namespace kernel
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/triangular_kernel.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/kernels/triangular_kernel.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/triangular_kernel.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/kernels/triangular_kernel.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,76 @@
+/**
+ * @file triangular_kernel.hpp
+ * @author Ryan Curtin
+ *
+ * Definition and implementation of the trivially simple triangular kernel.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_KERNELS_TRIANGULAR_KERNEL_HPP
+#define __MLPACK_CORE_KERNELS_TRIANGULAR_KERNEL_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+
+namespace mlpack {
+namespace kernel {
+
+/**
+ * The trivially simple triangular kernel, defined by
+ *
+ * @f[
+ * K(x, y) = \max \{ 0, 1 - \frac{|| x - y ||_2}{b} \}
+ * @f]
+ *
+ * where \f$ b \f$ is the bandwidth of the kernel.
+ */
+class TriangularKernel
+{
+ public:
+ /**
+ * Initialize the triangular kernel with the given bandwidth (default 1.0).
+ *
+ * @param bandwidth Bandwidth of the triangular kernel.
+ */
+ TriangularKernel(const double bandwidth = 1.0) : bandwidth(bandwidth) { }
+
+ /**
+ * Evaluate the triangular kernel for the two given vectors.
+ *
+ * @param a First vector.
+ * @param b Second vector.
+ */
+ template<typename Vec1Type, typename Vec2Type>
+ double Evaluate(const Vec1Type& a, const Vec2Type& b)
+ {
+ return std::max(0.0, (1 - metric::EuclideanDistance::Evaluate(a, b) /
+ bandwidth));
+ }
+
+ //! Get the bandwidth of the kernel.
+ double Bandwidth() const { return bandwidth; }
+ //! Modify the bandwidth of the kernel.
+ double& Bandwidth() { return bandwidth; }
+
+ private:
+ //! The bandwidth of the kernel.
+ double bandwidth;
+};
+
+}; // namespace kernel
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/clamp.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/math/clamp.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/clamp.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,77 +0,0 @@
-/**
- * @file clamp.hpp
- *
- * Miscellaneous math clamping routines.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_MATH_CLAMP_HPP
-#define __MLPACK_CORE_MATH_CLAMP_HPP
-
-#include <stdlib.h>
-#include <math.h>
-#include <float.h>
-
-namespace mlpack {
-namespace math /** Miscellaneous math routines. */ {
-
-/**
- * Forces a number to be non-negative, turning negative numbers into zero.
- * Avoids branching costs (this is a measurable improvement).
- *
- * @param d Double to clamp.
- * @return 0 if d < 0, d otherwise.
- */
-inline double ClampNonNegative(const double d)
-{
- return (d + fabs(d)) / 2;
-}
-
-/**
- * Forces a number to be non-positive, turning positive numbers into zero.
- * Avoids branching costs (this is a measurable improvement).
- *
- * @param d Double to clamp.
- * @param 0 if d > 0, d otherwise.
- */
-inline double ClampNonPositive(const double d)
-{
- return (d - fabs(d)) / 2;
-}
-
-/**
- * Clamp a number between a particular range.
- *
- * @param value The number to clamp.
- * @param rangeMin The first of the range.
- * @param rangeMax The last of the range.
- * @return max(rangeMin, min(rangeMax, d)).
- */
-inline double ClampRange(double value,
- const double rangeMin,
- const double rangeMax)
-{
- value -= rangeMax;
- value = ClampNonPositive(value) + rangeMax;
- value -= rangeMin;
- value = ClampNonNegative(value) + rangeMin;
- return value;
-}
-
-}; // namespace math
-}; // namespace mlpack
-
-#endif // __MLPACK_CORE_MATH_CLAMP_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/clamp.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/math/clamp.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/clamp.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/clamp.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,77 @@
+/**
+ * @file clamp.hpp
+ *
+ * Miscellaneous math clamping routines.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_MATH_CLAMP_HPP
+#define __MLPACK_CORE_MATH_CLAMP_HPP
+
+#include <stdlib.h>
+#include <math.h>
+#include <float.h>
+
+namespace mlpack {
+namespace math /** Miscellaneous math routines. */ {
+
+/**
+ * Forces a number to be non-negative, turning negative numbers into zero.
+ * Avoids branching costs (this is a measurable improvement).
+ *
+ * @param d Double to clamp.
+ * @return 0 if d < 0, d otherwise.
+ */
+inline double ClampNonNegative(const double d)
+{
+ return (d + fabs(d)) / 2;
+}
+
+/**
+ * Forces a number to be non-positive, turning positive numbers into zero.
+ * Avoids branching costs (this is a measurable improvement).
+ *
+ * @param d Double to clamp.
+ * @param 0 if d > 0, d otherwise.
+ */
+inline double ClampNonPositive(const double d)
+{
+ return (d - fabs(d)) / 2;
+}
+
+/**
+ * Clamp a number between a particular range.
+ *
+ * @param value The number to clamp.
+ * @param rangeMin The first of the range.
+ * @param rangeMax The last of the range.
+ * @return max(rangeMin, min(rangeMax, d)).
+ */
+inline double ClampRange(double value,
+ const double rangeMin,
+ const double rangeMax)
+{
+ value -= rangeMax;
+ value = ClampNonPositive(value) + rangeMax;
+ value -= rangeMin;
+ value = ClampNonNegative(value) + rangeMin;
+ return value;
+}
+
+}; // namespace math
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_MATH_CLAMP_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/lin_alg.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/math/lin_alg.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/lin_alg.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,229 +0,0 @@
-/**
- * @file lin_alg.cpp
- * @author Nishant Mehta
- *
- * Linear algebra utilities.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "lin_alg.hpp"
-
-#define max_rand_i 100000
-
-using namespace mlpack;
-using namespace math;
-
-/**
- * Auxiliary function to raise vector elements to a specific power. The sign
- * is ignored in the power operation and then re-added. Useful for
- * eigenvalues.
- */
-void mlpack::math::VectorPower(arma::vec& vec, double power)
-{
- for (size_t i = 0; i < vec.n_elem; i++)
- {
- if (std::abs(vec(i)) > 1e-12)
- vec(i) = (vec(i) > 0) ? std::pow(vec(i), (double) power) :
- -std::pow(-vec(i), (double) power);
- else
- vec(i) = 0;
- }
-}
-
-/**
- * Creates a centered matrix, where centering is done by subtracting
- * the sum over the columns (a column vector) from each column of the matrix.
- *
- * @param x Input matrix
- * @param xCentered Matrix to write centered output into
- */
-void mlpack::math::Center(const arma::mat& x, arma::mat& xCentered)
-{
- // Sum matrix along dimension 0 (that is, sum elements in each row).
- arma::vec rowVectorSum = arma::sum(x, 1);
- rowVectorSum /= x.n_cols; // scale
-
- xCentered.set_size(x.n_rows, x.n_cols);
- for (size_t i = 0; i < x.n_rows; i++)
- xCentered.row(i) = x.row(i) - rowVectorSum(i);
-}
-
-/**
- * Whitens a matrix using the singular value decomposition of the covariance
- * matrix. Whitening means the covariance matrix of the result is the identity
- * matrix.
- */
-void mlpack::math::WhitenUsingSVD(const arma::mat& x,
- arma::mat& xWhitened,
- arma::mat& whiteningMatrix)
-{
- arma::mat covX, u, v, invSMatrix, temp1;
- arma::vec sVector;
-
- covX = ccov(x);
-
- svd(u, sVector, v, covX);
-
- size_t d = sVector.n_elem;
- invSMatrix.zeros(d, d);
- invSMatrix.diag() = 1 / sqrt(sVector);
-
- whiteningMatrix = v * invSMatrix * trans(u);
-
- xWhitened = whiteningMatrix * x;
-}
-
-/**
- * Whitens a matrix using the eigendecomposition of the covariance matrix.
- * Whitening means the covariance matrix of the result is the identity matrix.
- */
-void mlpack::math::WhitenUsingEig(const arma::mat& x,
- arma::mat& xWhitened,
- arma::mat& whiteningMatrix)
-{
- arma::mat diag, eigenvectors;
- arma::vec eigenvalues;
-
- // Get eigenvectors of covariance of input matrix.
- eig_sym(eigenvalues, eigenvectors, ccov(x));
-
- // Generate diagonal matrix using 1 / sqrt(eigenvalues) for each value.
- VectorPower(eigenvalues, -0.5);
- diag.zeros(eigenvalues.n_elem, eigenvalues.n_elem);
- diag.diag() = eigenvalues;
-
- // Our whitening matrix is diag(1 / sqrt(eigenvectors)) * eigenvalues.
- whiteningMatrix = diag * trans(eigenvectors);
-
- // Now apply the whitening matrix.
- xWhitened = whiteningMatrix * x;
-}
-
-/**
- * Overwrites a dimension-N vector to a random vector on the unit sphere in R^N.
- */
-void mlpack::math::RandVector(arma::vec& v)
-{
- v.zeros();
-
- for (size_t i = 0; i + 1 < v.n_elem; i += 2)
- {
- double a = Random();
- double b = Random();
- double first_term = sqrt(-2 * log(a));
- double second_term = 2 * M_PI * b;
- v[i] = first_term * cos(second_term);
- v[i + 1] = first_term * sin(second_term);
- }
-
- if ((v.n_elem % 2) == 1)
- {
- v[v.n_elem - 1] = sqrt(-2 * log(math::Random())) * cos(2 * M_PI *
- math::Random());
- }
-
- v /= sqrt(dot(v, v));
-}
-
-/**
- * Orthogonalize x and return the result in W, using eigendecomposition.
- * We will be using the formula \f$ W = x (x^T x)^{-0.5} \f$.
- */
-void mlpack::math::Orthogonalize(const arma::mat& x, arma::mat& W)
-{
- // For a matrix A, A^N = V * D^N * V', where VDV' is the
- // eigendecomposition of the matrix A.
- arma::mat eigenvalues, eigenvectors;
- arma::vec egval;
- eig_sym(egval, eigenvectors, ccov(x));
- VectorPower(egval, -0.5);
-
- eigenvalues.zeros(egval.n_elem, egval.n_elem);
- eigenvalues.diag() = egval;
-
- arma::mat at = (eigenvectors * eigenvalues * trans(eigenvectors));
-
- W = at * x;
-}
-
-/**
- * Orthogonalize x in-place. This could be sped up by a custom
- * implementation.
- */
-void mlpack::math::Orthogonalize(arma::mat& x)
-{
- Orthogonalize(x, x);
-}
-
-/**
- * Remove a certain set of rows in a matrix while copying to a second matrix.
- *
- * @param input Input matrix to copy.
- * @param rowsToRemove Vector containing indices of rows to be removed.
- * @param output Matrix to copy non-removed rows into.
- */
-void mlpack::math::RemoveRows(const arma::mat& input,
- const std::vector<size_t>& rowsToRemove,
- arma::mat& output)
-{
- const size_t nRemove = rowsToRemove.size();
- const size_t nKeep = input.n_rows - nRemove;
-
- if (nRemove == 0)
- {
- output = input; // Copy everything.
- }
- else
- {
- output.set_size(nKeep, input.n_cols);
-
- size_t curRow = 0;
- size_t removeInd = 0;
- // First, check 0 to first row to remove.
- if (rowsToRemove[0] > 0)
- {
- // Note that this implies that n_rows > 1.
- output.rows(0, rowsToRemove[0] - 1) = input.rows(0, rowsToRemove[0] - 1);
- curRow += rowsToRemove[0];
- }
-
- // Now, check i'th row to remove to (i + 1)'th row to remove, until i is the
- // penultimate row.
- while (removeInd < nRemove - 1)
- {
- const size_t height = rowsToRemove[removeInd + 1] -
- rowsToRemove[removeInd] - 1;
-
- if (height > 0)
- {
- output.rows(curRow, curRow + height - 1) =
- input.rows(rowsToRemove[removeInd] + 1,
- rowsToRemove[removeInd + 1] - 1);
- curRow += height;
- }
-
- removeInd++;
- }
-
- // Now that i is the last row to remove, check last row to remove to last
- // row.
- if (rowsToRemove[removeInd] < input.n_rows - 1)
- {
- output.rows(curRow, nKeep - 1) = input.rows(rowsToRemove[removeInd] + 1,
- input.n_rows - 1);
- }
- }
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/lin_alg.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/math/lin_alg.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/lin_alg.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/lin_alg.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,229 @@
+/**
+ * @file lin_alg.cpp
+ * @author Nishant Mehta
+ *
+ * Linear algebra utilities.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "lin_alg.hpp"
+
+#define max_rand_i 100000
+
+using namespace mlpack;
+using namespace math;
+
+/**
+ * Auxiliary function to raise vector elements to a specific power. The sign
+ * is ignored in the power operation and then re-added. Useful for
+ * eigenvalues.
+ */
+void mlpack::math::VectorPower(arma::vec& vec, double power)
+{
+ for (size_t i = 0; i < vec.n_elem; i++)
+ {
+ if (std::abs(vec(i)) > 1e-12)
+ vec(i) = (vec(i) > 0) ? std::pow(vec(i), (double) power) :
+ -std::pow(-vec(i), (double) power);
+ else
+ vec(i) = 0;
+ }
+}
+
+/**
+ * Creates a centered matrix, where centering is done by subtracting
+ * the sum over the columns (a column vector) from each column of the matrix.
+ *
+ * @param x Input matrix
+ * @param xCentered Matrix to write centered output into
+ */
+void mlpack::math::Center(const arma::mat& x, arma::mat& xCentered)
+{
+ // Sum matrix along dimension 0 (that is, sum elements in each row).
+ arma::vec rowVectorSum = arma::sum(x, 1);
+ rowVectorSum /= x.n_cols; // scale
+
+ xCentered.set_size(x.n_rows, x.n_cols);
+ for (size_t i = 0; i < x.n_rows; i++)
+ xCentered.row(i) = x.row(i) - rowVectorSum(i);
+}
+
+/**
+ * Whitens a matrix using the singular value decomposition of the covariance
+ * matrix. Whitening means the covariance matrix of the result is the identity
+ * matrix.
+ */
+void mlpack::math::WhitenUsingSVD(const arma::mat& x,
+ arma::mat& xWhitened,
+ arma::mat& whiteningMatrix)
+{
+ arma::mat covX, u, v, invSMatrix, temp1;
+ arma::vec sVector;
+
+ covX = ccov(x);
+
+ svd(u, sVector, v, covX);
+
+ size_t d = sVector.n_elem;
+ invSMatrix.zeros(d, d);
+ invSMatrix.diag() = 1 / sqrt(sVector);
+
+ whiteningMatrix = v * invSMatrix * trans(u);
+
+ xWhitened = whiteningMatrix * x;
+}
+
+/**
+ * Whitens a matrix using the eigendecomposition of the covariance matrix.
+ * Whitening means the covariance matrix of the result is the identity matrix.
+ */
+void mlpack::math::WhitenUsingEig(const arma::mat& x,
+ arma::mat& xWhitened,
+ arma::mat& whiteningMatrix)
+{
+ arma::mat diag, eigenvectors;
+ arma::vec eigenvalues;
+
+ // Get eigenvectors of covariance of input matrix.
+ eig_sym(eigenvalues, eigenvectors, ccov(x));
+
+ // Generate diagonal matrix using 1 / sqrt(eigenvalues) for each value.
+ VectorPower(eigenvalues, -0.5);
+ diag.zeros(eigenvalues.n_elem, eigenvalues.n_elem);
+ diag.diag() = eigenvalues;
+
+ // Our whitening matrix is diag(1 / sqrt(eigenvectors)) * eigenvalues.
+ whiteningMatrix = diag * trans(eigenvectors);
+
+ // Now apply the whitening matrix.
+ xWhitened = whiteningMatrix * x;
+}
+
+/**
+ * Overwrites a dimension-N vector to a random vector on the unit sphere in R^N.
+ */
+void mlpack::math::RandVector(arma::vec& v)
+{
+ v.zeros();
+
+ for (size_t i = 0; i + 1 < v.n_elem; i += 2)
+ {
+ double a = Random();
+ double b = Random();
+ double first_term = sqrt(-2 * log(a));
+ double second_term = 2 * M_PI * b;
+ v[i] = first_term * cos(second_term);
+ v[i + 1] = first_term * sin(second_term);
+ }
+
+ if ((v.n_elem % 2) == 1)
+ {
+ v[v.n_elem - 1] = sqrt(-2 * log(math::Random())) * cos(2 * M_PI *
+ math::Random());
+ }
+
+ v /= sqrt(dot(v, v));
+}
+
+/**
+ * Orthogonalize x and return the result in W, using eigendecomposition.
+ * We will be using the formula \f$ W = x (x^T x)^{-0.5} \f$.
+ */
+void mlpack::math::Orthogonalize(const arma::mat& x, arma::mat& W)
+{
+ // For a matrix A, A^N = V * D^N * V', where VDV' is the
+ // eigendecomposition of the matrix A.
+ arma::mat eigenvalues, eigenvectors;
+ arma::vec egval;
+ eig_sym(egval, eigenvectors, ccov(x));
+ VectorPower(egval, -0.5);
+
+ eigenvalues.zeros(egval.n_elem, egval.n_elem);
+ eigenvalues.diag() = egval;
+
+ arma::mat at = (eigenvectors * eigenvalues * trans(eigenvectors));
+
+ W = at * x;
+}
+
+/**
+ * Orthogonalize x in-place. This could be sped up by a custom
+ * implementation.
+ */
+void mlpack::math::Orthogonalize(arma::mat& x)
+{
+ Orthogonalize(x, x);
+}
+
+/**
+ * Remove a certain set of rows in a matrix while copying to a second matrix.
+ *
+ * @param input Input matrix to copy.
+ * @param rowsToRemove Vector containing indices of rows to be removed.
+ * @param output Matrix to copy non-removed rows into.
+ */
+void mlpack::math::RemoveRows(const arma::mat& input,
+ const std::vector<size_t>& rowsToRemove,
+ arma::mat& output)
+{
+ const size_t nRemove = rowsToRemove.size();
+ const size_t nKeep = input.n_rows - nRemove;
+
+ if (nRemove == 0)
+ {
+ output = input; // Copy everything.
+ }
+ else
+ {
+ output.set_size(nKeep, input.n_cols);
+
+ size_t curRow = 0;
+ size_t removeInd = 0;
+ // First, check 0 to first row to remove.
+ if (rowsToRemove[0] > 0)
+ {
+ // Note that this implies that n_rows > 1.
+ output.rows(0, rowsToRemove[0] - 1) = input.rows(0, rowsToRemove[0] - 1);
+ curRow += rowsToRemove[0];
+ }
+
+ // Now, check i'th row to remove to (i + 1)'th row to remove, until i is the
+ // penultimate row.
+ while (removeInd < nRemove - 1)
+ {
+ const size_t height = rowsToRemove[removeInd + 1] -
+ rowsToRemove[removeInd] - 1;
+
+ if (height > 0)
+ {
+ output.rows(curRow, curRow + height - 1) =
+ input.rows(rowsToRemove[removeInd] + 1,
+ rowsToRemove[removeInd + 1] - 1);
+ curRow += height;
+ }
+
+ removeInd++;
+ }
+
+ // Now that i is the last row to remove, check last row to remove to last
+ // row.
+ if (rowsToRemove[removeInd] < input.n_rows - 1)
+ {
+ output.rows(curRow, nKeep - 1) = input.rows(rowsToRemove[removeInd] + 1,
+ input.n_rows - 1);
+ }
+ }
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/lin_alg.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/math/lin_alg.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/lin_alg.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,97 +0,0 @@
-/**
- * @file lin_alg.hpp
- * @author Nishant Mehta
- *
- * Linear algebra utilities.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_MATH_LIN_ALG_HPP
-#define __MLPACK_CORE_MATH_LIN_ALG_HPP
-
-#include <mlpack/core.hpp>
-
-/**
- * Linear algebra utility functions, generally performed on matrices or vectors.
- */
-namespace mlpack {
-namespace math {
-
-/**
- * Auxiliary function to raise vector elements to a specific power. The sign
- * is ignored in the power operation and then re-added. Useful for
- * eigenvalues.
- */
-void VectorPower(arma::vec& vec, double power);
-
-/**
- * Creates a centered matrix, where centering is done by subtracting
- * the sum over the columns (a column vector) from each column of the matrix.
- *
- * @param x Input matrix
- * @param xCentered Matrix to write centered output into
- */
-void Center(const arma::mat& x, arma::mat& xCentered);
-
-/**
- * Whitens a matrix using the singular value decomposition of the covariance
- * matrix. Whitening means the covariance matrix of the result is the identity
- * matrix.
- */
-void WhitenUsingSVD(const arma::mat& x,
- arma::mat& xWhitened,
- arma::mat& whiteningMatrix);
-
-/**
- * Whitens a matrix using the eigendecomposition of the covariance matrix.
- * Whitening means the covariance matrix of the result is the identity matrix.
- */
-void WhitenUsingEig(const arma::mat& x,
- arma::mat& xWhitened,
- arma::mat& whiteningMatrix);
-
-/**
- * Overwrites a dimension-N vector to a random vector on the unit sphere in R^N.
- */
-void RandVector(arma::vec& v);
-
-/**
- * Orthogonalize x and return the result in W, using eigendecomposition.
- * We will be using the formula \f$ W = x (x^T x)^{-0.5} \f$.
- */
-void Orthogonalize(const arma::mat& x, arma::mat& W);
-
-/**
- * Orthogonalize x in-place. This could be sped up by a custom
- * implementation.
- */
-void Orthogonalize(arma::mat& x);
-
-/**
- * Remove a certain set of rows in a matrix while copying to a second matrix.
- *
- * @param input Input matrix to copy.
- * @param rowsToRemove Vector containing indices of rows to be removed.
- * @param output Matrix to copy non-removed rows into.
- */
-void RemoveRows(const arma::mat& input,
- const std::vector<size_t>& rowsToRemove,
- arma::mat& output);
-
-}; // namespace math
-}; // namespace mlpack
-
-#endif // __MLPACK_CORE_MATH_LIN_ALG_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/lin_alg.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/math/lin_alg.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/lin_alg.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/lin_alg.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,97 @@
+/**
+ * @file lin_alg.hpp
+ * @author Nishant Mehta
+ *
+ * Linear algebra utilities.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_MATH_LIN_ALG_HPP
+#define __MLPACK_CORE_MATH_LIN_ALG_HPP
+
+#include <mlpack/core.hpp>
+
+/**
+ * Linear algebra utility functions, generally performed on matrices or vectors.
+ */
+namespace mlpack {
+namespace math {
+
+/**
+ * Auxiliary function to raise vector elements to a specific power. The sign
+ * is ignored in the power operation and then re-added. Useful for
+ * eigenvalues.
+ */
+void VectorPower(arma::vec& vec, double power);
+
+/**
+ * Creates a centered matrix, where centering is done by subtracting
+ * the sum over the columns (a column vector) from each column of the matrix.
+ *
+ * @param x Input matrix
+ * @param xCentered Matrix to write centered output into
+ */
+void Center(const arma::mat& x, arma::mat& xCentered);
+
+/**
+ * Whitens a matrix using the singular value decomposition of the covariance
+ * matrix. Whitening means the covariance matrix of the result is the identity
+ * matrix.
+ */
+void WhitenUsingSVD(const arma::mat& x,
+ arma::mat& xWhitened,
+ arma::mat& whiteningMatrix);
+
+/**
+ * Whitens a matrix using the eigendecomposition of the covariance matrix.
+ * Whitening means the covariance matrix of the result is the identity matrix.
+ */
+void WhitenUsingEig(const arma::mat& x,
+ arma::mat& xWhitened,
+ arma::mat& whiteningMatrix);
+
+/**
+ * Overwrites a dimension-N vector to a random vector on the unit sphere in R^N.
+ */
+void RandVector(arma::vec& v);
+
+/**
+ * Orthogonalize x and return the result in W, using eigendecomposition.
+ * We will be using the formula \f$ W = x (x^T x)^{-0.5} \f$.
+ */
+void Orthogonalize(const arma::mat& x, arma::mat& W);
+
+/**
+ * Orthogonalize x in-place. This could be sped up by a custom
+ * implementation.
+ */
+void Orthogonalize(arma::mat& x);
+
+/**
+ * Remove a certain set of rows in a matrix while copying to a second matrix.
+ *
+ * @param input Input matrix to copy.
+ * @param rowsToRemove Vector containing indices of rows to be removed.
+ * @param output Matrix to copy non-removed rows into.
+ */
+void RemoveRows(const arma::mat& input,
+ const std::vector<size_t>& rowsToRemove,
+ arma::mat& output);
+
+}; // namespace math
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_MATH_LIN_ALG_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/random.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/math/random.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/random.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,51 +0,0 @@
-/**
- * @file random.cpp
- *
- * Declarations of global Boost random number generators.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <boost/random.hpp>
-#include <boost/version.hpp>
-
-namespace mlpack {
-namespace math {
-
-#if BOOST_VERSION >= 104700
- // Global random object.
- boost::random::mt19937 randGen;
- // Global uniform distribution.
- boost::random::uniform_01<> randUniformDist;
- // Global normal distribution.
- boost::random::normal_distribution<> randNormalDist;
-#else
- // Global random object.
- boost::mt19937 randGen;
-
- #if BOOST_VERSION >= 103900
- // Global uniform distribution.
- boost::uniform_01<> randUniformDist;
- #else
- // Pre-1.39 Boost.Random did not give default template parameter values.
- boost::uniform_01<boost::mt19937, double> randUniformDist(randGen);
- #endif
-
- // Global normal distribution.
- boost::normal_distribution<> randNormalDist;
-#endif
-
-}; // namespace math
-}; // namespace mlpack
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/random.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/math/random.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/random.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/random.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,51 @@
+/**
+ * @file random.cpp
+ *
+ * Declarations of global Boost random number generators.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <boost/random.hpp>
+#include <boost/version.hpp>
+
+namespace mlpack {
+namespace math {
+
+#if BOOST_VERSION >= 104700
+ // Global random object.
+ boost::random::mt19937 randGen;
+ // Global uniform distribution.
+ boost::random::uniform_01<> randUniformDist;
+ // Global normal distribution.
+ boost::random::normal_distribution<> randNormalDist;
+#else
+ // Global random object.
+ boost::mt19937 randGen;
+
+ #if BOOST_VERSION >= 103900
+ // Global uniform distribution.
+ boost::uniform_01<> randUniformDist;
+ #else
+ // Pre-1.39 Boost.Random did not give default template parameter values.
+ boost::uniform_01<boost::mt19937, double> randUniformDist(randGen);
+ #endif
+
+ // Global normal distribution.
+ boost::normal_distribution<> randNormalDist;
+#endif
+
+}; // namespace math
+}; // namespace mlpack
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/random.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/math/random.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/random.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,154 +0,0 @@
-/**
- * @file random.hpp
- *
- * Miscellaneous math random-related routines.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_MATH_RANDOM_HPP
-#define __MLPACK_CORE_MATH_RANDOM_HPP
-
-#include <stdlib.h>
-#include <math.h>
-#include <float.h>
-
-#include <boost/random.hpp>
-
-namespace mlpack {
-namespace math /** Miscellaneous math routines. */ {
-
-// Annoying Boost versioning issues.
-#include <boost/version.hpp>
-
-#if BOOST_VERSION >= 104700
- // Global random object.
- extern boost::random::mt19937 randGen;
- // Global uniform distribution.
- extern boost::random::uniform_01<> randUniformDist;
- // Global normal distribution.
- extern boost::random::normal_distribution<> randNormalDist;
-#else
- // Global random object.
- extern boost::mt19937 randGen;
-
- #if BOOST_VERSION >= 103900
- // Global uniform distribution.
- extern boost::uniform_01<> randUniformDist;
- #else
- // Pre-1.39 Boost.Random did not give default template parameter values.
- extern boost::uniform_01<boost::mt19937, double> randUniformDist;
- #endif
-
- // Global normal distribution.
- extern boost::normal_distribution<> randNormalDist;
-#endif
-
-/**
- * Set the random seed used by the random functions (Random() and RandInt()).
- * The seed is casted to a 32-bit integer before being given to the random
- * number generator, but a size_t is taken as a parameter for API consistency.
- *
- * @param seed Seed for the random number generator.
- */
-inline void RandomSeed(const size_t seed)
-{
- randGen.seed((uint32_t) seed);
- srand((unsigned int) seed);
-}
-
-/**
- * Generates a uniform random number between 0 and 1.
- */
-inline double Random()
-{
-#if BOOST_VERSION >= 103900
- return randUniformDist(randGen);
-#else
- // Before Boost 1.39, we did not give the random object when we wanted a
- // random number; that gets given at construction time.
- return randUniformDist();
-#endif
-}
-
-/**
- * Generates a uniform random number in the specified range.
- */
-inline double Random(const double lo, const double hi)
-{
-#if BOOST_VERSION >= 103900
- return lo + (hi - lo) * randUniformDist(randGen);
-#else
- // Before Boost 1.39, we did not give the random object when we wanted a
- // random number; that gets given at construction time.
- return lo + (hi - lo) * randUniformDist();
-#endif
-}
-
-/**
- * Generates a uniform random integer.
- */
-inline int RandInt(const int hiExclusive)
-{
-#if BOOST_VERSION >= 103900
- return (int) std::floor((double) hiExclusive * randUniformDist(randGen));
-#else
- // Before Boost 1.39, we did not give the random object when we wanted a
- // random number; that gets given at construction time.
- return (int) std::floor((double) hiExclusive * randUniformDist());
-#endif
-}
-
-/**
- * Generates a uniform random integer.
- */
-inline int RandInt(const int lo, const int hiExclusive)
-{
-#if BOOST_VERSION >= 103900
- return lo + (int) std::floor((double) (hiExclusive - lo)
- * randUniformDist(randGen));
-#else
- // Before Boost 1.39, we did not give the random object when we wanted a
- // random number; that gets given at construction time.
- return lo + (int) std::floor((double) (hiExclusive - lo)
- * randUniformDist());
-#endif
-
-}
-
-/**
- * Generates a normally distributed random number with mean 0 and variance 1.
- */
-inline double RandNormal()
-{
- return randNormalDist(randGen);
-}
-
-/**
- * Generates a normally distributed random number with specified mean and
- * variance.
- *
- * @param mean Mean of distribution.
- * @param variance Variance of distribution.
- */
-inline double RandNormal(const double mean, const double variance)
-{
- return variance * randNormalDist(randGen) + mean;
-}
-
-}; // namespace math
-}; // namespace mlpack
-
-#endif // __MLPACK_CORE_MATH_MATH_LIB_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/random.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/math/random.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/random.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/random.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,154 @@
+/**
+ * @file random.hpp
+ *
+ * Miscellaneous math random-related routines.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_MATH_RANDOM_HPP
+#define __MLPACK_CORE_MATH_RANDOM_HPP
+
+#include <stdlib.h>
+#include <math.h>
+#include <float.h>
+
+#include <boost/random.hpp>
+
+namespace mlpack {
+namespace math /** Miscellaneous math routines. */ {
+
+// Annoying Boost versioning issues.
+#include <boost/version.hpp>
+
+#if BOOST_VERSION >= 104700
+ // Global random object.
+ extern boost::random::mt19937 randGen;
+ // Global uniform distribution.
+ extern boost::random::uniform_01<> randUniformDist;
+ // Global normal distribution.
+ extern boost::random::normal_distribution<> randNormalDist;
+#else
+ // Global random object.
+ extern boost::mt19937 randGen;
+
+ #if BOOST_VERSION >= 103900
+ // Global uniform distribution.
+ extern boost::uniform_01<> randUniformDist;
+ #else
+ // Pre-1.39 Boost.Random did not give default template parameter values.
+ extern boost::uniform_01<boost::mt19937, double> randUniformDist;
+ #endif
+
+ // Global normal distribution.
+ extern boost::normal_distribution<> randNormalDist;
+#endif
+
+/**
+ * Set the random seed used by the random functions (Random() and RandInt()).
+ * The seed is casted to a 32-bit integer before being given to the random
+ * number generator, but a size_t is taken as a parameter for API consistency.
+ *
+ * @param seed Seed for the random number generator.
+ */
+inline void RandomSeed(const size_t seed)
+{
+ randGen.seed((uint32_t) seed);
+ srand((unsigned int) seed);
+}
+
+/**
+ * Generates a uniform random number between 0 and 1.
+ */
+inline double Random()
+{
+#if BOOST_VERSION >= 103900
+ return randUniformDist(randGen);
+#else
+ // Before Boost 1.39, we did not give the random object when we wanted a
+ // random number; that gets given at construction time.
+ return randUniformDist();
+#endif
+}
+
+/**
+ * Generates a uniform random number in the specified range.
+ */
+inline double Random(const double lo, const double hi)
+{
+#if BOOST_VERSION >= 103900
+ return lo + (hi - lo) * randUniformDist(randGen);
+#else
+ // Before Boost 1.39, we did not give the random object when we wanted a
+ // random number; that gets given at construction time.
+ return lo + (hi - lo) * randUniformDist();
+#endif
+}
+
+/**
+ * Generates a uniform random integer.
+ */
+inline int RandInt(const int hiExclusive)
+{
+#if BOOST_VERSION >= 103900
+ return (int) std::floor((double) hiExclusive * randUniformDist(randGen));
+#else
+ // Before Boost 1.39, we did not give the random object when we wanted a
+ // random number; that gets given at construction time.
+ return (int) std::floor((double) hiExclusive * randUniformDist());
+#endif
+}
+
+/**
+ * Generates a uniform random integer.
+ */
+inline int RandInt(const int lo, const int hiExclusive)
+{
+#if BOOST_VERSION >= 103900
+ return lo + (int) std::floor((double) (hiExclusive - lo)
+ * randUniformDist(randGen));
+#else
+ // Before Boost 1.39, we did not give the random object when we wanted a
+ // random number; that gets given at construction time.
+ return lo + (int) std::floor((double) (hiExclusive - lo)
+ * randUniformDist());
+#endif
+
+}
+
+/**
+ * Generates a normally distributed random number with mean 0 and variance 1.
+ */
+inline double RandNormal()
+{
+ return randNormalDist(randGen);
+}
+
+/**
+ * Generates a normally distributed random number with specified mean and
+ * variance.
+ *
+ * @param mean Mean of distribution.
+ * @param variance Variance of distribution.
+ */
+inline double RandNormal(const double mean, const double variance)
+{
+ return variance * randNormalDist(randGen) + mean;
+}
+
+}; // namespace math
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_MATH_MATH_LIB_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/range.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/math/range.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/range.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,187 +0,0 @@
-/**
- * @file range.hpp
- *
- * Definition of the Range class, which represents a simple range with a lower
- * and upper bound.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_MATH_RANGE_HPP
-#define __MLPACK_CORE_MATH_RANGE_HPP
-
-namespace mlpack {
-namespace math {
-
-/**
- * Simple real-valued range. It contains an upper and lower bound.
- */
-class Range
-{
- private:
- double lo; /// The lower bound.
- double hi; /// The upper bound.
-
- public:
- /** Initialize to an empty set (where lo > hi). */
- inline Range();
-
- /***
- * Initialize a range to enclose only the given point (lo = point, hi =
- * point).
- *
- * @param point Point that this range will enclose.
- */
- inline Range(const double point);
-
- /**
- * Initializes to specified range.
- *
- * @param lo Lower bound of the range.
- * @param hi Upper bound of the range.
- */
- inline Range(const double lo, const double hi);
-
- //! Get the lower bound.
- inline double Lo() const { return lo; }
- //! Modify the lower bound.
- inline double& Lo() { return lo; }
-
- //! Get the upper bound.
- inline double Hi() const { return hi; }
- //! Modify the upper bound.
- inline double& Hi() { return hi; }
-
- /**
- * Gets the span of the range (hi - lo).
- */
- inline double Width() const;
-
- /**
- * Gets the midpoint of this range.
- */
- inline double Mid() const;
-
- /**
- * Expands this range to include another range.
- *
- * @param rhs Range to include.
- */
- inline Range& operator|=(const Range& rhs);
-
- /**
- * Expands this range to include another range.
- *
- * @param rhs Range to include.
- */
- inline Range operator|(const Range& rhs) const;
-
- /**
- * Shrinks this range to be the overlap with another range; this makes an
- * empty set if there is no overlap.
- *
- * @param rhs Other range.
- */
- inline Range& operator&=(const Range& rhs);
-
- /**
- * Shrinks this range to be the overlap with another range; this makes an
- * empty set if there is no overlap.
- *
- * @param rhs Other range.
- */
- inline Range operator&(const Range& rhs) const;
-
- /**
- * Scale the bounds by the given double.
- *
- * @param d Scaling factor.
- */
- inline Range& operator*=(const double d);
-
- /**
- * Scale the bounds by the given double.
- *
- * @param d Scaling factor.
- */
- inline Range operator*(const double d) const;
-
- /**
- * Scale the bounds by the given double.
- *
- * @param d Scaling factor.
- */
- friend inline Range operator*(const double d, const Range& r); // Symmetric.
-
- /**
- * Compare with another range for strict equality.
- *
- * @param rhs Other range.
- */
- inline bool operator==(const Range& rhs) const;
-
- /**
- * Compare with another range for strict equality.
- *
- * @param rhs Other range.
- */
- inline bool operator!=(const Range& rhs) const;
-
- /**
- * Compare with another range. For Range objects x and y, x < y means that x
- * is strictly less than y and does not overlap at all.
- *
- * @param rhs Other range.
- */
- inline bool operator<(const Range& rhs) const;
-
- /**
- * Compare with another range. For Range objects x and y, x < y means that x
- * is strictly less than y and does not overlap at all.
- *
- * @param rhs Other range.
- */
- inline bool operator>(const Range& rhs) const;
-
- /**
- * Determines if a point is contained within the range.
- *
- * @param d Point to check.
- */
- inline bool Contains(const double d) const;
-
- /**
- * Determines if another range overlaps with this one.
- *
- * @param r Other range.
- *
- * @return true if ranges overlap at all.
- */
- inline bool Contains(const Range& r) const;
-
- /**
- * Returns a string representation of an object.
- */
- inline std::string ToString() const;
-
-};
-
-}; // namespace math
-}; // namespace mlpack
-
-// Include inlined implementation.
-#include "range_impl.hpp"
-
-#endif // __MLPACK_CORE_MATH_RANGE_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/range.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/math/range.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/range.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/range.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,187 @@
+/**
+ * @file range.hpp
+ *
+ * Definition of the Range class, which represents a simple range with a lower
+ * and upper bound.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_MATH_RANGE_HPP
+#define __MLPACK_CORE_MATH_RANGE_HPP
+
+namespace mlpack {
+namespace math {
+
+/**
+ * Simple real-valued range. It contains an upper and lower bound.
+ */
+class Range
+{
+ private:
+ double lo; /// The lower bound.
+ double hi; /// The upper bound.
+
+ public:
+ /** Initialize to an empty set (where lo > hi). */
+ inline Range();
+
+ /***
+ * Initialize a range to enclose only the given point (lo = point, hi =
+ * point).
+ *
+ * @param point Point that this range will enclose.
+ */
+ inline Range(const double point);
+
+ /**
+ * Initializes to specified range.
+ *
+ * @param lo Lower bound of the range.
+ * @param hi Upper bound of the range.
+ */
+ inline Range(const double lo, const double hi);
+
+ //! Get the lower bound.
+ inline double Lo() const { return lo; }
+ //! Modify the lower bound.
+ inline double& Lo() { return lo; }
+
+ //! Get the upper bound.
+ inline double Hi() const { return hi; }
+ //! Modify the upper bound.
+ inline double& Hi() { return hi; }
+
+ /**
+ * Gets the span of the range (hi - lo).
+ */
+ inline double Width() const;
+
+ /**
+ * Gets the midpoint of this range.
+ */
+ inline double Mid() const;
+
+ /**
+ * Expands this range to include another range.
+ *
+ * @param rhs Range to include.
+ */
+ inline Range& operator|=(const Range& rhs);
+
+ /**
+ * Expands this range to include another range.
+ *
+ * @param rhs Range to include.
+ */
+ inline Range operator|(const Range& rhs) const;
+
+ /**
+ * Shrinks this range to be the overlap with another range; this makes an
+ * empty set if there is no overlap.
+ *
+ * @param rhs Other range.
+ */
+ inline Range& operator&=(const Range& rhs);
+
+ /**
+ * Shrinks this range to be the overlap with another range; this makes an
+ * empty set if there is no overlap.
+ *
+ * @param rhs Other range.
+ */
+ inline Range operator&(const Range& rhs) const;
+
+ /**
+ * Scale the bounds by the given double.
+ *
+ * @param d Scaling factor.
+ */
+ inline Range& operator*=(const double d);
+
+ /**
+ * Scale the bounds by the given double.
+ *
+ * @param d Scaling factor.
+ */
+ inline Range operator*(const double d) const;
+
+ /**
+ * Scale the bounds by the given double.
+ *
+ * @param d Scaling factor.
+ */
+ friend inline Range operator*(const double d, const Range& r); // Symmetric.
+
+ /**
+ * Compare with another range for strict equality.
+ *
+ * @param rhs Other range.
+ */
+ inline bool operator==(const Range& rhs) const;
+
+ /**
+ * Compare with another range for strict equality.
+ *
+ * @param rhs Other range.
+ */
+ inline bool operator!=(const Range& rhs) const;
+
+ /**
+ * Compare with another range. For Range objects x and y, x < y means that x
+ * is strictly less than y and does not overlap at all.
+ *
+ * @param rhs Other range.
+ */
+ inline bool operator<(const Range& rhs) const;
+
+ /**
+ * Compare with another range. For Range objects x and y, x < y means that x
+ * is strictly less than y and does not overlap at all.
+ *
+ * @param rhs Other range.
+ */
+ inline bool operator>(const Range& rhs) const;
+
+ /**
+ * Determines if a point is contained within the range.
+ *
+ * @param d Point to check.
+ */
+ inline bool Contains(const double d) const;
+
+ /**
+ * Determines if another range overlaps with this one.
+ *
+ * @param r Other range.
+ *
+ * @return true if ranges overlap at all.
+ */
+ inline bool Contains(const Range& r) const;
+
+ /**
+ * Returns a string representation of an object.
+ */
+ inline std::string ToString() const;
+
+};
+
+}; // namespace math
+}; // namespace mlpack
+
+// Include inlined implementation.
+#include "range_impl.hpp"
+
+#endif // __MLPACK_CORE_MATH_RANGE_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/range_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/math/range_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/range_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,204 +0,0 @@
-/**
- * @file range_impl.hpp
- *
- * Implementation of the (inlined) Range class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_MATH_RANGE_IMPL_HPP
-#define __MLPACK_CORE_MATH_RANGE_IMPL_HPP
-
-#include "range.hpp"
-#include <float.h>
-#include <sstream>
-
-namespace mlpack {
-namespace math {
-
-/**
- * Initialize the range to 0.
- */
-inline Range::Range() :
- lo(DBL_MAX), hi(-DBL_MAX) { /* nothing else to do */ }
-
-/**
- * Initialize a range to enclose only the given point.
- */
-inline Range::Range(const double point) :
- lo(point), hi(point) { /* nothing else to do */ }
-
-/**
- * Initializes the range to the specified values.
- */
-inline Range::Range(const double lo, const double hi) :
- lo(lo), hi(hi) { /* nothing else to do */ }
-
-/**
- * Gets the span of the range, hi - lo. Returns 0 if the range is negative.
- */
-inline double Range::Width() const
-{
- if (lo < hi)
- return (hi - lo);
- else
- return 0.0;
-}
-
-/**
- * Gets the midpoint of this range.
- */
-inline double Range::Mid() const
-{
- return (hi + lo) / 2;
-}
-
-/**
- * Expands range to include the other range.
- */
-inline Range& Range::operator|=(const Range& rhs)
-{
- if (rhs.lo < lo)
- lo = rhs.lo;
- if (rhs.hi > hi)
- hi = rhs.hi;
-
- return *this;
-}
-
-inline Range Range::operator|(const Range& rhs) const
-{
- return Range((rhs.lo < lo) ? rhs.lo : lo,
- (rhs.hi > hi) ? rhs.hi : hi);
-}
-
-/**
- * Shrinks range to be the overlap with another range, becoming an empty
- * set if there is no overlap.
- */
-inline Range& Range::operator&=(const Range& rhs)
-{
- if (rhs.lo > lo)
- lo = rhs.lo;
- if (rhs.hi < hi)
- hi = rhs.hi;
-
- return *this;
-}
-
-inline Range Range::operator&(const Range& rhs) const
-{
- return Range((rhs.lo > lo) ? rhs.lo : lo,
- (rhs.hi < hi) ? rhs.hi : hi);
-}
-
-/**
- * Scale the bounds by the given double.
- */
-inline Range& Range::operator*=(const double d)
-{
- lo *= d;
- hi *= d;
-
- // Now if we've negated, we need to flip things around so the bound is valid.
- if (lo > hi)
- {
- double tmp = hi;
- hi = lo;
- lo = tmp;
- }
-
- return *this;
-}
-
-inline Range Range::operator*(const double d) const
-{
- double nlo = lo * d;
- double nhi = hi * d;
-
- if (nlo <= nhi)
- return Range(nlo, nhi);
- else
- return Range(nhi, nlo);
-}
-
-// Symmetric case.
-inline Range operator*(const double d, const Range& r)
-{
- double nlo = r.lo * d;
- double nhi = r.hi * d;
-
- if (nlo <= nhi)
- return Range(nlo, nhi);
- else
- return Range(nhi, nlo);
-}
-
-/**
- * Compare with another range for strict equality.
- */
-inline bool Range::operator==(const Range& rhs) const
-{
- return (lo == rhs.lo) && (hi == rhs.hi);
-}
-
-inline bool Range::operator!=(const Range& rhs) const
-{
- return (lo != rhs.lo) || (hi != rhs.hi);
-}
-
-/**
- * Compare with another range. For Range objects x and y, x < y means that x is
- * strictly less than y and does not overlap at all.
- */
-inline bool Range::operator<(const Range& rhs) const
-{
- return hi < rhs.lo;
-}
-
-inline bool Range::operator>(const Range& rhs) const
-{
- return lo > rhs.hi;
-}
-
-/**
- * Determines if a point is contained within the range.
- */
-inline bool Range::Contains(const double d) const
-{
- return d >= lo && d <= hi;
-}
-
-/**
- * Determines if this range overlaps with another range.
- */
-inline bool Range::Contains(const Range& r) const
-{
- return lo <= r.hi;
-}
-/**
- * Returns a string representation of an object.
- */
-std::string Range::ToString() const
-{
- std::ostringstream convert;
- convert << "[" << lo << ", " << hi << "]";
- return convert.str();
-}
-
-}; // namesapce math
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/range_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/math/range_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/range_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/range_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,204 @@
+/**
+ * @file range_impl.hpp
+ *
+ * Implementation of the (inlined) Range class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_MATH_RANGE_IMPL_HPP
+#define __MLPACK_CORE_MATH_RANGE_IMPL_HPP
+
+#include "range.hpp"
+#include <float.h>
+#include <sstream>
+
+namespace mlpack {
+namespace math {
+
+/**
+ * Initialize the range to 0.
+ */
+inline Range::Range() :
+ lo(DBL_MAX), hi(-DBL_MAX) { /* nothing else to do */ }
+
+/**
+ * Initialize a range to enclose only the given point.
+ */
+inline Range::Range(const double point) :
+ lo(point), hi(point) { /* nothing else to do */ }
+
+/**
+ * Initializes the range to the specified values.
+ */
+inline Range::Range(const double lo, const double hi) :
+ lo(lo), hi(hi) { /* nothing else to do */ }
+
+/**
+ * Gets the span of the range, hi - lo. Returns 0 if the range is negative.
+ */
+inline double Range::Width() const
+{
+ if (lo < hi)
+ return (hi - lo);
+ else
+ return 0.0;
+}
+
+/**
+ * Gets the midpoint of this range.
+ */
+inline double Range::Mid() const
+{
+ return (hi + lo) / 2;
+}
+
+/**
+ * Expands range to include the other range.
+ */
+inline Range& Range::operator|=(const Range& rhs)
+{
+ if (rhs.lo < lo)
+ lo = rhs.lo;
+ if (rhs.hi > hi)
+ hi = rhs.hi;
+
+ return *this;
+}
+
+inline Range Range::operator|(const Range& rhs) const
+{
+ return Range((rhs.lo < lo) ? rhs.lo : lo,
+ (rhs.hi > hi) ? rhs.hi : hi);
+}
+
+/**
+ * Shrinks range to be the overlap with another range, becoming an empty
+ * set if there is no overlap.
+ */
+inline Range& Range::operator&=(const Range& rhs)
+{
+ if (rhs.lo > lo)
+ lo = rhs.lo;
+ if (rhs.hi < hi)
+ hi = rhs.hi;
+
+ return *this;
+}
+
+inline Range Range::operator&(const Range& rhs) const
+{
+ return Range((rhs.lo > lo) ? rhs.lo : lo,
+ (rhs.hi < hi) ? rhs.hi : hi);
+}
+
+/**
+ * Scale the bounds by the given double.
+ */
+inline Range& Range::operator*=(const double d)
+{
+ lo *= d;
+ hi *= d;
+
+ // Now if we've negated, we need to flip things around so the bound is valid.
+ if (lo > hi)
+ {
+ double tmp = hi;
+ hi = lo;
+ lo = tmp;
+ }
+
+ return *this;
+}
+
+inline Range Range::operator*(const double d) const
+{
+ double nlo = lo * d;
+ double nhi = hi * d;
+
+ if (nlo <= nhi)
+ return Range(nlo, nhi);
+ else
+ return Range(nhi, nlo);
+}
+
+// Symmetric case.
+inline Range operator*(const double d, const Range& r)
+{
+ double nlo = r.lo * d;
+ double nhi = r.hi * d;
+
+ if (nlo <= nhi)
+ return Range(nlo, nhi);
+ else
+ return Range(nhi, nlo);
+}
+
+/**
+ * Compare with another range for strict equality.
+ */
+inline bool Range::operator==(const Range& rhs) const
+{
+ return (lo == rhs.lo) && (hi == rhs.hi);
+}
+
+inline bool Range::operator!=(const Range& rhs) const
+{
+ return (lo != rhs.lo) || (hi != rhs.hi);
+}
+
+/**
+ * Compare with another range. For Range objects x and y, x < y means that x is
+ * strictly less than y and does not overlap at all.
+ */
+inline bool Range::operator<(const Range& rhs) const
+{
+ return hi < rhs.lo;
+}
+
+inline bool Range::operator>(const Range& rhs) const
+{
+ return lo > rhs.hi;
+}
+
+/**
+ * Determines if a point is contained within the range.
+ */
+inline bool Range::Contains(const double d) const
+{
+ return d >= lo && d <= hi;
+}
+
+/**
+ * Determines if this range overlaps with another range.
+ */
+inline bool Range::Contains(const Range& r) const
+{
+ return lo <= r.hi;
+}
+/**
+ * Returns a string representation of an object.
+ */
+std::string Range::ToString() const
+{
+ std::ostringstream convert;
+ convert << "[" << lo << ", " << hi << "]";
+ return convert.str();
+}
+
+}; // namesapce math
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/round.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/math/round.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/round.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,41 +0,0 @@
-/**
- * @file round.hpp
- * @author Ryan Curtin
- *
- * Implementation of round() for use on Visual Studio, where C99 isn't
- * implemented.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_MATH_ROUND_HPP
-#define __MLPACK_CORE_MATH_ROUND_HPP
-
-// _MSC_VER should only be defined for Visual Studio, which doesn't implement
-// C99.
-#ifdef _MSC_VER
-
-// This function ends up going into the global namespace, so it can be used in
-// place of C99's round().
-
-//! Round a number to the nearest integer.
-inline double round(double a)
-{
- return floor(a + 0.5);
-}
-
-#endif
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/round.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/math/round.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/round.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/math/round.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,41 @@
+/**
+ * @file round.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of round() for use on Visual Studio, where C99 isn't
+ * implemented.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_MATH_ROUND_HPP
+#define __MLPACK_CORE_MATH_ROUND_HPP
+
+// _MSC_VER should only be defined for Visual Studio, which doesn't implement
+// C99.
+#ifdef _MSC_VER
+
+// This function ends up going into the global namespace, so it can be used in
+// place of C99's round().
+
+//! Round a number to the nearest integer.
+inline double round(double a)
+{
+ return floor(a + 0.5);
+}
+
+#endif
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/lmetric.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/metrics/lmetric.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/lmetric.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,112 +0,0 @@
-/**
- * @file lmetric.hpp
- * @author Ryan Curtin
- *
- * Generalized L-metric, allowing both squared distances to be returned as well
- * as non-squared distances. The squared distances are faster to compute.
- *
- * This also gives several convenience typedefs for commonly used L-metrics.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_METRICS_LMETRIC_HPP
-#define __MLPACK_CORE_METRICS_LMETRIC_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace metric {
-
-/**
- * The L_p metric for arbitrary integer p, with an option to take the root.
- *
- * This class implements the standard L_p metric for two arbitary vectors @f$ x
- * @f$ and @f$ y @f$ of dimensionality @f$ n @f$:
- *
- * @f[
- * d(x, y) = \left( \sum_{i = 1}^{n} | x_i - y_i |^p \right)^{\frac{1}{p}}.
- * @f]
- *
- * The value of p is given as a template parameter.
- *
- * In addition, the function @f$ d(x, y) @f$ can be simplified, neglecting the
- * p-root calculation. This is done by specifying the TakeRoot template
- * parameter to be false. Then,
- *
- * @f[
- * d(x, y) = \sum_{i = 1}^{n} | x_i - y_i |^p
- * @f]
- *
- * It is faster to compute that distance, so TakeRoot is by default off.
- * However, when TakeRoot is false, the distance given is not actually a true
- * metric -- it does not satisfy the triangle inequality. Some MLPACK methods
- * do not require the triangle inequality to operate correctly (such as the
- * BinarySpaceTree), but setting TakeRoot = false in some cases will cause
- * incorrect results.
- *
- * A few convenience typedefs are given:
- *
- * - ManhattanDistance
- * - EuclideanDistance
- * - SquaredEuclideanDistance
- *
- * @tparam Power Power of metric; i.e. Power = 1 gives the L1-norm (Manhattan
- * distance).
- * @tparam TakeRoot If true, the Power'th root of the result is taken before it
- * is returned. Setting this to false causes the metric to not satisfy the
- * Triangle Inequality (be careful!).
- */
-template<int Power, bool TakeRoot = true>
-class LMetric
-{
- public:
- /***
- * Default constructor does nothing, but is required to satisfy the Kernel
- * policy.
- */
- LMetric() { }
-
- /**
- * Computes the distance between two points.
- */
- template<typename VecType1, typename VecType2>
- static double Evaluate(const VecType1& a, const VecType2& b);
-};
-
-// Convenience typedefs.
-
-/***
- * The Manhattan (L1) distance.
- */
-typedef LMetric<1, false> ManhattanDistance;
-
-/***
- * The squared Euclidean (L2) distance.
- */
-typedef LMetric<2, false> SquaredEuclideanDistance;
-
-/***
- * The Euclidean (L2) distance.
- */
-typedef LMetric<2, true> EuclideanDistance;
-
-}; // namespace metric
-}; // namespace mlpack
-
-// Include implementation.
-#include "lmetric_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/lmetric.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/metrics/lmetric.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/lmetric.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/lmetric.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,112 @@
+/**
+ * @file lmetric.hpp
+ * @author Ryan Curtin
+ *
+ * Generalized L-metric, allowing both squared distances to be returned as well
+ * as non-squared distances. The squared distances are faster to compute.
+ *
+ * This also gives several convenience typedefs for commonly used L-metrics.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_METRICS_LMETRIC_HPP
+#define __MLPACK_CORE_METRICS_LMETRIC_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace metric {
+
+/**
+ * The L_p metric for arbitrary integer p, with an option to take the root.
+ *
+ * This class implements the standard L_p metric for two arbitary vectors @f$ x
+ * @f$ and @f$ y @f$ of dimensionality @f$ n @f$:
+ *
+ * @f[
+ * d(x, y) = \left( \sum_{i = 1}^{n} | x_i - y_i |^p \right)^{\frac{1}{p}}.
+ * @f]
+ *
+ * The value of p is given as a template parameter.
+ *
+ * In addition, the function @f$ d(x, y) @f$ can be simplified, neglecting the
+ * p-root calculation. This is done by specifying the TakeRoot template
+ * parameter to be false. Then,
+ *
+ * @f[
+ * d(x, y) = \sum_{i = 1}^{n} | x_i - y_i |^p
+ * @f]
+ *
+ * It is faster to compute that distance, so TakeRoot is by default off.
+ * However, when TakeRoot is false, the distance given is not actually a true
+ * metric -- it does not satisfy the triangle inequality. Some MLPACK methods
+ * do not require the triangle inequality to operate correctly (such as the
+ * BinarySpaceTree), but setting TakeRoot = false in some cases will cause
+ * incorrect results.
+ *
+ * A few convenience typedefs are given:
+ *
+ * - ManhattanDistance
+ * - EuclideanDistance
+ * - SquaredEuclideanDistance
+ *
+ * @tparam Power Power of metric; i.e. Power = 1 gives the L1-norm (Manhattan
+ * distance).
+ * @tparam TakeRoot If true, the Power'th root of the result is taken before it
+ * is returned. Setting this to false causes the metric to not satisfy the
+ * Triangle Inequality (be careful!).
+ */
+template<int Power, bool TakeRoot = true>
+class LMetric
+{
+ public:
+ /***
+ * Default constructor does nothing, but is required to satisfy the Kernel
+ * policy.
+ */
+ LMetric() { }
+
+ /**
+ * Computes the distance between two points.
+ */
+ template<typename VecType1, typename VecType2>
+ static double Evaluate(const VecType1& a, const VecType2& b);
+};
+
+// Convenience typedefs.
+
+/***
+ * The Manhattan (L1) distance.
+ */
+typedef LMetric<1, false> ManhattanDistance;
+
+/***
+ * The squared Euclidean (L2) distance.
+ */
+typedef LMetric<2, false> SquaredEuclideanDistance;
+
+/***
+ * The Euclidean (L2) distance.
+ */
+typedef LMetric<2, true> EuclideanDistance;
+
+}; // namespace metric
+}; // namespace mlpack
+
+// Include implementation.
+#include "lmetric_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/lmetric_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/metrics/lmetric_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/lmetric_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,99 +0,0 @@
-/**
- * @file lmetric_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of template specializations of LMetric class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_METRICS_LMETRIC_IMPL_HPP
-#define __MLPACK_CORE_METRICS_LMETRIC_IMPL_HPP
-
-// In case it hasn't been included.
-#include "lmetric.hpp"
-
-namespace mlpack {
-namespace metric {
-
-// Unspecialized implementation. This should almost never be used...
-template<int Power, bool TakeRoot>
-template<typename VecType1, typename VecType2>
-double LMetric<Power, TakeRoot>::Evaluate(const VecType1& a,
- const VecType2& b)
-{
- double sum = 0;
- for (size_t i = 0; i < a.n_elem; i++)
- sum += pow(fabs(a[i] - b[i]), Power);
-
- if (!TakeRoot) // The compiler should optimize this correctly at compile-time.
- return sum;
-
- return pow(sum, (1.0 / Power));
-}
-
-// L1-metric specializations; the root doesn't matter.
-template<>
-template<typename VecType1, typename VecType2>
-double LMetric<1, true>::Evaluate(const VecType1& a, const VecType2& b)
-{
- return accu(abs(a - b));
-}
-
-template<>
-template<typename VecType1, typename VecType2>
-double LMetric<1, false>::Evaluate(const VecType1& a, const VecType2& b)
-{
- return accu(abs(a - b));
-}
-
-// L2-metric specializations.
-template<>
-template<typename VecType1, typename VecType2>
-double LMetric<2, true>::Evaluate(const VecType1& a, const VecType2& b)
-{
- return sqrt(accu(square(a - b)));
-}
-
-template<>
-template<typename VecType1, typename VecType2>
-double LMetric<2, false>::Evaluate(const VecType1& a, const VecType2& b)
-{
- return accu(square(a - b));
-}
-
-// L3-metric specialization (not very likely to be used, but just in case).
-template<>
-template<typename VecType1, typename VecType2>
-double LMetric<3, true>::Evaluate(const VecType1& a, const VecType2& b)
-{
- double sum = 0;
- for (size_t i = 0; i < a.n_elem; i++)
- sum += pow(fabs(a[i] - b[i]), 3.0);
-
- return pow(accu(pow(abs(a - b), 3.0)), 1.0 / 3.0);
-}
-
-template<>
-template<typename VecType1, typename VecType2>
-double LMetric<3, false>::Evaluate(const VecType1& a, const VecType2& b)
-{
- return accu(pow(abs(a - b), 3.0));
-}
-
-}; // namespace metric
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/lmetric_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/metrics/lmetric_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/lmetric_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/lmetric_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,99 @@
+/**
+ * @file lmetric_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of template specializations of LMetric class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_METRICS_LMETRIC_IMPL_HPP
+#define __MLPACK_CORE_METRICS_LMETRIC_IMPL_HPP
+
+// In case it hasn't been included.
+#include "lmetric.hpp"
+
+namespace mlpack {
+namespace metric {
+
+// Unspecialized implementation. This should almost never be used...
+template<int Power, bool TakeRoot>
+template<typename VecType1, typename VecType2>
+double LMetric<Power, TakeRoot>::Evaluate(const VecType1& a,
+ const VecType2& b)
+{
+ double sum = 0;
+ for (size_t i = 0; i < a.n_elem; i++)
+ sum += pow(fabs(a[i] - b[i]), Power);
+
+ if (!TakeRoot) // The compiler should optimize this correctly at compile-time.
+ return sum;
+
+ return pow(sum, (1.0 / Power));
+}
+
+// L1-metric specializations; the root doesn't matter.
+template<>
+template<typename VecType1, typename VecType2>
+double LMetric<1, true>::Evaluate(const VecType1& a, const VecType2& b)
+{
+ return accu(abs(a - b));
+}
+
+template<>
+template<typename VecType1, typename VecType2>
+double LMetric<1, false>::Evaluate(const VecType1& a, const VecType2& b)
+{
+ return accu(abs(a - b));
+}
+
+// L2-metric specializations.
+template<>
+template<typename VecType1, typename VecType2>
+double LMetric<2, true>::Evaluate(const VecType1& a, const VecType2& b)
+{
+ return sqrt(accu(square(a - b)));
+}
+
+template<>
+template<typename VecType1, typename VecType2>
+double LMetric<2, false>::Evaluate(const VecType1& a, const VecType2& b)
+{
+ return accu(square(a - b));
+}
+
+// L3-metric specialization (not very likely to be used, but just in case).
+template<>
+template<typename VecType1, typename VecType2>
+double LMetric<3, true>::Evaluate(const VecType1& a, const VecType2& b)
+{
+ double sum = 0;
+ for (size_t i = 0; i < a.n_elem; i++)
+ sum += pow(fabs(a[i] - b[i]), 3.0);
+
+ return pow(accu(pow(abs(a - b), 3.0)), 1.0 / 3.0);
+}
+
+template<>
+template<typename VecType1, typename VecType2>
+double LMetric<3, false>::Evaluate(const VecType1& a, const VecType2& b)
+{
+ return accu(pow(abs(a - b), 3.0));
+}
+
+}; // namespace metric
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/mahalanobis_distance.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/metrics/mahalanobis_distance.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/mahalanobis_distance.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,123 +0,0 @@
-/***
- * @file mahalanobis_dstance.h
- * @author Ryan Curtin
- *
- * The Mahalanobis distance.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_METRICS_MAHALANOBIS_DISTANCE_HPP
-#define __MLPACK_CORE_METRICS_MAHALANOBIS_DISTANCE_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace metric {
-
-/**
- * The Mahalanobis distance, which is essentially a stretched Euclidean
- * distance. Given a square covariance matrix @f$ Q @f$ of size @f$ d @f$ x
- * @f$ d @f$, where @f$ d @f$ is the dimensionality of the points it will be
- * evaluating, and given two vectors @f$ x @f$ and @f$ y @f$ also of
- * dimensionality @f$ d @f$,
- *
- * @f[
- * d(x, y) = \sqrt{(x - y)^T Q (x - y)}
- * @f]
- *
- * where Q is the covariance matrix.
- *
- * Because each evaluation multiplies (x_1 - x_2) by the covariance matrix, it
- * may be much quicker to use an LMetric and simply stretch the actual dataset
- * itself before performing any evaluations. However, this class is provided
- * for convenience.
- *
- * Similar to the LMetric class, this offers a template parameter t_take_root
- * which, when set to false, will instead evaluate the distance
- *
- * @f[
- * d(x, y) = (x - y)^T Q (x - y)
- * @f]
- *
- * which is faster to evaluate.
- *
- * @tparam t_take_root If true, takes the root of the output. It is slightly
- * faster to leave this at the default of false.
- */
-template<bool t_take_root = false>
-class MahalanobisDistance
-{
- public:
- /**
- * Initialize the Mahalanobis distance with the empty matrix as covariance.
- * Don't call Evaluate() until you set the covariance with Covariance()!
- */
- MahalanobisDistance() { }
-
- /**
- * Initialize the Mahalanobis distance with the identity matrix of the given
- * dimensionality.
- *
- * @param dimensionality Dimesnsionality of the covariance matrix.
- */
- MahalanobisDistance(const size_t dimensionality) :
- covariance(arma::eye<arma::mat>(dimensionality, dimensionality)) { }
-
- /**
- * Initialize the Mahalanobis distance with the given covariance matrix. The
- * given covariance matrix will be copied (this is not optimal).
- *
- * @param covariance The covariance matrix to use for this distance.
- */
- MahalanobisDistance(const arma::mat& covariance) : covariance(covariance) { }
-
- /**
- * Evaluate the distance between the two given points using this Mahalanobis
- * distance. If the covariance matrix has not been set (i.e. if you used the
- * empty constructor and did not later modify the covariance matrix), calling
- * this method will probably result in a crash.
- *
- * @param a First vector.
- * @param b Second vector.
- */
- template<typename VecType1, typename VecType2>
- double Evaluate(const VecType1& a, const VecType2& b);
-
- /**
- * Access the covariance matrix.
- *
- * @return Constant reference to the covariance matrix.
- */
- const arma::mat& Covariance() const { return covariance; }
-
- /**
- * Modify the covariance matrix.
- *
- * @return Reference to the covariance matrix.
- */
- arma::mat& Covariance() { return covariance; }
-
- private:
- //! The covariance matrix associated with this distance.
- arma::mat covariance;
-};
-
-}; // namespace distance
-}; // namespace mlpack
-
-#include "mahalanobis_distance_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/mahalanobis_distance.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/metrics/mahalanobis_distance.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/mahalanobis_distance.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/mahalanobis_distance.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,123 @@
+/***
+ * @file mahalanobis_dstance.h
+ * @author Ryan Curtin
+ *
+ * The Mahalanobis distance.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_METRICS_MAHALANOBIS_DISTANCE_HPP
+#define __MLPACK_CORE_METRICS_MAHALANOBIS_DISTANCE_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace metric {
+
+/**
+ * The Mahalanobis distance, which is essentially a stretched Euclidean
+ * distance. Given a square covariance matrix @f$ Q @f$ of size @f$ d @f$ x
+ * @f$ d @f$, where @f$ d @f$ is the dimensionality of the points it will be
+ * evaluating, and given two vectors @f$ x @f$ and @f$ y @f$ also of
+ * dimensionality @f$ d @f$,
+ *
+ * @f[
+ * d(x, y) = \sqrt{(x - y)^T Q (x - y)}
+ * @f]
+ *
+ * where Q is the covariance matrix.
+ *
+ * Because each evaluation multiplies (x_1 - x_2) by the covariance matrix, it
+ * may be much quicker to use an LMetric and simply stretch the actual dataset
+ * itself before performing any evaluations. However, this class is provided
+ * for convenience.
+ *
+ * Similar to the LMetric class, this offers a template parameter t_take_root
+ * which, when set to false, will instead evaluate the distance
+ *
+ * @f[
+ * d(x, y) = (x - y)^T Q (x - y)
+ * @f]
+ *
+ * which is faster to evaluate.
+ *
+ * @tparam t_take_root If true, takes the root of the output. It is slightly
+ * faster to leave this at the default of false.
+ */
+template<bool t_take_root = false>
+class MahalanobisDistance
+{
+ public:
+ /**
+ * Initialize the Mahalanobis distance with the empty matrix as covariance.
+ * Don't call Evaluate() until you set the covariance with Covariance()!
+ */
+ MahalanobisDistance() { }
+
+ /**
+ * Initialize the Mahalanobis distance with the identity matrix of the given
+ * dimensionality.
+ *
+ * @param dimensionality Dimesnsionality of the covariance matrix.
+ */
+ MahalanobisDistance(const size_t dimensionality) :
+ covariance(arma::eye<arma::mat>(dimensionality, dimensionality)) { }
+
+ /**
+ * Initialize the Mahalanobis distance with the given covariance matrix. The
+ * given covariance matrix will be copied (this is not optimal).
+ *
+ * @param covariance The covariance matrix to use for this distance.
+ */
+ MahalanobisDistance(const arma::mat& covariance) : covariance(covariance) { }
+
+ /**
+ * Evaluate the distance between the two given points using this Mahalanobis
+ * distance. If the covariance matrix has not been set (i.e. if you used the
+ * empty constructor and did not later modify the covariance matrix), calling
+ * this method will probably result in a crash.
+ *
+ * @param a First vector.
+ * @param b Second vector.
+ */
+ template<typename VecType1, typename VecType2>
+ double Evaluate(const VecType1& a, const VecType2& b);
+
+ /**
+ * Access the covariance matrix.
+ *
+ * @return Constant reference to the covariance matrix.
+ */
+ const arma::mat& Covariance() const { return covariance; }
+
+ /**
+ * Modify the covariance matrix.
+ *
+ * @return Reference to the covariance matrix.
+ */
+ arma::mat& Covariance() { return covariance; }
+
+ private:
+ //! The covariance matrix associated with this distance.
+ arma::mat covariance;
+};
+
+}; // namespace distance
+}; // namespace mlpack
+
+#include "mahalanobis_distance_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/mahalanobis_distance_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/metrics/mahalanobis_distance_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/mahalanobis_distance_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,64 +0,0 @@
-/***
- * @file mahalanobis_distance.cc
- * @author Ryan Curtin
- *
- * Implementation of the Mahalanobis distance.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_METRICS_MAHALANOBIS_DISTANCE_IMPL_HPP
-#define __MLPACK_CORE_METRICS_MAHALANOBIS_DISTANCE_IMPL_HPP
-
-#include "mahalanobis_distance.hpp"
-
-namespace mlpack {
-namespace metric {
-
-/**
- * Specialization for non-rooted case.
- */
-template<>
-template<typename VecType1, typename VecType2>
-double MahalanobisDistance<false>::Evaluate(const VecType1& a,
- const VecType2& b)
-{
- arma::vec m = (a - b);
- arma::mat out = trans(m) * covariance * m; // 1x1
- return out[0];
-}
-
-/**
- * Specialization for rooted case. This requires one extra evaluation of
- * sqrt().
- */
-template<>
-template<typename VecType1, typename VecType2>
-double MahalanobisDistance<true>::Evaluate(const VecType1& a,
- const VecType2& b)
-{
- // Check if covariance matrix has been initialized.
- if (covariance.n_rows == 0)
- covariance = arma::eye<arma::mat>(a.n_elem, a.n_elem);
-
- arma::vec m = (a - b);
- arma::mat out = trans(m) * covariance * m; // 1x1;
- return sqrt(out[0]);
-}
-
-}; // namespace metric
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/mahalanobis_distance_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/metrics/mahalanobis_distance_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/mahalanobis_distance_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/metrics/mahalanobis_distance_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,64 @@
+/***
+ * @file mahalanobis_distance.cc
+ * @author Ryan Curtin
+ *
+ * Implementation of the Mahalanobis distance.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_METRICS_MAHALANOBIS_DISTANCE_IMPL_HPP
+#define __MLPACK_CORE_METRICS_MAHALANOBIS_DISTANCE_IMPL_HPP
+
+#include "mahalanobis_distance.hpp"
+
+namespace mlpack {
+namespace metric {
+
+/**
+ * Specialization for non-rooted case.
+ */
+template<>
+template<typename VecType1, typename VecType2>
+double MahalanobisDistance<false>::Evaluate(const VecType1& a,
+ const VecType2& b)
+{
+ arma::vec m = (a - b);
+ arma::mat out = trans(m) * covariance * m; // 1x1
+ return out[0];
+}
+
+/**
+ * Specialization for rooted case. This requires one extra evaluation of
+ * sqrt().
+ */
+template<>
+template<typename VecType1, typename VecType2>
+double MahalanobisDistance<true>::Evaluate(const VecType1& a,
+ const VecType2& b)
+{
+ // Check if covariance matrix has been initialized.
+ if (covariance.n_rows == 0)
+ covariance = arma::eye<arma::mat>(a.n_elem, a.n_elem);
+
+ arma::vec m = (a - b);
+ arma::mat out = trans(m) * covariance * m; // 1x1;
+ return sqrt(out[0]);
+}
+
+}; // namespace metric
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,157 +0,0 @@
-/**
- * @file aug_lagrangian.hpp
- * @author Ryan Curtin
- *
- * Definition of AugLagrangian class, which implements the Augmented Lagrangian
- * optimization method (also called the 'method of multipliers'. This class
- * uses the L-BFGS optimizer.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#ifndef __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_HPP
-#define __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_HPP
-
-#include <mlpack/core.hpp>
-#include <mlpack/core/optimizers/lbfgs/lbfgs.hpp>
-
-#include "aug_lagrangian_function.hpp"
-
-namespace mlpack {
-namespace optimization {
-
-/**
- * The AugLagrangian class implements the Augmented Lagrangian method of
- * optimization. In this scheme, a penalty term is added to the Lagrangian.
- * This method is also called the "method of multipliers".
- *
- * The template class LagrangianFunction must implement the following five
- * methods:
- *
- * - double Evaluate(const arma::mat& coordinates);
- * - void Gradient(const arma::mat& coordinates, arma::mat& gradient);
- * - size_t NumConstraints();
- * - double EvaluateConstraint(size_t index, const arma::mat& coordinates);
- * - double GradientConstraint(size_t index, const arma::mat& coordinates,
- * arma::mat& gradient);
- *
- * The number of constraints must be greater than or equal to 0, and
- * EvaluateConstraint() should evaluate the constraint at the given index for
- * the given coordinates. Evaluate() should provide the objective function
- * value for the given coordinates.
- *
- * @tparam LagrangianFunction Function which can be optimized by this class.
- */
-template<typename LagrangianFunction>
-class AugLagrangian
-{
- public:
- //! Shorthand for the type of the L-BFGS optimizer we'll be using.
- typedef L_BFGS<AugLagrangianFunction<LagrangianFunction> >
- L_BFGSType;
-
- /**
- * Initialize the Augmented Lagrangian with the default L-BFGS optimizer. We
- * limit the number of L-BFGS iterations to 1000, rather than the unlimited
- * default L-BFGS.
- *
- * @param function The function to be optimized.
- */
- AugLagrangian(LagrangianFunction& function);
-
- /**
- * Initialize the Augmented Lagrangian with a custom L-BFGS optimizer.
- *
- * @param function The function to be optimized. This must be a pre-created
- * utility AugLagrangianFunction.
- * @param lbfgs The custom L-BFGS optimizer to be used. This should have
- * already been initialized with the given AugLagrangianFunction.
- */
- AugLagrangian(AugLagrangianFunction<LagrangianFunction>& augfunc,
- L_BFGSType& lbfgs);
-
- /**
- * Optimize the function. The value '1' is used for the initial value of each
- * Lagrange multiplier. To set the Lagrange multipliers yourself, use the
- * other overload of Optimize().
- *
- * @param coordinates Output matrix to store the optimized coordinates in.
- * @param maxIterations Maximum number of iterations of the Augmented
- * Lagrangian algorithm. 0 indicates no maximum.
- * @param sigma Initial penalty parameter.
- */
- bool Optimize(arma::mat& coordinates,
- const size_t maxIterations = 1000);
-
- /**
- * Optimize the function, giving initial estimates for the Lagrange
- * multipliers. The vector of Lagrange multipliers will be modified to
- * contain the Lagrange multipliers of the final solution (if one is found).
- *
- * @param coordinates Output matrix to store the optimized coordinates in.
- * @param initLambda Vector of initial Lagrange multipliers. Should have
- * length equal to the number of constraints.
- * @param initSigma Initial penalty parameter.
- * @param maxIterations Maximum number of iterations of the Augmented
- * Lagrangian algorithm. 0 indicates no maximum.
- */
- bool Optimize(arma::mat& coordinates,
- const arma::vec& initLambda,
- const double initSigma,
- const size_t maxIterations = 1000);
-
- //! Get the LagrangianFunction.
- const LagrangianFunction& Function() const { return function; }
- //! Modify the LagrangianFunction.
- LagrangianFunction& Function() { return function; }
-
- //! Get the L-BFGS object used for the actual optimization.
- const L_BFGSType& LBFGS() const { return lbfgs; }
- //! Modify the L-BFGS object used for the actual optimization.
- L_BFGSType& LBFGS() { return lbfgs; }
-
- //! Get the Lagrange multipliers.
- const arma::vec& Lambda() const { return augfunc.Lambda(); }
- //! Modify the Lagrange multipliers (i.e. set them before optimization).
- arma::vec& Lambda() { return augfunc.Lambda(); }
-
- //! Get the penalty parameter.
- double Sigma() const { return augfunc.Sigma(); }
- //! Modify the penalty parameter.
- double& Sigma() { return augfunc.Sigma(); }
-
- private:
- //! Function to be optimized.
- LagrangianFunction& function;
-
- //! Internally used AugLagrangianFunction which holds the function we are
- //! optimizing. This isn't publically accessible, but we provide ways to get
- //! to the Lagrange multipliers and the penalty parameter sigma.
- AugLagrangianFunction<LagrangianFunction> augfunc;
-
- //! If the user did not pass an L_BFGS object, we'll use our own internal one.
- L_BFGSType lbfgsInternal;
-
- //! The L-BFGS optimizer that we will use.
- L_BFGSType& lbfgs;
-};
-
-}; // namespace optimization
-}; // namespace mlpack
-
-#include "aug_lagrangian_impl.hpp"
-
-#endif // __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,157 @@
+/**
+ * @file aug_lagrangian.hpp
+ * @author Ryan Curtin
+ *
+ * Definition of AugLagrangian class, which implements the Augmented Lagrangian
+ * optimization method (also called the 'method of multipliers'. This class
+ * uses the L-BFGS optimizer.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_HPP
+#define __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/optimizers/lbfgs/lbfgs.hpp>
+
+#include "aug_lagrangian_function.hpp"
+
+namespace mlpack {
+namespace optimization {
+
+/**
+ * The AugLagrangian class implements the Augmented Lagrangian method of
+ * optimization. In this scheme, a penalty term is added to the Lagrangian.
+ * This method is also called the "method of multipliers".
+ *
+ * The template class LagrangianFunction must implement the following five
+ * methods:
+ *
+ * - double Evaluate(const arma::mat& coordinates);
+ * - void Gradient(const arma::mat& coordinates, arma::mat& gradient);
+ * - size_t NumConstraints();
+ * - double EvaluateConstraint(size_t index, const arma::mat& coordinates);
+ * - double GradientConstraint(size_t index, const arma::mat& coordinates,
+ * arma::mat& gradient);
+ *
+ * The number of constraints must be greater than or equal to 0, and
+ * EvaluateConstraint() should evaluate the constraint at the given index for
+ * the given coordinates. Evaluate() should provide the objective function
+ * value for the given coordinates.
+ *
+ * @tparam LagrangianFunction Function which can be optimized by this class.
+ */
+template<typename LagrangianFunction>
+class AugLagrangian
+{
+ public:
+ //! Shorthand for the type of the L-BFGS optimizer we'll be using.
+ typedef L_BFGS<AugLagrangianFunction<LagrangianFunction> >
+ L_BFGSType;
+
+ /**
+ * Initialize the Augmented Lagrangian with the default L-BFGS optimizer. We
+ * limit the number of L-BFGS iterations to 1000, rather than the unlimited
+ * default L-BFGS.
+ *
+ * @param function The function to be optimized.
+ */
+ AugLagrangian(LagrangianFunction& function);
+
+ /**
+ * Initialize the Augmented Lagrangian with a custom L-BFGS optimizer.
+ *
+ * @param function The function to be optimized. This must be a pre-created
+ * utility AugLagrangianFunction.
+ * @param lbfgs The custom L-BFGS optimizer to be used. This should have
+ * already been initialized with the given AugLagrangianFunction.
+ */
+ AugLagrangian(AugLagrangianFunction<LagrangianFunction>& augfunc,
+ L_BFGSType& lbfgs);
+
+ /**
+ * Optimize the function. The value '1' is used for the initial value of each
+ * Lagrange multiplier. To set the Lagrange multipliers yourself, use the
+ * other overload of Optimize().
+ *
+ * @param coordinates Output matrix to store the optimized coordinates in.
+ * @param maxIterations Maximum number of iterations of the Augmented
+ * Lagrangian algorithm. 0 indicates no maximum.
+ * @param sigma Initial penalty parameter.
+ */
+ bool Optimize(arma::mat& coordinates,
+ const size_t maxIterations = 1000);
+
+ /**
+ * Optimize the function, giving initial estimates for the Lagrange
+ * multipliers. The vector of Lagrange multipliers will be modified to
+ * contain the Lagrange multipliers of the final solution (if one is found).
+ *
+ * @param coordinates Output matrix to store the optimized coordinates in.
+ * @param initLambda Vector of initial Lagrange multipliers. Should have
+ * length equal to the number of constraints.
+ * @param initSigma Initial penalty parameter.
+ * @param maxIterations Maximum number of iterations of the Augmented
+ * Lagrangian algorithm. 0 indicates no maximum.
+ */
+ bool Optimize(arma::mat& coordinates,
+ const arma::vec& initLambda,
+ const double initSigma,
+ const size_t maxIterations = 1000);
+
+ //! Get the LagrangianFunction.
+ const LagrangianFunction& Function() const { return function; }
+ //! Modify the LagrangianFunction.
+ LagrangianFunction& Function() { return function; }
+
+ //! Get the L-BFGS object used for the actual optimization.
+ const L_BFGSType& LBFGS() const { return lbfgs; }
+ //! Modify the L-BFGS object used for the actual optimization.
+ L_BFGSType& LBFGS() { return lbfgs; }
+
+ //! Get the Lagrange multipliers.
+ const arma::vec& Lambda() const { return augfunc.Lambda(); }
+ //! Modify the Lagrange multipliers (i.e. set them before optimization).
+ arma::vec& Lambda() { return augfunc.Lambda(); }
+
+ //! Get the penalty parameter.
+ double Sigma() const { return augfunc.Sigma(); }
+ //! Modify the penalty parameter.
+ double& Sigma() { return augfunc.Sigma(); }
+
+ private:
+ //! Function to be optimized.
+ LagrangianFunction& function;
+
+ //! Internally used AugLagrangianFunction which holds the function we are
+ //! optimizing. This isn't publically accessible, but we provide ways to get
+ //! to the Lagrange multipliers and the penalty parameter sigma.
+ AugLagrangianFunction<LagrangianFunction> augfunc;
+
+ //! If the user did not pass an L_BFGS object, we'll use our own internal one.
+ L_BFGSType lbfgsInternal;
+
+ //! The L-BFGS optimizer that we will use.
+ L_BFGSType& lbfgs;
+};
+
+}; // namespace optimization
+}; // namespace mlpack
+
+#include "aug_lagrangian_impl.hpp"
+
+#endif // __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,128 +0,0 @@
-/**
- * @file aug_lagrangian_function.hpp
- * @author Ryan Curtin
- *
- * Contains a utility class for AugLagrangian.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_FUNCTION_HPP
-#define __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_FUNCTION_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace optimization {
-
-/**
- * This is a utility class used by AugLagrangian, meant to wrap a
- * LagrangianFunction into a function usable by a simple optimizer like L-BFGS.
- * Given a LagrangianFunction which follows the format outlined in the
- * documentation for AugLagrangian, this class provides Evaluate(), Gradient(),
- * and GetInitialPoint() functions which allow this class to be used with a
- * simple optimizer like L-BFGS.
- *
- * This class can be specialized for your particular implementation -- commonly,
- * a faster method for computing the overall objective and gradient of the
- * augmented Lagrangian function can be implemented than the naive, default
- * implementation given. Use class template specialization and re-implement all
- * of the methods (unfortunately, C++ specialization rules mean you have to
- * re-implement everything).
- *
- * @tparam LagrangianFunction Lagrangian function to be used.
- */
-template<typename LagrangianFunction>
-class AugLagrangianFunction
-{
- public:
- /**
- * Initialize the AugLagrangianFunction, but don't set the Lagrange
- * multipliers or penalty parameters yet. Make sure you set the Lagrange
- * multipliers before you use this...
- *
- * @param function Lagrangian function.
- */
- AugLagrangianFunction(LagrangianFunction& function);
-
- /**
- * Initialize the AugLagrangianFunction with the given LagrangianFunction,
- * Lagrange multipliers, and initial penalty parameter.
- *
- * @param function Lagrangian function.
- * @param lambda Initial Lagrange multipliers.
- * @param sigma Initial penalty parameter.
- */
- AugLagrangianFunction(LagrangianFunction& function,
- const arma::vec& lambda,
- const double sigma);
- /**
- * Evaluate the objective function of the Augmented Lagrangian function, which
- * is the standard Lagrangian function evaluation plus a penalty term, which
- * penalizes unsatisfied constraints.
- *
- * @param coordinates Coordinates to evaluate function at.
- * @return Objective function.
- */
- double Evaluate(const arma::mat& coordinates) const;
-
- /**
- * Evaluate the gradient of the Augmented Lagrangian function.
- *
- * @param coordinates Coordinates to evaluate gradient at.
- * @param gradient Matrix to store gradient into.
- */
- void Gradient(const arma::mat& coordinates, arma::mat& gradient) const;
-
- /**
- * Get the initial point of the optimization (supplied by the
- * LagrangianFunction).
- *
- * @return Initial point.
- */
- const arma::mat& GetInitialPoint() const;
-
- //! Get the Lagrange multipliers.
- const arma::vec& Lambda() const { return lambda; }
- //! Modify the Lagrange multipliers.
- arma::vec& Lambda() { return lambda; }
-
- //! Get sigma (the penalty parameter).
- double Sigma() const { return sigma; }
- //! Modify sigma (the penalty parameter).
- double& Sigma() { return sigma; }
-
- //! Get the Lagrangian function.
- const LagrangianFunction& Function() const { return function; }
- //! Modify the Lagrangian function.
- LagrangianFunction& Function() { return function; }
-
- private:
- //! Instantiation of the function to be optimized.
- LagrangianFunction& function;
-
- //! The Lagrange multipliers.
- arma::vec lambda;
- //! The penalty parameter.
- double sigma;
-};
-
-}; // namespace optimization
-}; // namespace mlpack
-
-// Include basic implementation.
-#include "aug_lagrangian_function_impl.hpp"
-
-#endif // __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_FUNCTION_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,128 @@
+/**
+ * @file aug_lagrangian_function.hpp
+ * @author Ryan Curtin
+ *
+ * Contains a utility class for AugLagrangian.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_FUNCTION_HPP
+#define __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_FUNCTION_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace optimization {
+
+/**
+ * This is a utility class used by AugLagrangian, meant to wrap a
+ * LagrangianFunction into a function usable by a simple optimizer like L-BFGS.
+ * Given a LagrangianFunction which follows the format outlined in the
+ * documentation for AugLagrangian, this class provides Evaluate(), Gradient(),
+ * and GetInitialPoint() functions which allow this class to be used with a
+ * simple optimizer like L-BFGS.
+ *
+ * This class can be specialized for your particular implementation -- commonly,
+ * a faster method for computing the overall objective and gradient of the
+ * augmented Lagrangian function can be implemented than the naive, default
+ * implementation given. Use class template specialization and re-implement all
+ * of the methods (unfortunately, C++ specialization rules mean you have to
+ * re-implement everything).
+ *
+ * @tparam LagrangianFunction Lagrangian function to be used.
+ */
+template<typename LagrangianFunction>
+class AugLagrangianFunction
+{
+ public:
+ /**
+ * Initialize the AugLagrangianFunction, but don't set the Lagrange
+ * multipliers or penalty parameters yet. Make sure you set the Lagrange
+ * multipliers before you use this...
+ *
+ * @param function Lagrangian function.
+ */
+ AugLagrangianFunction(LagrangianFunction& function);
+
+ /**
+ * Initialize the AugLagrangianFunction with the given LagrangianFunction,
+ * Lagrange multipliers, and initial penalty parameter.
+ *
+ * @param function Lagrangian function.
+ * @param lambda Initial Lagrange multipliers.
+ * @param sigma Initial penalty parameter.
+ */
+ AugLagrangianFunction(LagrangianFunction& function,
+ const arma::vec& lambda,
+ const double sigma);
+ /**
+ * Evaluate the objective function of the Augmented Lagrangian function, which
+ * is the standard Lagrangian function evaluation plus a penalty term, which
+ * penalizes unsatisfied constraints.
+ *
+ * @param coordinates Coordinates to evaluate function at.
+ * @return Objective function.
+ */
+ double Evaluate(const arma::mat& coordinates) const;
+
+ /**
+ * Evaluate the gradient of the Augmented Lagrangian function.
+ *
+ * @param coordinates Coordinates to evaluate gradient at.
+ * @param gradient Matrix to store gradient into.
+ */
+ void Gradient(const arma::mat& coordinates, arma::mat& gradient) const;
+
+ /**
+ * Get the initial point of the optimization (supplied by the
+ * LagrangianFunction).
+ *
+ * @return Initial point.
+ */
+ const arma::mat& GetInitialPoint() const;
+
+ //! Get the Lagrange multipliers.
+ const arma::vec& Lambda() const { return lambda; }
+ //! Modify the Lagrange multipliers.
+ arma::vec& Lambda() { return lambda; }
+
+ //! Get sigma (the penalty parameter).
+ double Sigma() const { return sigma; }
+ //! Modify sigma (the penalty parameter).
+ double& Sigma() { return sigma; }
+
+ //! Get the Lagrangian function.
+ const LagrangianFunction& Function() const { return function; }
+ //! Modify the Lagrangian function.
+ LagrangianFunction& Function() { return function; }
+
+ private:
+ //! Instantiation of the function to be optimized.
+ LagrangianFunction& function;
+
+ //! The Lagrange multipliers.
+ arma::vec lambda;
+ //! The penalty parameter.
+ double sigma;
+};
+
+}; // namespace optimization
+}; // namespace mlpack
+
+// Include basic implementation.
+#include "aug_lagrangian_function_impl.hpp"
+
+#endif // __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_FUNCTION_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,115 +0,0 @@
-/**
- * @file aug_lagrangian_function_impl.hpp
- * @author Ryan Curtin
- *
- * Simple, naive implementation of AugLagrangianFunction. Better
- * specializations can probably be given in many cases, but this is the most
- * general case.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_FUNCTION_IMPL_HPP
-#define __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_FUNCTION_IMPL_HPP
-
-// In case it hasn't been included.
-#include "aug_lagrangian_function.hpp"
-
-namespace mlpack {
-namespace optimization {
-
-// Initialize the AugLagrangianFunction.
-template<typename LagrangianFunction>
-AugLagrangianFunction<LagrangianFunction>::AugLagrangianFunction(
- LagrangianFunction& function) :
- function(function),
- lambda(function.NumConstraints()),
- sigma(10)
-{
- // Nothing else to do.
-}
-
-// Initialize the AugLagrangianFunction.
-template<typename LagrangianFunction>
-AugLagrangianFunction<LagrangianFunction>::AugLagrangianFunction(
- LagrangianFunction& function,
- const arma::vec& lambda,
- const double sigma) :
- lambda(lambda),
- sigma(sigma),
- function(function)
-{
- // Nothing else to do.
-}
-
-// Evaluate the AugLagrangianFunction at the given coordinates.
-template<typename LagrangianFunction>
-double AugLagrangianFunction<LagrangianFunction>::Evaluate(
- const arma::mat& coordinates) const
-{
- // The augmented Lagrangian is evaluated as
- // f(x) + {-lambda_i * c_i(x) + (sigma / 2) c_i(x)^2} for all constraints
-
- // First get the function's objective value.
- double objective = function.Evaluate(coordinates);
-
- // Now loop for each constraint.
- for (size_t i = 0; i < function.NumConstraints(); ++i)
- {
- double constraint = function.EvaluateConstraint(i, coordinates);
-
- objective += (-lambda[i] * constraint) +
- sigma * std::pow(constraint, 2) / 2;
- }
-
- return objective;
-}
-
-// Evaluate the gradient of the AugLagrangianFunction at the given coordinates.
-template<typename LagrangianFunction>
-void AugLagrangianFunction<LagrangianFunction>::Gradient(
- const arma::mat& coordinates,
- arma::mat& gradient) const
-{
- // The augmented Lagrangian's gradient is evaluted as
- // f'(x) + {(-lambda_i + sigma * c_i(x)) * c'_i(x)} for all constraints
- gradient.zeros();
- function.Gradient(coordinates, gradient);
-
- arma::mat constraintGradient; // Temporary for constraint gradients.
- for (size_t i = 0; i < function.NumConstraints(); i++)
- {
- function.GradientConstraint(i, coordinates, constraintGradient);
-
- // Now calculate scaling factor and add to existing gradient.
- arma::mat tmpGradient;
- tmpGradient = (-lambda[i] + sigma *
- function.EvaluateConstraint(i, coordinates)) * constraintGradient;
- gradient += tmpGradient;
- }
-}
-
-// Get the initial point.
-template<typename LagrangianFunction>
-const arma::mat& AugLagrangianFunction<LagrangianFunction>::GetInitialPoint()
- const
-{
- return function.GetInitialPoint();
-}
-
-}; // namespace optimization
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,115 @@
+/**
+ * @file aug_lagrangian_function_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Simple, naive implementation of AugLagrangianFunction. Better
+ * specializations can probably be given in many cases, but this is the most
+ * general case.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_FUNCTION_IMPL_HPP
+#define __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_FUNCTION_IMPL_HPP
+
+// In case it hasn't been included.
+#include "aug_lagrangian_function.hpp"
+
+namespace mlpack {
+namespace optimization {
+
+// Initialize the AugLagrangianFunction.
+template<typename LagrangianFunction>
+AugLagrangianFunction<LagrangianFunction>::AugLagrangianFunction(
+ LagrangianFunction& function) :
+ function(function),
+ lambda(function.NumConstraints()),
+ sigma(10)
+{
+ // Nothing else to do.
+}
+
+// Initialize the AugLagrangianFunction.
+template<typename LagrangianFunction>
+AugLagrangianFunction<LagrangianFunction>::AugLagrangianFunction(
+ LagrangianFunction& function,
+ const arma::vec& lambda,
+ const double sigma) :
+ lambda(lambda),
+ sigma(sigma),
+ function(function)
+{
+ // Nothing else to do.
+}
+
+// Evaluate the AugLagrangianFunction at the given coordinates.
+template<typename LagrangianFunction>
+double AugLagrangianFunction<LagrangianFunction>::Evaluate(
+ const arma::mat& coordinates) const
+{
+ // The augmented Lagrangian is evaluated as
+ // f(x) + {-lambda_i * c_i(x) + (sigma / 2) c_i(x)^2} for all constraints
+
+ // First get the function's objective value.
+ double objective = function.Evaluate(coordinates);
+
+ // Now loop for each constraint.
+ for (size_t i = 0; i < function.NumConstraints(); ++i)
+ {
+ double constraint = function.EvaluateConstraint(i, coordinates);
+
+ objective += (-lambda[i] * constraint) +
+ sigma * std::pow(constraint, 2) / 2;
+ }
+
+ return objective;
+}
+
+// Evaluate the gradient of the AugLagrangianFunction at the given coordinates.
+template<typename LagrangianFunction>
+void AugLagrangianFunction<LagrangianFunction>::Gradient(
+ const arma::mat& coordinates,
+ arma::mat& gradient) const
+{
+ // The augmented Lagrangian's gradient is evaluted as
+ // f'(x) + {(-lambda_i + sigma * c_i(x)) * c'_i(x)} for all constraints
+ gradient.zeros();
+ function.Gradient(coordinates, gradient);
+
+ arma::mat constraintGradient; // Temporary for constraint gradients.
+ for (size_t i = 0; i < function.NumConstraints(); i++)
+ {
+ function.GradientConstraint(i, coordinates, constraintGradient);
+
+ // Now calculate scaling factor and add to existing gradient.
+ arma::mat tmpGradient;
+ tmpGradient = (-lambda[i] + sigma *
+ function.EvaluateConstraint(i, coordinates)) * constraintGradient;
+ gradient += tmpGradient;
+ }
+}
+
+// Get the initial point.
+template<typename LagrangianFunction>
+const arma::mat& AugLagrangianFunction<LagrangianFunction>::GetInitialPoint()
+ const
+{
+ return function.GetInitialPoint();
+}
+
+}; // namespace optimization
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,163 +0,0 @@
-/**
- * @file aug_lagrangian_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of AugLagrangian class (Augmented Lagrangian optimization
- * method).
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#ifndef __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_IMPL_HPP
-#define __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_IMPL_HPP
-
-#include <mlpack/core/optimizers/lbfgs/lbfgs.hpp>
-#include "aug_lagrangian_function.hpp"
-
-namespace mlpack {
-namespace optimization {
-
-template<typename LagrangianFunction>
-AugLagrangian<LagrangianFunction>::AugLagrangian(LagrangianFunction& function) :
- function(function),
- augfunc(function),
- lbfgsInternal(augfunc),
- lbfgs(lbfgsInternal)
-{
- lbfgs.MaxIterations() = 1000;
-}
-
-template<typename LagrangianFunction>
-AugLagrangian<LagrangianFunction>::AugLagrangian(
- AugLagrangianFunction<LagrangianFunction>& augfunc,
- L_BFGSType& lbfgs) :
- function(augfunc.Function()),
- augfunc(augfunc),
- lbfgs(lbfgs)
-{
- // Nothing to do. lbfgsInternal isn't used in this case.
-}
-
-// This overload just sets the lambda and sigma and calls the other overload.
-template<typename LagrangianFunction>
-bool AugLagrangian<LagrangianFunction>::Optimize(arma::mat& coordinates,
- const arma::vec& initLambda,
- const double initSigma,
- const size_t maxIterations)
-{
- augfunc.Lambda() = initLambda;
- augfunc.Sigma() = initSigma;
-
- return Optimize(coordinates, maxIterations);
-}
-
-template<typename LagrangianFunction>
-bool AugLagrangian<LagrangianFunction>::Optimize(arma::mat& coordinates,
- const size_t maxIterations)
-{
- // Ensure that we update lambda immediately.
- double penaltyThreshold = DBL_MAX;
-
- // Track the last objective to compare for convergence.
- double lastObjective = function.Evaluate(coordinates);
-
- // Then, calculate the current penalty.
- double penalty = 0;
- for (size_t i = 0; i < function.NumConstraints(); i++)
- penalty += std::pow(function.EvaluateConstraint(i, coordinates), 2);
-
- Log::Debug << "Penalty is " << penalty << " (threshold " << penaltyThreshold
- << ")." << std::endl;
-
- // The odd comparison allows user to pass maxIterations = 0 (i.e. no limit on
- // number of iterations).
- size_t it;
- for (it = 0; it != (maxIterations - 1); it++)
- {
- Log::Warn << "AugLagrangian on iteration " << it
- << ", starting with objective " << lastObjective << "." << std::endl;
-
- // Log::Warn << coordinates << std::endl;
-
-// Log::Warn << trans(coordinates) * coordinates << std::endl;
-
- if (!lbfgs.Optimize(coordinates))
- Log::Warn << "L-BFGS reported an error during optimization."
- << std::endl;
-
- // Check if we are done with the entire optimization (the threshold we are
- // comparing with is arbitrary).
- if (std::abs(lastObjective - function.Evaluate(coordinates)) < 1e-10 &&
- augfunc.Sigma() > 500000)
- return true;
-
- lastObjective = function.Evaluate(coordinates);
-
- // Assuming that the optimization has converged to a new set of coordinates,
- // we now update either lambda or sigma. We update sigma if the penalty
- // term is too high, and we update lambda otherwise.
-
- // First, calculate the current penalty.
- double penalty = 0;
- for (size_t i = 0; i < function.NumConstraints(); i++)
- {
- penalty += std::pow(function.EvaluateConstraint(i, coordinates), 2);
-// Log::Debug << "Constraint " << i << " is " <<
-// function.EvaluateConstraint(i, coordinates) << std::endl;
- }
-
- Log::Warn << "Penalty is " << penalty << " (threshold "
- << penaltyThreshold << ")." << std::endl;
-
- for (size_t i = 0; i < function.NumConstraints(); ++i)
- {
-// arma::mat tmpgrad;
-// function.GradientConstraint(i, coordinates, tmpgrad);
-// Log::Debug << "Gradient of constraint " << i << " is " << std::endl;
-// Log::Debug << tmpgrad << std::endl;
- }
-
- if (penalty < penaltyThreshold) // We update lambda.
- {
- // We use the update: lambda_{k + 1} = lambda_k - sigma * c(coordinates),
- // but we have to write a loop to do this for each constraint.
- for (size_t i = 0; i < function.NumConstraints(); i++)
- augfunc.Lambda()[i] -= augfunc.Sigma() *
- function.EvaluateConstraint(i, coordinates);
-
- // We also update the penalty threshold to be a factor of the current
- // penalty. TODO: this factor should be a parameter (from CLI). The
- // value of 0.25 is taken from Burer and Monteiro (2002).
- penaltyThreshold = 0.25 * penalty;
- Log::Warn << "Lagrange multiplier estimates updated." << std::endl;
- }
- else
- {
- // We multiply sigma by a constant value. TODO: this factor should be a
- // parameter (from CLI). The value of 10 is taken from Burer and Monteiro
- // (2002).
- augfunc.Sigma() *= 10;
- Log::Warn << "Updated sigma to " << augfunc.Sigma() << "." << std::endl;
- }
- }
-
- return false;
-}
-
-}; // namespace optimization
-}; // namespace mlpack
-
-#endif // __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_IMPL_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,163 @@
+/**
+ * @file aug_lagrangian_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of AugLagrangian class (Augmented Lagrangian optimization
+ * method).
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_IMPL_HPP
+#define __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_IMPL_HPP
+
+#include <mlpack/core/optimizers/lbfgs/lbfgs.hpp>
+#include "aug_lagrangian_function.hpp"
+
+namespace mlpack {
+namespace optimization {
+
+template<typename LagrangianFunction>
+AugLagrangian<LagrangianFunction>::AugLagrangian(LagrangianFunction& function) :
+ function(function),
+ augfunc(function),
+ lbfgsInternal(augfunc),
+ lbfgs(lbfgsInternal)
+{
+ lbfgs.MaxIterations() = 1000;
+}
+
+template<typename LagrangianFunction>
+AugLagrangian<LagrangianFunction>::AugLagrangian(
+ AugLagrangianFunction<LagrangianFunction>& augfunc,
+ L_BFGSType& lbfgs) :
+ function(augfunc.Function()),
+ augfunc(augfunc),
+ lbfgs(lbfgs)
+{
+ // Nothing to do. lbfgsInternal isn't used in this case.
+}
+
+// This overload just sets the lambda and sigma and calls the other overload.
+template<typename LagrangianFunction>
+bool AugLagrangian<LagrangianFunction>::Optimize(arma::mat& coordinates,
+ const arma::vec& initLambda,
+ const double initSigma,
+ const size_t maxIterations)
+{
+ augfunc.Lambda() = initLambda;
+ augfunc.Sigma() = initSigma;
+
+ return Optimize(coordinates, maxIterations);
+}
+
+template<typename LagrangianFunction>
+bool AugLagrangian<LagrangianFunction>::Optimize(arma::mat& coordinates,
+ const size_t maxIterations)
+{
+ // Ensure that we update lambda immediately.
+ double penaltyThreshold = DBL_MAX;
+
+ // Track the last objective to compare for convergence.
+ double lastObjective = function.Evaluate(coordinates);
+
+ // Then, calculate the current penalty.
+ double penalty = 0;
+ for (size_t i = 0; i < function.NumConstraints(); i++)
+ penalty += std::pow(function.EvaluateConstraint(i, coordinates), 2);
+
+ Log::Debug << "Penalty is " << penalty << " (threshold " << penaltyThreshold
+ << ")." << std::endl;
+
+ // The odd comparison allows user to pass maxIterations = 0 (i.e. no limit on
+ // number of iterations).
+ size_t it;
+ for (it = 0; it != (maxIterations - 1); it++)
+ {
+ Log::Warn << "AugLagrangian on iteration " << it
+ << ", starting with objective " << lastObjective << "." << std::endl;
+
+ // Log::Warn << coordinates << std::endl;
+
+// Log::Warn << trans(coordinates) * coordinates << std::endl;
+
+ if (!lbfgs.Optimize(coordinates))
+ Log::Warn << "L-BFGS reported an error during optimization."
+ << std::endl;
+
+ // Check if we are done with the entire optimization (the threshold we are
+ // comparing with is arbitrary).
+ if (std::abs(lastObjective - function.Evaluate(coordinates)) < 1e-10 &&
+ augfunc.Sigma() > 500000)
+ return true;
+
+ lastObjective = function.Evaluate(coordinates);
+
+ // Assuming that the optimization has converged to a new set of coordinates,
+ // we now update either lambda or sigma. We update sigma if the penalty
+ // term is too high, and we update lambda otherwise.
+
+ // First, calculate the current penalty.
+ double penalty = 0;
+ for (size_t i = 0; i < function.NumConstraints(); i++)
+ {
+ penalty += std::pow(function.EvaluateConstraint(i, coordinates), 2);
+// Log::Debug << "Constraint " << i << " is " <<
+// function.EvaluateConstraint(i, coordinates) << std::endl;
+ }
+
+ Log::Warn << "Penalty is " << penalty << " (threshold "
+ << penaltyThreshold << ")." << std::endl;
+
+ for (size_t i = 0; i < function.NumConstraints(); ++i)
+ {
+// arma::mat tmpgrad;
+// function.GradientConstraint(i, coordinates, tmpgrad);
+// Log::Debug << "Gradient of constraint " << i << " is " << std::endl;
+// Log::Debug << tmpgrad << std::endl;
+ }
+
+ if (penalty < penaltyThreshold) // We update lambda.
+ {
+ // We use the update: lambda_{k + 1} = lambda_k - sigma * c(coordinates),
+ // but we have to write a loop to do this for each constraint.
+ for (size_t i = 0; i < function.NumConstraints(); i++)
+ augfunc.Lambda()[i] -= augfunc.Sigma() *
+ function.EvaluateConstraint(i, coordinates);
+
+ // We also update the penalty threshold to be a factor of the current
+ // penalty. TODO: this factor should be a parameter (from CLI). The
+ // value of 0.25 is taken from Burer and Monteiro (2002).
+ penaltyThreshold = 0.25 * penalty;
+ Log::Warn << "Lagrange multiplier estimates updated." << std::endl;
+ }
+ else
+ {
+ // We multiply sigma by a constant value. TODO: this factor should be a
+ // parameter (from CLI). The value of 10 is taken from Burer and Monteiro
+ // (2002).
+ augfunc.Sigma() *= 10;
+ Log::Warn << "Updated sigma to " << augfunc.Sigma() << "." << std::endl;
+ }
+ }
+
+ return false;
+}
+
+}; // namespace optimization
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_IMPL_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,402 +0,0 @@
-/**
- * @file aug_lagrangian_test_functions.cpp
- * @author Ryan Curtin
- *
- * Implementation of AugLagrangianTestFunction class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#include "aug_lagrangian_test_functions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::optimization;
-
-//
-// AugLagrangianTestFunction
-//
-AugLagrangianTestFunction::AugLagrangianTestFunction()
-{
- // Set the initial point to be (0, 0).
- initialPoint.zeros(2, 1);
-}
-
-AugLagrangianTestFunction::AugLagrangianTestFunction(
- const arma::mat& initialPoint) :
- initialPoint(initialPoint)
-{
- // Nothing to do.
-}
-
-double AugLagrangianTestFunction::Evaluate(const arma::mat& coordinates)
-{
- // f(x) = 6 x_1^2 + 4 x_1 x_2 + 3 x_2^2
- return ((6 * std::pow(coordinates[0], 2)) +
- (4 * (coordinates[0] * coordinates[1])) +
- (3 * std::pow(coordinates[1], 2)));
-}
-
-void AugLagrangianTestFunction::Gradient(const arma::mat& coordinates,
- arma::mat& gradient)
-{
- // f'_x1(x) = 12 x_1 + 4 x_2
- // f'_x2(x) = 4 x_1 + 6 x_2
- gradient.set_size(2, 1);
-
- gradient[0] = 12 * coordinates[0] + 4 * coordinates[1];
- gradient[1] = 4 * coordinates[0] + 6 * coordinates[1];
-}
-
-double AugLagrangianTestFunction::EvaluateConstraint(const size_t index,
- const arma::mat& coordinates)
-{
- // We return 0 if the index is wrong (not 0).
- if (index != 0)
- return 0;
-
- // c(x) = x_1 + x_2 - 5
- return (coordinates[0] + coordinates[1] - 5);
-}
-
-void AugLagrangianTestFunction::GradientConstraint(const size_t index,
- const arma::mat& /* coordinates */,
- arma::mat& gradient)
-{
- // If the user passed an invalid index (not 0), we will return a zero
- // gradient.
- gradient.zeros(2, 1);
-
- if (index == 0)
- {
- // c'_x1(x) = 1
- // c'_x2(x) = 1
- gradient.ones(2, 1); // Use a shortcut instead of assigning individually.
- }
-}
-
-//
-// GockenbachFunction
-//
-GockenbachFunction::GockenbachFunction()
-{
- // Set the initial point to (0, 0, 1).
- initialPoint.zeros(3, 1);
- initialPoint[2] = 1;
-}
-
-GockenbachFunction::GockenbachFunction(const arma::mat& initialPoint) :
- initialPoint(initialPoint)
-{
- // Nothing to do.
-}
-
-double GockenbachFunction::Evaluate(const arma::mat& coordinates)
-{
- // f(x) = (x_1 - 1)^2 + 2 (x_2 + 2)^2 + 3(x_3 + 3)^2
- return ((std::pow(coordinates[0] - 1, 2)) +
- (2 * std::pow(coordinates[1] + 2, 2)) +
- (3 * std::pow(coordinates[2] + 3, 2)));
-}
-
-void GockenbachFunction::Gradient(const arma::mat& coordinates,
- arma::mat& gradient)
-{
- // f'_x1(x) = 2 (x_1 - 1)
- // f'_x2(x) = 4 (x_2 + 2)
- // f'_x3(x) = 6 (x_3 + 3)
- gradient.set_size(3, 1);
-
- gradient[0] = 2 * (coordinates[0] - 1);
- gradient[1] = 4 * (coordinates[1] + 2);
- gradient[2] = 6 * (coordinates[2] + 3);
-}
-
-double GockenbachFunction::EvaluateConstraint(const size_t index,
- const arma::mat& coordinates)
-{
- double constraint = 0;
-
- switch (index)
- {
- case 0: // g(x) = (x_3 - x_2 - x_1 - 1) = 0
- constraint = (coordinates[2] - coordinates[1] - coordinates[0] - 1);
- break;
-
- case 1: // h(x) = (x_3 - x_1^2) >= 0
- // To deal with the inequality, the constraint will simply evaluate to 0
- // when h(x) >= 0.
- constraint = std::min(0.0,
- (coordinates[2] - std::pow(coordinates[0], 2)));
- break;
- }
-
- // 0 will be returned for an invalid index (but this is okay).
- return constraint;
-}
-
-void GockenbachFunction::GradientConstraint(const size_t index,
- const arma::mat& coordinates,
- arma::mat& gradient)
-{
- gradient.zeros(3, 1);
-
- switch (index)
- {
- case 0:
- // g'_x1(x) = -1
- // g'_x2(x) = -1
- // g'_x3(x) = 1
- gradient[0] = -1;
- gradient[1] = -1;
- gradient[2] = 1;
- break;
-
- case 1:
- // h'_x1(x) = -2 x_1
- // h'_x2(x) = 0
- // h'_x3(x) = 1
- gradient[0] = -2 * coordinates[0];
- gradient[2] = 1;
- break;
- }
-}
-
-//
-// LovaszThetaSDP
-//
-LovaszThetaSDP::LovaszThetaSDP() : edges(0), vertices(0), initialPoint(0, 0)
-{ }
-
-LovaszThetaSDP::LovaszThetaSDP(const arma::mat& edges) : edges(edges),
- initialPoint(0, 0)
-{
- // Calculate V by finding the maximum index in the edges matrix.
- vertices = max(max(edges)) + 1;
-// Log::Debug << vertices << " vertices in graph." << std::endl;
-}
-
-double LovaszThetaSDP::Evaluate(const arma::mat& coordinates)
-{
- // The objective is equal to -Tr(ones * X) = -Tr(ones * (R^T * R)).
- // This can be simplified into the negative sum of (R^T * R).
-// Log::Debug << "Evaluting objective function with coordinates:" << std::endl;
-// std::cout << coordinates << std::endl;
-// Log::Debug << "trans(coord) * coord:" << std::endl;
-// std::cout << (trans(coordinates) * coordinates) << std::endl;
-
-
- arma::mat x = trans(coordinates) * coordinates;
- double obj = -accu(x);
-
-// double obj = 0;
-// for (size_t i = 0; i < coordinates.n_cols; i++)
-// obj -= dot(coordinates.col(i), coordinates.col(i));
-
-// Log::Debug << "Objective function is " << obj << "." << std::endl;
-
- return obj;
-}
-
-void LovaszThetaSDP::Gradient(const arma::mat& coordinates,
- arma::mat& gradient)
-{
-
- // The gradient is equal to (2 S' R^T)^T, with R being coordinates.
- // S' = C - sum_{i = 1}^{m} [ y_i - sigma (Tr(A_i * (R^T R)) - b_i)] * A_i
- // We will calculate it in a not very smart way, but it should work.
- // Log::Warn << "Using stupid specialization for gradient calculation!"
- // << std::endl;
-
- // Initialize S' piece by piece. It is of size n x n.
- const size_t n = coordinates.n_cols;
- arma::mat s(n, n);
- s.ones();
- s *= -1; // C = -ones().
-
- for (size_t i = 0; i < NumConstraints(); ++i)
- {
- // Calculate [ y_i - sigma (Tr(A_i * (R^T R)) - b_i) ] * A_i.
- // Result will be a matrix; inner result is a scalar.
- if (i == 0)
- {
- // A_0 = I_n. Hooray! That's easy! b_0 = 1.
- double inner = -1 * double(n) - 0.5 *
- (trace(trans(coordinates) * coordinates) - 1);
-
- arma::mat zz = (inner * arma::eye<arma::mat>(n, n));
-
-// Log::Debug << "Constraint " << i << " matrix to add is " << std::endl;
-// Log::Debug << zz << std::endl;
-
- s -= zz;
- }
- else
- {
- // Get edge so we can construct constraint A_i matrix. b_i = 0.
- arma::vec edge = edges.col(i - 1);
-
- arma::mat a;
- a.zeros(n, n);
-
- // Only two nonzero entries.
- a(edge[0], edge[1]) = 1;
- a(edge[1], edge[0]) = 1;
-
- double inner = (-1) - 0.5 *
- (trace(a * (trans(coordinates) * coordinates)));
-
- arma::mat zz = (inner * a);
-
-// Log::Debug << "Constraint " << i << " matrix to add is " << std::endl;
-// Log::Debug << zz << std::endl;
-
- s -= zz;
- }
- }
-
-// Log::Warn << "Calculated S is: " << std::endl << s << std::endl;
-
- gradient = trans(2 * s * trans(coordinates));
-
-// Log::Warn << "Calculated gradient is: " << std::endl << gradient << std::endl;
-
-
-// Log::Debug << "Evaluating gradient. " << std::endl;
-
- // The gradient of -Tr(ones * X) is equal to -2 * ones * R
-// arma::mat ones;
-// ones.ones(coordinates.n_rows, coordinates.n_rows);
-// gradient = -2 * ones * coordinates;
-
-// Log::Debug << "Done with gradient." << std::endl;
-// std::cout << gradient;
-}
-
-size_t LovaszThetaSDP::NumConstraints() const
-{
- // Each edge is a constraint, and we have the constraint Tr(X) = 1.
- return edges.n_cols + 1;
-}
-
-double LovaszThetaSDP::EvaluateConstraint(const size_t index,
- const arma::mat& coordinates)
-{
- if (index == 0) // This is the constraint Tr(X) = 1.
- {
- double sum = -1; // Tr(X) - 1 = 0, so we prefix the subtraction.
- for (size_t i = 0; i < coordinates.n_cols; i++)
- sum += std::abs(dot(coordinates.col(i), coordinates.col(i)));
-
-// Log::Debug << "Constraint " << index << " evaluates to " << sum << std::endl;
- return sum;
- }
-
- size_t i = edges(0, index - 1);
- size_t j = edges(1, index - 1);
-
-// Log::Debug << "Constraint " << index << " evaluates to " <<
-// dot(coordinates.col(i), coordinates.col(j)) << "." << std::endl;
-
- // The constraint itself is X_ij, or (R^T R)_ij.
- return std::abs(dot(coordinates.col(i), coordinates.col(j)));
-}
-
-void LovaszThetaSDP::GradientConstraint(const size_t index,
- const arma::mat& coordinates,
- arma::mat& gradient)
-{
-// Log::Debug << "Gradient of constraint " << index << " is " << std::endl;
- if (index == 0) // This is the constraint Tr(X) = 1.
- {
- gradient = 2 * coordinates; // d/dR (Tr(R R^T)) = 2 R.
-// std::cout << gradient;
- return;
- }
-
-// Log::Debug << "Evaluating gradient of constraint " << index << " with ";
- size_t i = edges(0, index - 1);
- size_t j = edges(1, index - 1);
-// Log::Debug << "i = " << i << " and j = " << j << "." << std::endl;
-
- // Since the constraint is (R^T R)_ij, the gradient for (x, y) will be (I
- // derived this for one of the MVU constraints):
- // 0 , y != i, y != j
- // 2 R_xj, y = i, y != j
- // 2 R_xi, y != i, y = j
- // 4 R_xy, y = i, y = j
- // This results in the gradient matrix having two nonzero rows; for row
- // i, the elements are R_nj, where n is the row; for column j, the elements
- // are R_ni.
- gradient.zeros(coordinates.n_rows, coordinates.n_cols);
-
- gradient.col(i) = coordinates.col(j);
- gradient.col(j) += coordinates.col(i); // In case j = i (shouldn't happen).
-
-// std::cout << gradient;
-}
-
-const arma::mat& LovaszThetaSDP::GetInitialPoint()
-{
- if (initialPoint.n_rows != 0 && initialPoint.n_cols != 0)
- return initialPoint; // It has already been calculated.
-
-// Log::Debug << "Calculating initial point." << std::endl;
-
- // First, we must calculate the correct value of r. The matrix we return, R,
- // will be r x V, because X = R^T R is of dimension V x V.
- // The rule for calculating r (from Monteiro and Burer, eq. 5) is
- // r = max(r >= 0 : r (r + 1) / 2 <= m }
- // where m is equal to the number of constraints plus one.
- //
- // Solved, this is
- // 0.5 r^2 + 0.5 r - m = 0
- // which becomes
- // r = (-0.5 [+/-] sqrt((-0.5)^2 - 4 * -0.5 * m)) / -1
- // r = 0.5 [+/-] sqrt(0.25 + 2 m)
- // and because m is always positive,
- // r = 0.5 + sqrt(0.25 + 2m)
- float m = NumConstraints();
- float r = 0.5 + sqrt(0.25 + 2 * m);
- if (ceil(r) > vertices)
- r = vertices; // An upper bound on the dimension.
-
- Log::Debug << "Dimension will be " << ceil(r) << " x " << vertices << "."
- << std::endl;
-
- initialPoint.set_size(ceil(r), vertices);
-
- // Now we set the entries of the initial matrix according to the formula given
- // in Section 4 of Monteiro and Burer.
- for (size_t i = 0; i < r; i++)
- {
- for (size_t j = 0; j < (size_t) vertices; j++)
- {
- if (i == j)
- initialPoint(i, j) = sqrt(1.0 / r) + sqrt(1.0 / (vertices * m));
- else
- initialPoint(i, j) = sqrt(1.0 / (vertices * m));
- }
- }
-
- Log::Debug << "Initial matrix " << std::endl << initialPoint << std::endl;
-
- Log::Warn << "X " << std::endl << trans(initialPoint) * initialPoint
- << std::endl;
-
- Log::Warn << "accu " << accu(trans(initialPoint) * initialPoint) << std::endl;
-
- return initialPoint;
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,402 @@
+/**
+ * @file aug_lagrangian_test_functions.cpp
+ * @author Ryan Curtin
+ *
+ * Implementation of AugLagrangianTestFunction class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include "aug_lagrangian_test_functions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::optimization;
+
+//
+// AugLagrangianTestFunction
+//
+AugLagrangianTestFunction::AugLagrangianTestFunction()
+{
+ // Set the initial point to be (0, 0).
+ initialPoint.zeros(2, 1);
+}
+
+AugLagrangianTestFunction::AugLagrangianTestFunction(
+ const arma::mat& initialPoint) :
+ initialPoint(initialPoint)
+{
+ // Nothing to do.
+}
+
+double AugLagrangianTestFunction::Evaluate(const arma::mat& coordinates)
+{
+ // f(x) = 6 x_1^2 + 4 x_1 x_2 + 3 x_2^2
+ return ((6 * std::pow(coordinates[0], 2)) +
+ (4 * (coordinates[0] * coordinates[1])) +
+ (3 * std::pow(coordinates[1], 2)));
+}
+
+void AugLagrangianTestFunction::Gradient(const arma::mat& coordinates,
+ arma::mat& gradient)
+{
+ // f'_x1(x) = 12 x_1 + 4 x_2
+ // f'_x2(x) = 4 x_1 + 6 x_2
+ gradient.set_size(2, 1);
+
+ gradient[0] = 12 * coordinates[0] + 4 * coordinates[1];
+ gradient[1] = 4 * coordinates[0] + 6 * coordinates[1];
+}
+
+double AugLagrangianTestFunction::EvaluateConstraint(const size_t index,
+ const arma::mat& coordinates)
+{
+ // We return 0 if the index is wrong (not 0).
+ if (index != 0)
+ return 0;
+
+ // c(x) = x_1 + x_2 - 5
+ return (coordinates[0] + coordinates[1] - 5);
+}
+
+void AugLagrangianTestFunction::GradientConstraint(const size_t index,
+ const arma::mat& /* coordinates */,
+ arma::mat& gradient)
+{
+ // If the user passed an invalid index (not 0), we will return a zero
+ // gradient.
+ gradient.zeros(2, 1);
+
+ if (index == 0)
+ {
+ // c'_x1(x) = 1
+ // c'_x2(x) = 1
+ gradient.ones(2, 1); // Use a shortcut instead of assigning individually.
+ }
+}
+
+//
+// GockenbachFunction
+//
+GockenbachFunction::GockenbachFunction()
+{
+ // Set the initial point to (0, 0, 1).
+ initialPoint.zeros(3, 1);
+ initialPoint[2] = 1;
+}
+
+GockenbachFunction::GockenbachFunction(const arma::mat& initialPoint) :
+ initialPoint(initialPoint)
+{
+ // Nothing to do.
+}
+
+double GockenbachFunction::Evaluate(const arma::mat& coordinates)
+{
+ // f(x) = (x_1 - 1)^2 + 2 (x_2 + 2)^2 + 3(x_3 + 3)^2
+ return ((std::pow(coordinates[0] - 1, 2)) +
+ (2 * std::pow(coordinates[1] + 2, 2)) +
+ (3 * std::pow(coordinates[2] + 3, 2)));
+}
+
+void GockenbachFunction::Gradient(const arma::mat& coordinates,
+ arma::mat& gradient)
+{
+ // f'_x1(x) = 2 (x_1 - 1)
+ // f'_x2(x) = 4 (x_2 + 2)
+ // f'_x3(x) = 6 (x_3 + 3)
+ gradient.set_size(3, 1);
+
+ gradient[0] = 2 * (coordinates[0] - 1);
+ gradient[1] = 4 * (coordinates[1] + 2);
+ gradient[2] = 6 * (coordinates[2] + 3);
+}
+
+double GockenbachFunction::EvaluateConstraint(const size_t index,
+ const arma::mat& coordinates)
+{
+ double constraint = 0;
+
+ switch (index)
+ {
+ case 0: // g(x) = (x_3 - x_2 - x_1 - 1) = 0
+ constraint = (coordinates[2] - coordinates[1] - coordinates[0] - 1);
+ break;
+
+ case 1: // h(x) = (x_3 - x_1^2) >= 0
+ // To deal with the inequality, the constraint will simply evaluate to 0
+ // when h(x) >= 0.
+ constraint = std::min(0.0,
+ (coordinates[2] - std::pow(coordinates[0], 2)));
+ break;
+ }
+
+ // 0 will be returned for an invalid index (but this is okay).
+ return constraint;
+}
+
+void GockenbachFunction::GradientConstraint(const size_t index,
+ const arma::mat& coordinates,
+ arma::mat& gradient)
+{
+ gradient.zeros(3, 1);
+
+ switch (index)
+ {
+ case 0:
+ // g'_x1(x) = -1
+ // g'_x2(x) = -1
+ // g'_x3(x) = 1
+ gradient[0] = -1;
+ gradient[1] = -1;
+ gradient[2] = 1;
+ break;
+
+ case 1:
+ // h'_x1(x) = -2 x_1
+ // h'_x2(x) = 0
+ // h'_x3(x) = 1
+ gradient[0] = -2 * coordinates[0];
+ gradient[2] = 1;
+ break;
+ }
+}
+
+//
+// LovaszThetaSDP
+//
+LovaszThetaSDP::LovaszThetaSDP() : edges(0), vertices(0), initialPoint(0, 0)
+{ }
+
+LovaszThetaSDP::LovaszThetaSDP(const arma::mat& edges) : edges(edges),
+ initialPoint(0, 0)
+{
+ // Calculate V by finding the maximum index in the edges matrix.
+ vertices = max(max(edges)) + 1;
+// Log::Debug << vertices << " vertices in graph." << std::endl;
+}
+
+double LovaszThetaSDP::Evaluate(const arma::mat& coordinates)
+{
+ // The objective is equal to -Tr(ones * X) = -Tr(ones * (R^T * R)).
+ // This can be simplified into the negative sum of (R^T * R).
+// Log::Debug << "Evaluting objective function with coordinates:" << std::endl;
+// std::cout << coordinates << std::endl;
+// Log::Debug << "trans(coord) * coord:" << std::endl;
+// std::cout << (trans(coordinates) * coordinates) << std::endl;
+
+
+ arma::mat x = trans(coordinates) * coordinates;
+ double obj = -accu(x);
+
+// double obj = 0;
+// for (size_t i = 0; i < coordinates.n_cols; i++)
+// obj -= dot(coordinates.col(i), coordinates.col(i));
+
+// Log::Debug << "Objective function is " << obj << "." << std::endl;
+
+ return obj;
+}
+
+void LovaszThetaSDP::Gradient(const arma::mat& coordinates,
+ arma::mat& gradient)
+{
+
+ // The gradient is equal to (2 S' R^T)^T, with R being coordinates.
+ // S' = C - sum_{i = 1}^{m} [ y_i - sigma (Tr(A_i * (R^T R)) - b_i)] * A_i
+ // We will calculate it in a not very smart way, but it should work.
+ // Log::Warn << "Using stupid specialization for gradient calculation!"
+ // << std::endl;
+
+ // Initialize S' piece by piece. It is of size n x n.
+ const size_t n = coordinates.n_cols;
+ arma::mat s(n, n);
+ s.ones();
+ s *= -1; // C = -ones().
+
+ for (size_t i = 0; i < NumConstraints(); ++i)
+ {
+ // Calculate [ y_i - sigma (Tr(A_i * (R^T R)) - b_i) ] * A_i.
+ // Result will be a matrix; inner result is a scalar.
+ if (i == 0)
+ {
+ // A_0 = I_n. Hooray! That's easy! b_0 = 1.
+ double inner = -1 * double(n) - 0.5 *
+ (trace(trans(coordinates) * coordinates) - 1);
+
+ arma::mat zz = (inner * arma::eye<arma::mat>(n, n));
+
+// Log::Debug << "Constraint " << i << " matrix to add is " << std::endl;
+// Log::Debug << zz << std::endl;
+
+ s -= zz;
+ }
+ else
+ {
+ // Get edge so we can construct constraint A_i matrix. b_i = 0.
+ arma::vec edge = edges.col(i - 1);
+
+ arma::mat a;
+ a.zeros(n, n);
+
+ // Only two nonzero entries.
+ a(edge[0], edge[1]) = 1;
+ a(edge[1], edge[0]) = 1;
+
+ double inner = (-1) - 0.5 *
+ (trace(a * (trans(coordinates) * coordinates)));
+
+ arma::mat zz = (inner * a);
+
+// Log::Debug << "Constraint " << i << " matrix to add is " << std::endl;
+// Log::Debug << zz << std::endl;
+
+ s -= zz;
+ }
+ }
+
+// Log::Warn << "Calculated S is: " << std::endl << s << std::endl;
+
+ gradient = trans(2 * s * trans(coordinates));
+
+// Log::Warn << "Calculated gradient is: " << std::endl << gradient << std::endl;
+
+
+// Log::Debug << "Evaluating gradient. " << std::endl;
+
+ // The gradient of -Tr(ones * X) is equal to -2 * ones * R
+// arma::mat ones;
+// ones.ones(coordinates.n_rows, coordinates.n_rows);
+// gradient = -2 * ones * coordinates;
+
+// Log::Debug << "Done with gradient." << std::endl;
+// std::cout << gradient;
+}
+
+size_t LovaszThetaSDP::NumConstraints() const
+{
+ // Each edge is a constraint, and we have the constraint Tr(X) = 1.
+ return edges.n_cols + 1;
+}
+
+double LovaszThetaSDP::EvaluateConstraint(const size_t index,
+ const arma::mat& coordinates)
+{
+ if (index == 0) // This is the constraint Tr(X) = 1.
+ {
+ double sum = -1; // Tr(X) - 1 = 0, so we prefix the subtraction.
+ for (size_t i = 0; i < coordinates.n_cols; i++)
+ sum += std::abs(dot(coordinates.col(i), coordinates.col(i)));
+
+// Log::Debug << "Constraint " << index << " evaluates to " << sum << std::endl;
+ return sum;
+ }
+
+ size_t i = edges(0, index - 1);
+ size_t j = edges(1, index - 1);
+
+// Log::Debug << "Constraint " << index << " evaluates to " <<
+// dot(coordinates.col(i), coordinates.col(j)) << "." << std::endl;
+
+ // The constraint itself is X_ij, or (R^T R)_ij.
+ return std::abs(dot(coordinates.col(i), coordinates.col(j)));
+}
+
+void LovaszThetaSDP::GradientConstraint(const size_t index,
+ const arma::mat& coordinates,
+ arma::mat& gradient)
+{
+// Log::Debug << "Gradient of constraint " << index << " is " << std::endl;
+ if (index == 0) // This is the constraint Tr(X) = 1.
+ {
+ gradient = 2 * coordinates; // d/dR (Tr(R R^T)) = 2 R.
+// std::cout << gradient;
+ return;
+ }
+
+// Log::Debug << "Evaluating gradient of constraint " << index << " with ";
+ size_t i = edges(0, index - 1);
+ size_t j = edges(1, index - 1);
+// Log::Debug << "i = " << i << " and j = " << j << "." << std::endl;
+
+ // Since the constraint is (R^T R)_ij, the gradient for (x, y) will be (I
+ // derived this for one of the MVU constraints):
+ // 0 , y != i, y != j
+ // 2 R_xj, y = i, y != j
+ // 2 R_xi, y != i, y = j
+ // 4 R_xy, y = i, y = j
+ // This results in the gradient matrix having two nonzero rows; for row
+ // i, the elements are R_nj, where n is the row; for column j, the elements
+ // are R_ni.
+ gradient.zeros(coordinates.n_rows, coordinates.n_cols);
+
+ gradient.col(i) = coordinates.col(j);
+ gradient.col(j) += coordinates.col(i); // In case j = i (shouldn't happen).
+
+// std::cout << gradient;
+}
+
+const arma::mat& LovaszThetaSDP::GetInitialPoint()
+{
+ if (initialPoint.n_rows != 0 && initialPoint.n_cols != 0)
+ return initialPoint; // It has already been calculated.
+
+// Log::Debug << "Calculating initial point." << std::endl;
+
+ // First, we must calculate the correct value of r. The matrix we return, R,
+ // will be r x V, because X = R^T R is of dimension V x V.
+ // The rule for calculating r (from Monteiro and Burer, eq. 5) is
+ // r = max(r >= 0 : r (r + 1) / 2 <= m }
+ // where m is equal to the number of constraints plus one.
+ //
+ // Solved, this is
+ // 0.5 r^2 + 0.5 r - m = 0
+ // which becomes
+ // r = (-0.5 [+/-] sqrt((-0.5)^2 - 4 * -0.5 * m)) / -1
+ // r = 0.5 [+/-] sqrt(0.25 + 2 m)
+ // and because m is always positive,
+ // r = 0.5 + sqrt(0.25 + 2m)
+ float m = NumConstraints();
+ float r = 0.5 + sqrt(0.25 + 2 * m);
+ if (ceil(r) > vertices)
+ r = vertices; // An upper bound on the dimension.
+
+ Log::Debug << "Dimension will be " << ceil(r) << " x " << vertices << "."
+ << std::endl;
+
+ initialPoint.set_size(ceil(r), vertices);
+
+ // Now we set the entries of the initial matrix according to the formula given
+ // in Section 4 of Monteiro and Burer.
+ for (size_t i = 0; i < r; i++)
+ {
+ for (size_t j = 0; j < (size_t) vertices; j++)
+ {
+ if (i == j)
+ initialPoint(i, j) = sqrt(1.0 / r) + sqrt(1.0 / (vertices * m));
+ else
+ initialPoint(i, j) = sqrt(1.0 / (vertices * m));
+ }
+ }
+
+ Log::Debug << "Initial matrix " << std::endl << initialPoint << std::endl;
+
+ Log::Warn << "X " << std::endl << trans(initialPoint) * initialPoint
+ << std::endl;
+
+ Log::Warn << "accu " << accu(trans(initialPoint) * initialPoint) << std::endl;
+
+ return initialPoint;
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,155 +0,0 @@
-/**
- * @file aug_lagrangian_test_functions.hpp
- * @author Ryan Curtin
- *
- * Define test functions for the augmented Lagrangian method.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_TEST_FUNCTIONS_HPP
-#define __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_TEST_FUNCTIONS_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace optimization {
-
-/**
- * This function is taken from "Practical Mathematical Optimization" (Snyman),
- * section 5.3.8 ("Application of the Augmented Lagrangian Method"). It has
- * only one constraint.
- *
- * The minimum that satisfies the constraint is x = [1, 4], with an objective
- * value of 70.
- */
-class AugLagrangianTestFunction
-{
- public:
- AugLagrangianTestFunction();
- AugLagrangianTestFunction(const arma::mat& initial_point);
-
- double Evaluate(const arma::mat& coordinates);
- void Gradient(const arma::mat& coordinates, arma::mat& gradient);
-
- size_t NumConstraints() const { return 1; }
-
- double EvaluateConstraint(const size_t index, const arma::mat& coordinates);
- void GradientConstraint(const size_t index,
- const arma::mat& coordinates,
- arma::mat& gradient);
-
- const arma::mat& GetInitialPoint() const { return initialPoint; }
-
- private:
- arma::mat initialPoint;
-};
-
-/**
- * This function is taken from M. Gockenbach's lectures on general nonlinear
- * programs, found at:
- * http://www.math.mtu.edu/~msgocken/ma5630spring2003/lectures/nlp/nlp.pdf
- *
- * The program we are using is example 2.5 from this document.
- * I have arbitrarily decided that this will be called the Gockenbach function.
- *
- * The minimum that satisfies the two constraints is given as
- * x = [0.12288, -1.1078, 0.015100], with an objective value of about 29.634.
- */
-class GockenbachFunction
-{
- public:
- GockenbachFunction();
- GockenbachFunction(const arma::mat& initial_point);
-
- double Evaluate(const arma::mat& coordinates);
- void Gradient(const arma::mat& coordinates, arma::mat& gradient);
-
- size_t NumConstraints() const { return 2; };
-
- double EvaluateConstraint(const size_t index, const arma::mat& coordinates);
- void GradientConstraint(const size_t index,
- const arma::mat& coordinates,
- arma::mat& gradient);
-
- const arma::mat& GetInitialPoint() const { return initialPoint; }
-
- private:
- arma::mat initialPoint;
-};
-
-
-
-/**
- * This function is the Lovasz-Theta semidefinite program, as implemented in the
- * following paper:
- *
- * S. Burer, R. Monteiro
- * "A nonlinear programming algorithm for solving semidefinite programs via
- * low-rank factorization."
- * Journal of Mathematical Programming, 2004
- *
- * Given a simple, undirected graph G = (V, E), the Lovasz-Theta SDP is defined
- * by:
- *
- * min_X{Tr(-(e e^T)^T X) : Tr(X) = 1, X_ij = 0 for all (i, j) in E, X >= 0}
- *
- * where e is the vector of all ones and X has dimension |V| x |V|.
- *
- * In the Monteiro-Burer formulation, we take X = R * R^T, where R is the
- * coordinates given to the Evaluate(), Gradient(), EvaluateConstraint(), and
- * GradientConstraint() functions.
- */
-class LovaszThetaSDP
-{
- public:
- LovaszThetaSDP();
-
- /**
- * Initialize the Lovasz-Theta SDP with the given set of edges. The edge
- * matrix should consist of rows of two dimensions, where dimension 0 is the
- * first vertex of the edge and dimension 1 is the second edge (or vice versa,
- * as it doesn't make a difference).
- *
- * @param edges Matrix of edges.
- */
- LovaszThetaSDP(const arma::mat& edges);
-
- double Evaluate(const arma::mat& coordinates);
- void Gradient(const arma::mat& coordinates, arma::mat& gradient);
-
- size_t NumConstraints() const;
-
- double EvaluateConstraint(const size_t index, const arma::mat& coordinates);
- void GradientConstraint(const size_t index,
- const arma::mat& coordinates,
- arma::mat& gradient);
-
- const arma::mat& GetInitialPoint();
-
- const arma::mat& Edges() const { return edges; }
- arma::mat& Edges() { return edges; }
-
- private:
- arma::mat edges;
- size_t vertices;
-
- arma::mat initialPoint;
-};
-
-}; // namespace optimization
-}; // namespace mlpack
-
-#endif // __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_TEST_FUNCTIONS_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,155 @@
+/**
+ * @file aug_lagrangian_test_functions.hpp
+ * @author Ryan Curtin
+ *
+ * Define test functions for the augmented Lagrangian method.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_TEST_FUNCTIONS_HPP
+#define __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_TEST_FUNCTIONS_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace optimization {
+
+/**
+ * This function is taken from "Practical Mathematical Optimization" (Snyman),
+ * section 5.3.8 ("Application of the Augmented Lagrangian Method"). It has
+ * only one constraint.
+ *
+ * The minimum that satisfies the constraint is x = [1, 4], with an objective
+ * value of 70.
+ */
+class AugLagrangianTestFunction
+{
+ public:
+ AugLagrangianTestFunction();
+ AugLagrangianTestFunction(const arma::mat& initial_point);
+
+ double Evaluate(const arma::mat& coordinates);
+ void Gradient(const arma::mat& coordinates, arma::mat& gradient);
+
+ size_t NumConstraints() const { return 1; }
+
+ double EvaluateConstraint(const size_t index, const arma::mat& coordinates);
+ void GradientConstraint(const size_t index,
+ const arma::mat& coordinates,
+ arma::mat& gradient);
+
+ const arma::mat& GetInitialPoint() const { return initialPoint; }
+
+ private:
+ arma::mat initialPoint;
+};
+
+/**
+ * This function is taken from M. Gockenbach's lectures on general nonlinear
+ * programs, found at:
+ * http://www.math.mtu.edu/~msgocken/ma5630spring2003/lectures/nlp/nlp.pdf
+ *
+ * The program we are using is example 2.5 from this document.
+ * I have arbitrarily decided that this will be called the Gockenbach function.
+ *
+ * The minimum that satisfies the two constraints is given as
+ * x = [0.12288, -1.1078, 0.015100], with an objective value of about 29.634.
+ */
+class GockenbachFunction
+{
+ public:
+ GockenbachFunction();
+ GockenbachFunction(const arma::mat& initial_point);
+
+ double Evaluate(const arma::mat& coordinates);
+ void Gradient(const arma::mat& coordinates, arma::mat& gradient);
+
+ size_t NumConstraints() const { return 2; };
+
+ double EvaluateConstraint(const size_t index, const arma::mat& coordinates);
+ void GradientConstraint(const size_t index,
+ const arma::mat& coordinates,
+ arma::mat& gradient);
+
+ const arma::mat& GetInitialPoint() const { return initialPoint; }
+
+ private:
+ arma::mat initialPoint;
+};
+
+
+
+/**
+ * This function is the Lovasz-Theta semidefinite program, as implemented in the
+ * following paper:
+ *
+ * S. Burer, R. Monteiro
+ * "A nonlinear programming algorithm for solving semidefinite programs via
+ * low-rank factorization."
+ * Journal of Mathematical Programming, 2004
+ *
+ * Given a simple, undirected graph G = (V, E), the Lovasz-Theta SDP is defined
+ * by:
+ *
+ * min_X{Tr(-(e e^T)^T X) : Tr(X) = 1, X_ij = 0 for all (i, j) in E, X >= 0}
+ *
+ * where e is the vector of all ones and X has dimension |V| x |V|.
+ *
+ * In the Monteiro-Burer formulation, we take X = R * R^T, where R is the
+ * coordinates given to the Evaluate(), Gradient(), EvaluateConstraint(), and
+ * GradientConstraint() functions.
+ */
+class LovaszThetaSDP
+{
+ public:
+ LovaszThetaSDP();
+
+ /**
+ * Initialize the Lovasz-Theta SDP with the given set of edges. The edge
+ * matrix should consist of rows of two dimensions, where dimension 0 is the
+ * first vertex of the edge and dimension 1 is the second edge (or vice versa,
+ * as it doesn't make a difference).
+ *
+ * @param edges Matrix of edges.
+ */
+ LovaszThetaSDP(const arma::mat& edges);
+
+ double Evaluate(const arma::mat& coordinates);
+ void Gradient(const arma::mat& coordinates, arma::mat& gradient);
+
+ size_t NumConstraints() const;
+
+ double EvaluateConstraint(const size_t index, const arma::mat& coordinates);
+ void GradientConstraint(const size_t index,
+ const arma::mat& coordinates,
+ arma::mat& gradient);
+
+ const arma::mat& GetInitialPoint();
+
+ const arma::mat& Edges() const { return edges; }
+ arma::mat& Edges() { return edges; }
+
+ private:
+ arma::mat edges;
+ size_t vertices;
+
+ arma::mat initialPoint;
+};
+
+}; // namespace optimization
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_TEST_FUNCTIONS_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/lbfgs.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/lbfgs/lbfgs.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/lbfgs.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,263 +0,0 @@
-/**
- * @file lbfgs.hpp
- * @author Dongryeol Lee
- * @author Ryan Curtin
- *
- * The generic L-BFGS optimizer.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_OPTIMIZERS_LBFGS_LBFGS_HPP
-#define __MLPACK_CORE_OPTIMIZERS_LBFGS_LBFGS_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace optimization {
-
-/**
- * The generic L-BFGS optimizer, which uses a back-tracking line search
- * algorithm to minimize a function. The parameters for the algorithm (number
- * of memory points, maximum step size, and so forth) are all configurable via
- * either the constructor or standalone modifier functions. A function which
- * can be optimized by this class must implement the following methods:
- *
- * - a default constructor
- * - double Evaluate(const arma::mat& coordinates);
- * - void Gradient(const arma::mat& coordinates, arma::mat& gradient);
- * - arma::mat& GetInitialPoint();
- */
-template<typename FunctionType>
-class L_BFGS
-{
- public:
- /**
- * Initialize the L-BFGS object. Store a reference to the function we will be
- * optimizing and set the size of the memory for the algorithm. There are
- * many parameters that can be set for the optimization, but default values
- * are given for each of them.
- *
- * @param function Instance of function to be optimized.
- * @param numBasis Number of memory points to be stored (default 5).
- * @param maxIterations Maximum number of iterations for the optimization
- * (default 0 -- may run indefinitely).
- * @param armijoConstant Controls the accuracy of the line search routine for
- * determining the Armijo condition.
- * @param wolfe Parameter for detecting the Wolfe condition.
- * @param minGradientNorm Minimum gradient norm required to continue the
- * optimization.
- * @param maxLineSearchTrials The maximum number of trials for the line search
- * (before giving up).
- * @param minStep The minimum step of the line search.
- * @param maxStep The maximum step of the line search.
- */
- L_BFGS(FunctionType& function,
- const size_t numBasis = 5, /* entirely arbitrary */
- const size_t maxIterations = 0, /* run forever */
- const double armijoConstant = 1e-4,
- const double wolfe = 0.9,
- const double minGradientNorm = 1e-10,
- const size_t maxLineSearchTrials = 50,
- const double minStep = 1e-20,
- const double maxStep = 1e20);
-
- /**
- * Return the point where the lowest function value has been found.
- *
- * @return arma::vec representing the point and a double with the function
- * value at that point.
- */
- const std::pair<arma::mat, double>& MinPointIterate() const;
-
- /**
- * Use L-BFGS to optimize the given function, starting at the given iterate
- * point and finding the minimum. The maximum number of iterations is set in
- * the constructor (or with MaxIterations()). Alternately, another overload
- * is provided which takes a maximum number of iterations as a parameter. The
- * given starting point will be modified to store the finishing point of the
- * algorithm, and the final objective value is returned.
- *
- * @param iterate Starting point (will be modified).
- * @return Objective value of the final point.
- */
- double Optimize(arma::mat& iterate);
-
- /**
- * Use L-BFGS to optimize (minimize) the given function, starting at the given
- * iterate point, and performing no more than the given maximum number of
- * iterations (the class variable maxIterations is ignored for this run, but
- * not modified). The given starting point will be modified to store the
- * finishing point of the algorithm, and the final objective value is
- * returned.
- *
- * @param iterate Starting point (will be modified).
- * @param maxIterations Maximum number of iterations (0 specifies no limit).
- * @return Objective value of the final point.
- */
- double Optimize(arma::mat& iterate, const size_t maxIterations);
-
- //! Get the memory size.
- size_t NumBasis() const { return numBasis; }
- //! Modify the memory size.
- size_t& NumBasis() { return numBasis; }
-
- //! Get the maximum number of iterations.
- size_t MaxIterations() const { return maxIterations; }
- //! Modify the maximum number of iterations.
- size_t& MaxIterations() { return maxIterations; }
-
- //! Get the Armijo condition constant.
- double ArmijoConstant() const { return armijoConstant; }
- //! Modify the Armijo condition constant.
- double& ArmijoConstant() { return armijoConstant; }
-
- //! Get the Wolfe parameter.
- double Wolfe() const { return wolfe; }
- //! Modify the Wolfe parameter.
- double& Wolfe() { return wolfe; }
-
- //! Get the minimum gradient norm.
- double MinGradientNorm() const { return minGradientNorm; }
- //! Modify the minimum gradient norm.
- double& MinGradientNorm() { return minGradientNorm; }
-
- //! Get the maximum number of line search trials.
- size_t MaxLineSearchTrials() const { return maxLineSearchTrials; }
- //! Modify the maximum number of line search trials.
- size_t& MaxLineSearchTrials() { return maxLineSearchTrials; }
-
- //! Return the minimum line search step size.
- double MinStep() const { return minStep; }
- //! Modify the minimum line search step size.
- double& MinStep() { return minStep; }
-
- //! Return the maximum line search step size.
- double MaxStep() const { return maxStep; }
- //! Modify the maximum line search step size.
- double& MaxStep() { return maxStep; }
-
- private:
- //! Internal reference to the function we are optimizing.
- FunctionType& function;
-
- //! Position of the new iterate.
- arma::mat newIterateTmp;
- //! Stores all the s matrices in memory.
- arma::cube s;
- //! Stores all the y matrices in memory.
- arma::cube y;
-
- //! Size of memory for this L-BFGS optimizer.
- size_t numBasis;
- //! Maximum number of iterations.
- size_t maxIterations;
- //! Parameter for determining the Armijo condition.
- double armijoConstant;
- //! Parameter for detecting the Wolfe condition.
- double wolfe;
- //! Minimum gradient norm required to continue the optimization.
- double minGradientNorm;
- //! Maximum number of trials for the line search.
- size_t maxLineSearchTrials;
- //! Minimum step of the line search.
- double minStep;
- //! Maximum step of the line search.
- double maxStep;
-
- //! Best point found so far.
- std::pair<arma::mat, double> minPointIterate;
-
- /**
- * Evaluate the function at the given iterate point and store the result if it
- * is a new minimum.
- *
- * @return The value of the function.
- */
- double Evaluate(const arma::mat& iterate);
-
- /**
- * Calculate the scaling factor, gamma, which is used to scale the Hessian
- * approximation matrix. See method M3 in Section 4 of Liu and Nocedal
- * (1989).
- *
- * @return The calculated scaling factor.
- */
- double ChooseScalingFactor(const size_t iterationNum,
- const arma::mat& gradient);
-
- /**
- * Check to make sure that the norm of the gradient is not smaller than 1e-5.
- * Currently that value is not configurable.
- *
- * @return (norm < minGradientNorm).
- */
- bool GradientNormTooSmall(const arma::mat& gradient);
-
- /**
- * Perform a back-tracking line search along the search direction to
- * calculate a step size satisfying the Wolfe conditions. The parameter
- * iterate will be modified if the method is successful.
- *
- * @param functionValue Value of the function at the initial point
- * @param iterate The initial point to begin the line search from
- * @param gradient The gradient at the initial point
- * @param searchDirection A vector specifying the search direction
- * @param stepSize Variable the calculated step size will be stored in
- *
- * @return false if no step size is suitable, true otherwise.
- */
- bool LineSearch(double& functionValue,
- arma::mat& iterate,
- arma::mat& gradient,
- const arma::mat& searchDirection);
-
- /**
- * Find the L-BFGS search direction.
- *
- * @param gradient The gradient at the current point
- * @param iteration_num The iteration number
- * @param scaling_factor Scaling factor to use (see ChooseScalingFactor_())
- * @param search_direction Vector to store search direction in
- */
- void SearchDirection(const arma::mat& gradient,
- const size_t iterationNum,
- const double scalingFactor,
- arma::mat& searchDirection);
-
- /**
- * Update the y and s matrices, which store the differences
- * between the iterate and old iterate and the differences between the
- * gradient and the old gradient, respectively.
- *
- * @param iterationNum Iteration number
- * @param iterate Current point
- * @param oldIterate Point at last iteration
- * @param gradient Gradient at current point (iterate)
- * @param oldGradient Gradient at last iteration point (oldIterate)
- */
- void UpdateBasisSet(const size_t iterationNum,
- const arma::mat& iterate,
- const arma::mat& oldIterate,
- const arma::mat& gradient,
- const arma::mat& oldGradient);
-};
-
-}; // namespace optimization
-}; // namespace mlpack
-
-#include "lbfgs_impl.hpp"
-
-#endif // __MLPACK_CORE_OPTIMIZERS_LBFGS_LBFGS_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/lbfgs.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/lbfgs/lbfgs.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/lbfgs.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/lbfgs.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,263 @@
+/**
+ * @file lbfgs.hpp
+ * @author Dongryeol Lee
+ * @author Ryan Curtin
+ *
+ * The generic L-BFGS optimizer.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_OPTIMIZERS_LBFGS_LBFGS_HPP
+#define __MLPACK_CORE_OPTIMIZERS_LBFGS_LBFGS_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace optimization {
+
+/**
+ * The generic L-BFGS optimizer, which uses a back-tracking line search
+ * algorithm to minimize a function. The parameters for the algorithm (number
+ * of memory points, maximum step size, and so forth) are all configurable via
+ * either the constructor or standalone modifier functions. A function which
+ * can be optimized by this class must implement the following methods:
+ *
+ * - a default constructor
+ * - double Evaluate(const arma::mat& coordinates);
+ * - void Gradient(const arma::mat& coordinates, arma::mat& gradient);
+ * - arma::mat& GetInitialPoint();
+ */
+template<typename FunctionType>
+class L_BFGS
+{
+ public:
+ /**
+ * Initialize the L-BFGS object. Store a reference to the function we will be
+ * optimizing and set the size of the memory for the algorithm. There are
+ * many parameters that can be set for the optimization, but default values
+ * are given for each of them.
+ *
+ * @param function Instance of function to be optimized.
+ * @param numBasis Number of memory points to be stored (default 5).
+ * @param maxIterations Maximum number of iterations for the optimization
+ * (default 0 -- may run indefinitely).
+ * @param armijoConstant Controls the accuracy of the line search routine for
+ * determining the Armijo condition.
+ * @param wolfe Parameter for detecting the Wolfe condition.
+ * @param minGradientNorm Minimum gradient norm required to continue the
+ * optimization.
+ * @param maxLineSearchTrials The maximum number of trials for the line search
+ * (before giving up).
+ * @param minStep The minimum step of the line search.
+ * @param maxStep The maximum step of the line search.
+ */
+ L_BFGS(FunctionType& function,
+ const size_t numBasis = 5, /* entirely arbitrary */
+ const size_t maxIterations = 0, /* run forever */
+ const double armijoConstant = 1e-4,
+ const double wolfe = 0.9,
+ const double minGradientNorm = 1e-10,
+ const size_t maxLineSearchTrials = 50,
+ const double minStep = 1e-20,
+ const double maxStep = 1e20);
+
+ /**
+ * Return the point where the lowest function value has been found.
+ *
+ * @return arma::vec representing the point and a double with the function
+ * value at that point.
+ */
+ const std::pair<arma::mat, double>& MinPointIterate() const;
+
+ /**
+ * Use L-BFGS to optimize the given function, starting at the given iterate
+ * point and finding the minimum. The maximum number of iterations is set in
+ * the constructor (or with MaxIterations()). Alternately, another overload
+ * is provided which takes a maximum number of iterations as a parameter. The
+ * given starting point will be modified to store the finishing point of the
+ * algorithm, and the final objective value is returned.
+ *
+ * @param iterate Starting point (will be modified).
+ * @return Objective value of the final point.
+ */
+ double Optimize(arma::mat& iterate);
+
+ /**
+ * Use L-BFGS to optimize (minimize) the given function, starting at the given
+ * iterate point, and performing no more than the given maximum number of
+ * iterations (the class variable maxIterations is ignored for this run, but
+ * not modified). The given starting point will be modified to store the
+ * finishing point of the algorithm, and the final objective value is
+ * returned.
+ *
+ * @param iterate Starting point (will be modified).
+ * @param maxIterations Maximum number of iterations (0 specifies no limit).
+ * @return Objective value of the final point.
+ */
+ double Optimize(arma::mat& iterate, const size_t maxIterations);
+
+ //! Get the memory size.
+ size_t NumBasis() const { return numBasis; }
+ //! Modify the memory size.
+ size_t& NumBasis() { return numBasis; }
+
+ //! Get the maximum number of iterations.
+ size_t MaxIterations() const { return maxIterations; }
+ //! Modify the maximum number of iterations.
+ size_t& MaxIterations() { return maxIterations; }
+
+ //! Get the Armijo condition constant.
+ double ArmijoConstant() const { return armijoConstant; }
+ //! Modify the Armijo condition constant.
+ double& ArmijoConstant() { return armijoConstant; }
+
+ //! Get the Wolfe parameter.
+ double Wolfe() const { return wolfe; }
+ //! Modify the Wolfe parameter.
+ double& Wolfe() { return wolfe; }
+
+ //! Get the minimum gradient norm.
+ double MinGradientNorm() const { return minGradientNorm; }
+ //! Modify the minimum gradient norm.
+ double& MinGradientNorm() { return minGradientNorm; }
+
+ //! Get the maximum number of line search trials.
+ size_t MaxLineSearchTrials() const { return maxLineSearchTrials; }
+ //! Modify the maximum number of line search trials.
+ size_t& MaxLineSearchTrials() { return maxLineSearchTrials; }
+
+ //! Return the minimum line search step size.
+ double MinStep() const { return minStep; }
+ //! Modify the minimum line search step size.
+ double& MinStep() { return minStep; }
+
+ //! Return the maximum line search step size.
+ double MaxStep() const { return maxStep; }
+ //! Modify the maximum line search step size.
+ double& MaxStep() { return maxStep; }
+
+ private:
+ //! Internal reference to the function we are optimizing.
+ FunctionType& function;
+
+ //! Position of the new iterate.
+ arma::mat newIterateTmp;
+ //! Stores all the s matrices in memory.
+ arma::cube s;
+ //! Stores all the y matrices in memory.
+ arma::cube y;
+
+ //! Size of memory for this L-BFGS optimizer.
+ size_t numBasis;
+ //! Maximum number of iterations.
+ size_t maxIterations;
+ //! Parameter for determining the Armijo condition.
+ double armijoConstant;
+ //! Parameter for detecting the Wolfe condition.
+ double wolfe;
+ //! Minimum gradient norm required to continue the optimization.
+ double minGradientNorm;
+ //! Maximum number of trials for the line search.
+ size_t maxLineSearchTrials;
+ //! Minimum step of the line search.
+ double minStep;
+ //! Maximum step of the line search.
+ double maxStep;
+
+ //! Best point found so far.
+ std::pair<arma::mat, double> minPointIterate;
+
+ /**
+ * Evaluate the function at the given iterate point and store the result if it
+ * is a new minimum.
+ *
+ * @return The value of the function.
+ */
+ double Evaluate(const arma::mat& iterate);
+
+ /**
+ * Calculate the scaling factor, gamma, which is used to scale the Hessian
+ * approximation matrix. See method M3 in Section 4 of Liu and Nocedal
+ * (1989).
+ *
+ * @return The calculated scaling factor.
+ */
+ double ChooseScalingFactor(const size_t iterationNum,
+ const arma::mat& gradient);
+
+ /**
+ * Check to make sure that the norm of the gradient is not smaller than 1e-5.
+ * Currently that value is not configurable.
+ *
+ * @return (norm < minGradientNorm).
+ */
+ bool GradientNormTooSmall(const arma::mat& gradient);
+
+ /**
+ * Perform a back-tracking line search along the search direction to
+ * calculate a step size satisfying the Wolfe conditions. The parameter
+ * iterate will be modified if the method is successful.
+ *
+ * @param functionValue Value of the function at the initial point
+ * @param iterate The initial point to begin the line search from
+ * @param gradient The gradient at the initial point
+ * @param searchDirection A vector specifying the search direction
+ * @param stepSize Variable the calculated step size will be stored in
+ *
+ * @return false if no step size is suitable, true otherwise.
+ */
+ bool LineSearch(double& functionValue,
+ arma::mat& iterate,
+ arma::mat& gradient,
+ const arma::mat& searchDirection);
+
+ /**
+ * Find the L-BFGS search direction.
+ *
+ * @param gradient The gradient at the current point
+ * @param iteration_num The iteration number
+ * @param scaling_factor Scaling factor to use (see ChooseScalingFactor_())
+ * @param search_direction Vector to store search direction in
+ */
+ void SearchDirection(const arma::mat& gradient,
+ const size_t iterationNum,
+ const double scalingFactor,
+ arma::mat& searchDirection);
+
+ /**
+ * Update the y and s matrices, which store the differences
+ * between the iterate and old iterate and the differences between the
+ * gradient and the old gradient, respectively.
+ *
+ * @param iterationNum Iteration number
+ * @param iterate Current point
+ * @param oldIterate Point at last iteration
+ * @param gradient Gradient at current point (iterate)
+ * @param oldGradient Gradient at last iteration point (oldIterate)
+ */
+ void UpdateBasisSet(const size_t iterationNum,
+ const arma::mat& iterate,
+ const arma::mat& oldIterate,
+ const arma::mat& gradient,
+ const arma::mat& oldGradient);
+};
+
+}; // namespace optimization
+}; // namespace mlpack
+
+#include "lbfgs_impl.hpp"
+
+#endif // __MLPACK_CORE_OPTIMIZERS_LBFGS_LBFGS_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/lbfgs_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/lbfgs/lbfgs_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/lbfgs_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,431 +0,0 @@
-/**
- * @file lbfgs_impl.hpp
- * @author Dongryeol Lee (dongryel at cc.gatech.edu)
- * @author Ryan Curtin
- *
- * The implementation of the L_BFGS optimizer.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_OPTIMIZERS_LBFGS_LBFGS_IMPL_HPP
-#define __MLPACK_CORE_OPTIMIZERS_LBFGS_LBFGS_IMPL_HPP
-
-namespace mlpack {
-namespace optimization {
-
-/**
- * Evaluate the function at the given iterate point and store the result if
- * it is a new minimum.
- *
- * @return The value of the function
- */
-template<typename FunctionType>
-double L_BFGS<FunctionType>::Evaluate(const arma::mat& iterate)
-{
- // Evaluate the function and keep track of the minimum function
- // value encountered during the optimization.
- double functionValue = function.Evaluate(iterate);
-
- if (functionValue < minPointIterate.second)
- {
- minPointIterate.first = iterate;
- minPointIterate.second = functionValue;
- }
-
- return functionValue;
-}
-
-/**
- * Calculate the scaling factor gamma which is used to scale the Hessian
- * approximation matrix. See method M3 in Section 4 of Liu and Nocedal (1989).
- *
- * @return The calculated scaling factor
- */
-template<typename FunctionType>
-double L_BFGS<FunctionType>::ChooseScalingFactor(const size_t iterationNum,
- const arma::mat& gradient)
-{
- double scalingFactor = 1.0;
- if (iterationNum > 0)
- {
- int previousPos = (iterationNum - 1) % numBasis;
- // Get s and y matrices once instead of multiple times.
- arma::mat& sMat = s.slice(previousPos);
- arma::mat& yMat = y.slice(previousPos);
- scalingFactor = dot(sMat, yMat) / dot(yMat, yMat);
- }
- else
- {
- scalingFactor = 1.0 / sqrt(dot(gradient, gradient));
- }
-
- return scalingFactor;
-}
-
-/**
- * Check to make sure that the norm of the gradient is not smaller than 1e-10.
- * Currently that value is not configurable.
- *
- * @return (norm < minGradientNorm)
- */
-template<typename FunctionType>
-bool L_BFGS<FunctionType>::GradientNormTooSmall(const arma::mat& gradient)
-{
- double norm = arma::norm(gradient, 2);
-
- return (norm < minGradientNorm);
-}
-
-/**
- * Perform a back-tracking line search along the search direction to calculate a
- * step size satisfying the Wolfe conditions.
- *
- * @param functionValue Value of the function at the initial point
- * @param iterate The initial point to begin the line search from
- * @param gradient The gradient at the initial point
- * @param searchDirection A vector specifying the search direction
- * @param stepSize Variable the calculated step size will be stored in
- *
- * @return false if no step size is suitable, true otherwise.
- */
-template<typename FunctionType>
-bool L_BFGS<FunctionType>::LineSearch(double& functionValue,
- arma::mat& iterate,
- arma::mat& gradient,
- const arma::mat& searchDirection)
-{
- // Default first step size of 1.0.
- double stepSize = 1.0;
-
- // The initial linear term approximation in the direction of the
- // search direction.
- double initialSearchDirectionDotGradient =
- arma::dot(gradient, searchDirection);
-
- // If it is not a descent direction, just report failure.
- if (initialSearchDirectionDotGradient > 0.0)
- {
- Log::Warn << "L-BFGS line search direction is not a descent direction "
- << "(terminating)!" << std::endl;
- return false;
- }
-
- // Save the initial function value.
- double initialFunctionValue = functionValue;
-
- // Unit linear approximation to the decrease in function value.
- double linearApproxFunctionValueDecrease = armijoConstant *
- initialSearchDirectionDotGradient;
-
- // The number of iteration in the search.
- size_t numIterations = 0;
-
- // Armijo step size scaling factor for increase and decrease.
- const double inc = 2.1;
- const double dec = 0.5;
- double width = 0;
-
- while (true)
- {
- // Perform a step and evaluate the gradient and the function values at that
- // point.
- newIterateTmp = iterate;
- newIterateTmp += stepSize * searchDirection;
- functionValue = Evaluate(newIterateTmp);
- function.Gradient(newIterateTmp, gradient);
- numIterations++;
-
- if (functionValue > initialFunctionValue + stepSize *
- linearApproxFunctionValueDecrease)
- {
- width = dec;
- }
- else
- {
- // Check Wolfe's condition.
- double searchDirectionDotGradient = arma::dot(gradient, searchDirection);
-
- if (searchDirectionDotGradient < wolfe *
- initialSearchDirectionDotGradient)
- {
- width = inc;
- }
- else
- {
- if (searchDirectionDotGradient > -wolfe *
- initialSearchDirectionDotGradient)
- {
- width = dec;
- }
- else
- {
- break;
- }
- }
- }
-
- // Terminate when the step size gets too small or too big or it
- // exceeds the max number of iterations.
- if ((stepSize < minStep) || (stepSize > maxStep) ||
- (numIterations >= maxLineSearchTrials))
- {
- return false;
- }
-
- // Scale the step size.
- stepSize *= width;
- }
-
- // Move to the new iterate.
- iterate = newIterateTmp;
- return true;
-}
-
-/**
- * Find the L_BFGS search direction.
- *
- * @param gradient The gradient at the current point
- * @param iterationNum The iteration number
- * @param scalingFactor Scaling factor to use (see ChooseScalingFactor_())
- * @param searchDirection Vector to store search direction in
- */
-template<typename FunctionType>
-void L_BFGS<FunctionType>::SearchDirection(const arma::mat& gradient,
- const size_t iterationNum,
- const double scalingFactor,
- arma::mat& searchDirection)
-{
- // Start from this point.
- searchDirection = gradient;
-
- // See "A Recursive Formula to Compute H * g" in "Updating quasi-Newton
- // matrices with limited storage" (Nocedal, 1980).
-
- // Temporary variables.
- arma::vec rho(numBasis);
- arma::vec alpha(numBasis);
-
- size_t limit = (numBasis > iterationNum) ? 0 : (iterationNum - numBasis);
- for (size_t i = iterationNum; i != limit; i--)
- {
- int translatedPosition = (i + (numBasis - 1)) % numBasis;
- rho[iterationNum - i] = 1.0 / arma::dot(y.slice(translatedPosition),
- s.slice(translatedPosition));
- alpha[iterationNum - i] = rho[iterationNum - i] *
- arma::dot(s.slice(translatedPosition), searchDirection);
- searchDirection -= alpha[iterationNum - i] * y.slice(translatedPosition);
- }
-
- searchDirection *= scalingFactor;
-
- for (size_t i = limit; i < iterationNum; i++)
- {
- int translatedPosition = i % numBasis;
- double beta = rho[iterationNum - i - 1] *
- arma::dot(y.slice(translatedPosition), searchDirection);
- searchDirection += (alpha[iterationNum - i - 1] - beta) *
- s.slice(translatedPosition);
- }
-
- // Negate the search direction so that it is a descent direction.
- searchDirection *= -1;
-}
-
-/**
- * Update the y and s matrices, which store the differences between
- * the iterate and old iterate and the differences between the gradient and the
- * old gradient, respectively.
- *
- * @param iterationNum Iteration number
- * @param iterate Current point
- * @param oldIterate Point at last iteration
- * @param gradient Gradient at current point (iterate)
- * @param oldGradient Gradient at last iteration point (oldIterate)
- */
-template<typename FunctionType>
-void L_BFGS<FunctionType>::UpdateBasisSet(const size_t iterationNum,
- const arma::mat& iterate,
- const arma::mat& oldIterate,
- const arma::mat& gradient,
- const arma::mat& oldGradient)
-{
- // Overwrite a certain position instead of pushing everything in the vector
- // back one position.
- int overwritePos = iterationNum % numBasis;
- s.slice(overwritePos) = iterate - oldIterate;
- y.slice(overwritePos) = gradient - oldGradient;
-}
-
-/**
- * Initialize the L_BFGS object. Copy the function we will be optimizing and
- * set the size of the memory for the algorithm.
- *
- * @param function Instance of function to be optimized
- * @param numBasis Number of memory points to be stored
- * @param armijoConstant Controls the accuracy of the line search routine for
- * determining the Armijo condition.
- * @param wolfe Parameter for detecting the Wolfe condition.
- * @param minGradientNorm Minimum gradient norm required to continue the
- * optimization.
- * @param maxLineSearchTrials The maximum number of trials for the line search
- * (before giving up).
- * @param minStep The minimum step of the line search.
- * @param maxStep The maximum step of the line search.
- */
-template<typename FunctionType>
-L_BFGS<FunctionType>::L_BFGS(FunctionType& function,
- const size_t numBasis,
- const size_t maxIterations,
- const double armijoConstant,
- const double wolfe,
- const double minGradientNorm,
- const size_t maxLineSearchTrials,
- const double minStep,
- const double maxStep) :
- function(function),
- numBasis(numBasis),
- maxIterations(maxIterations),
- armijoConstant(armijoConstant),
- wolfe(wolfe),
- minGradientNorm(minGradientNorm),
- maxLineSearchTrials(maxLineSearchTrials),
- minStep(minStep),
- maxStep(maxStep)
-{
- // Get the dimensions of the coordinates of the function; GetInitialPoint()
- // might return an arma::vec, but that's okay because then n_cols will simply
- // be 1.
- const size_t rows = function.GetInitialPoint().n_rows;
- const size_t cols = function.GetInitialPoint().n_cols;
-
- newIterateTmp.set_size(rows, cols);
- s.set_size(rows, cols, numBasis);
- y.set_size(rows, cols, numBasis);
-
- // Allocate the pair holding the min iterate information.
- minPointIterate.first.zeros(rows, cols);
- minPointIterate.second = std::numeric_limits<double>::max();
-}
-
-/**
- * Return the point where the lowest function value has been found.
- *
- * @return arma::vec representing the point and a double with the function
- * value at that point.
- */
-template<typename FunctionType>
-inline const std::pair<arma::mat, double>&
-L_BFGS<FunctionType>::MinPointIterate() const
-{
- return minPointIterate;
-}
-
-template<typename FunctionType>
-inline double L_BFGS<FunctionType>::Optimize(arma::mat& iterate)
-{
- return Optimize(iterate, maxIterations);
-}
-
-/**
- * Use L_BFGS to optimize the given function, starting at the given iterate
- * point and performing no more than the specified number of maximum iterations.
- * The given starting point will be modified to store the finishing point of the
- * algorithm.
- *
- * @param numIterations Maximum number of iterations to perform
- * @param iterate Starting point (will be modified)
- */
-template<typename FunctionType>
-double L_BFGS<FunctionType>::Optimize(arma::mat& iterate,
- const size_t maxIterations)
-{
- // The old iterate to be saved.
- arma::mat oldIterate;
- oldIterate.zeros(iterate.n_rows, iterate.n_cols);
-
- // Whether to optimize until convergence.
- bool optimizeUntilConvergence = (maxIterations == 0);
-
- // The initial function value.
- double functionValue = Evaluate(iterate);
-
- // The gradient: the current and the old.
- arma::mat gradient;
- arma::mat oldGradient;
- gradient.zeros(iterate.n_rows, iterate.n_cols);
- oldGradient.zeros(iterate.n_rows, iterate.n_cols);
-
- // The search direction.
- arma::mat searchDirection;
- searchDirection.zeros(iterate.n_rows, iterate.n_cols);
-
- // The initial gradient value.
- function.Gradient(iterate, gradient);
-
- // The main optimization loop.
- for (size_t itNum = 0; optimizeUntilConvergence || (itNum != maxIterations);
- ++itNum)
- {
- Log::Debug << "L-BFGS iteration " << itNum << "; objective " <<
- function.Evaluate(iterate) << "." << std::endl;
-
- // Break when the norm of the gradient becomes too small.
- if (GradientNormTooSmall(gradient))
- {
- Log::Debug << "L-BFGS gradient norm too small (terminating successfully)."
- << std::endl;
- break;
- }
-
- // Choose the scaling factor.
- double scalingFactor = ChooseScalingFactor(itNum, gradient);
-
- // Build an approximation to the Hessian and choose the search
- // direction for the current iteration.
- SearchDirection(gradient, itNum, scalingFactor, searchDirection);
-
- // Save the old iterate and the gradient before stepping.
- oldIterate = iterate;
- oldGradient = gradient;
-
- // Do a line search and take a step.
- if (!LineSearch(functionValue, iterate, gradient, searchDirection))
- {
- Log::Debug << "Line search failed. Stopping optimization." << std::endl;
- break; // The line search failed; nothing else to try.
- }
-
- // It is possible that the difference between the two coordinates is zero.
- // In this case we terminate successfully.
- if (accu(iterate != oldIterate) == 0)
- {
- Log::Debug << "L-BFGS step size of 0 (terminating successfully)."
- << std::endl;
- break;
- }
-
- // Overwrite an old basis set.
- UpdateBasisSet(itNum, iterate, oldIterate, gradient, oldGradient);
-
- } // End of the optimization loop.
-
- return function.Evaluate(iterate);
-}
-
-}; // namespace optimization
-}; // namespace mlpack
-
-#endif // __MLPACK_CORE_OPTIMIZERS_LBFGS_LBFGS_IMPL_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/lbfgs_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/lbfgs/lbfgs_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/lbfgs_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/lbfgs_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,431 @@
+/**
+ * @file lbfgs_impl.hpp
+ * @author Dongryeol Lee (dongryel at cc.gatech.edu)
+ * @author Ryan Curtin
+ *
+ * The implementation of the L_BFGS optimizer.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_OPTIMIZERS_LBFGS_LBFGS_IMPL_HPP
+#define __MLPACK_CORE_OPTIMIZERS_LBFGS_LBFGS_IMPL_HPP
+
+namespace mlpack {
+namespace optimization {
+
+/**
+ * Evaluate the function at the given iterate point and store the result if
+ * it is a new minimum.
+ *
+ * @return The value of the function
+ */
+template<typename FunctionType>
+double L_BFGS<FunctionType>::Evaluate(const arma::mat& iterate)
+{
+ // Evaluate the function and keep track of the minimum function
+ // value encountered during the optimization.
+ double functionValue = function.Evaluate(iterate);
+
+ if (functionValue < minPointIterate.second)
+ {
+ minPointIterate.first = iterate;
+ minPointIterate.second = functionValue;
+ }
+
+ return functionValue;
+}
+
+/**
+ * Calculate the scaling factor gamma which is used to scale the Hessian
+ * approximation matrix. See method M3 in Section 4 of Liu and Nocedal (1989).
+ *
+ * @return The calculated scaling factor
+ */
+template<typename FunctionType>
+double L_BFGS<FunctionType>::ChooseScalingFactor(const size_t iterationNum,
+ const arma::mat& gradient)
+{
+ double scalingFactor = 1.0;
+ if (iterationNum > 0)
+ {
+ int previousPos = (iterationNum - 1) % numBasis;
+ // Get s and y matrices once instead of multiple times.
+ arma::mat& sMat = s.slice(previousPos);
+ arma::mat& yMat = y.slice(previousPos);
+ scalingFactor = dot(sMat, yMat) / dot(yMat, yMat);
+ }
+ else
+ {
+ scalingFactor = 1.0 / sqrt(dot(gradient, gradient));
+ }
+
+ return scalingFactor;
+}
+
+/**
+ * Check to make sure that the norm of the gradient is not smaller than 1e-10.
+ * Currently that value is not configurable.
+ *
+ * @return (norm < minGradientNorm)
+ */
+template<typename FunctionType>
+bool L_BFGS<FunctionType>::GradientNormTooSmall(const arma::mat& gradient)
+{
+ double norm = arma::norm(gradient, 2);
+
+ return (norm < minGradientNorm);
+}
+
+/**
+ * Perform a back-tracking line search along the search direction to calculate a
+ * step size satisfying the Wolfe conditions.
+ *
+ * @param functionValue Value of the function at the initial point
+ * @param iterate The initial point to begin the line search from
+ * @param gradient The gradient at the initial point
+ * @param searchDirection A vector specifying the search direction
+ * @param stepSize Variable the calculated step size will be stored in
+ *
+ * @return false if no step size is suitable, true otherwise.
+ */
+template<typename FunctionType>
+bool L_BFGS<FunctionType>::LineSearch(double& functionValue,
+ arma::mat& iterate,
+ arma::mat& gradient,
+ const arma::mat& searchDirection)
+{
+ // Default first step size of 1.0.
+ double stepSize = 1.0;
+
+ // The initial linear term approximation in the direction of the
+ // search direction.
+ double initialSearchDirectionDotGradient =
+ arma::dot(gradient, searchDirection);
+
+ // If it is not a descent direction, just report failure.
+ if (initialSearchDirectionDotGradient > 0.0)
+ {
+ Log::Warn << "L-BFGS line search direction is not a descent direction "
+ << "(terminating)!" << std::endl;
+ return false;
+ }
+
+ // Save the initial function value.
+ double initialFunctionValue = functionValue;
+
+ // Unit linear approximation to the decrease in function value.
+ double linearApproxFunctionValueDecrease = armijoConstant *
+ initialSearchDirectionDotGradient;
+
+ // The number of iteration in the search.
+ size_t numIterations = 0;
+
+ // Armijo step size scaling factor for increase and decrease.
+ const double inc = 2.1;
+ const double dec = 0.5;
+ double width = 0;
+
+ while (true)
+ {
+ // Perform a step and evaluate the gradient and the function values at that
+ // point.
+ newIterateTmp = iterate;
+ newIterateTmp += stepSize * searchDirection;
+ functionValue = Evaluate(newIterateTmp);
+ function.Gradient(newIterateTmp, gradient);
+ numIterations++;
+
+ if (functionValue > initialFunctionValue + stepSize *
+ linearApproxFunctionValueDecrease)
+ {
+ width = dec;
+ }
+ else
+ {
+ // Check Wolfe's condition.
+ double searchDirectionDotGradient = arma::dot(gradient, searchDirection);
+
+ if (searchDirectionDotGradient < wolfe *
+ initialSearchDirectionDotGradient)
+ {
+ width = inc;
+ }
+ else
+ {
+ if (searchDirectionDotGradient > -wolfe *
+ initialSearchDirectionDotGradient)
+ {
+ width = dec;
+ }
+ else
+ {
+ break;
+ }
+ }
+ }
+
+ // Terminate when the step size gets too small or too big or it
+ // exceeds the max number of iterations.
+ if ((stepSize < minStep) || (stepSize > maxStep) ||
+ (numIterations >= maxLineSearchTrials))
+ {
+ return false;
+ }
+
+ // Scale the step size.
+ stepSize *= width;
+ }
+
+ // Move to the new iterate.
+ iterate = newIterateTmp;
+ return true;
+}
+
+/**
+ * Find the L_BFGS search direction.
+ *
+ * @param gradient The gradient at the current point
+ * @param iterationNum The iteration number
+ * @param scalingFactor Scaling factor to use (see ChooseScalingFactor_())
+ * @param searchDirection Vector to store search direction in
+ */
+template<typename FunctionType>
+void L_BFGS<FunctionType>::SearchDirection(const arma::mat& gradient,
+ const size_t iterationNum,
+ const double scalingFactor,
+ arma::mat& searchDirection)
+{
+ // Start from this point.
+ searchDirection = gradient;
+
+ // See "A Recursive Formula to Compute H * g" in "Updating quasi-Newton
+ // matrices with limited storage" (Nocedal, 1980).
+
+ // Temporary variables.
+ arma::vec rho(numBasis);
+ arma::vec alpha(numBasis);
+
+ size_t limit = (numBasis > iterationNum) ? 0 : (iterationNum - numBasis);
+ for (size_t i = iterationNum; i != limit; i--)
+ {
+ int translatedPosition = (i + (numBasis - 1)) % numBasis;
+ rho[iterationNum - i] = 1.0 / arma::dot(y.slice(translatedPosition),
+ s.slice(translatedPosition));
+ alpha[iterationNum - i] = rho[iterationNum - i] *
+ arma::dot(s.slice(translatedPosition), searchDirection);
+ searchDirection -= alpha[iterationNum - i] * y.slice(translatedPosition);
+ }
+
+ searchDirection *= scalingFactor;
+
+ for (size_t i = limit; i < iterationNum; i++)
+ {
+ int translatedPosition = i % numBasis;
+ double beta = rho[iterationNum - i - 1] *
+ arma::dot(y.slice(translatedPosition), searchDirection);
+ searchDirection += (alpha[iterationNum - i - 1] - beta) *
+ s.slice(translatedPosition);
+ }
+
+ // Negate the search direction so that it is a descent direction.
+ searchDirection *= -1;
+}
+
+/**
+ * Update the y and s matrices, which store the differences between
+ * the iterate and old iterate and the differences between the gradient and the
+ * old gradient, respectively.
+ *
+ * @param iterationNum Iteration number
+ * @param iterate Current point
+ * @param oldIterate Point at last iteration
+ * @param gradient Gradient at current point (iterate)
+ * @param oldGradient Gradient at last iteration point (oldIterate)
+ */
+template<typename FunctionType>
+void L_BFGS<FunctionType>::UpdateBasisSet(const size_t iterationNum,
+ const arma::mat& iterate,
+ const arma::mat& oldIterate,
+ const arma::mat& gradient,
+ const arma::mat& oldGradient)
+{
+ // Overwrite a certain position instead of pushing everything in the vector
+ // back one position.
+ int overwritePos = iterationNum % numBasis;
+ s.slice(overwritePos) = iterate - oldIterate;
+ y.slice(overwritePos) = gradient - oldGradient;
+}
+
+/**
+ * Initialize the L_BFGS object. Copy the function we will be optimizing and
+ * set the size of the memory for the algorithm.
+ *
+ * @param function Instance of function to be optimized
+ * @param numBasis Number of memory points to be stored
+ * @param armijoConstant Controls the accuracy of the line search routine for
+ * determining the Armijo condition.
+ * @param wolfe Parameter for detecting the Wolfe condition.
+ * @param minGradientNorm Minimum gradient norm required to continue the
+ * optimization.
+ * @param maxLineSearchTrials The maximum number of trials for the line search
+ * (before giving up).
+ * @param minStep The minimum step of the line search.
+ * @param maxStep The maximum step of the line search.
+ */
+template<typename FunctionType>
+L_BFGS<FunctionType>::L_BFGS(FunctionType& function,
+ const size_t numBasis,
+ const size_t maxIterations,
+ const double armijoConstant,
+ const double wolfe,
+ const double minGradientNorm,
+ const size_t maxLineSearchTrials,
+ const double minStep,
+ const double maxStep) :
+ function(function),
+ numBasis(numBasis),
+ maxIterations(maxIterations),
+ armijoConstant(armijoConstant),
+ wolfe(wolfe),
+ minGradientNorm(minGradientNorm),
+ maxLineSearchTrials(maxLineSearchTrials),
+ minStep(minStep),
+ maxStep(maxStep)
+{
+ // Get the dimensions of the coordinates of the function; GetInitialPoint()
+ // might return an arma::vec, but that's okay because then n_cols will simply
+ // be 1.
+ const size_t rows = function.GetInitialPoint().n_rows;
+ const size_t cols = function.GetInitialPoint().n_cols;
+
+ newIterateTmp.set_size(rows, cols);
+ s.set_size(rows, cols, numBasis);
+ y.set_size(rows, cols, numBasis);
+
+ // Allocate the pair holding the min iterate information.
+ minPointIterate.first.zeros(rows, cols);
+ minPointIterate.second = std::numeric_limits<double>::max();
+}
+
+/**
+ * Return the point where the lowest function value has been found.
+ *
+ * @return arma::vec representing the point and a double with the function
+ * value at that point.
+ */
+template<typename FunctionType>
+inline const std::pair<arma::mat, double>&
+L_BFGS<FunctionType>::MinPointIterate() const
+{
+ return minPointIterate;
+}
+
+template<typename FunctionType>
+inline double L_BFGS<FunctionType>::Optimize(arma::mat& iterate)
+{
+ return Optimize(iterate, maxIterations);
+}
+
+/**
+ * Use L_BFGS to optimize the given function, starting at the given iterate
+ * point and performing no more than the specified number of maximum iterations.
+ * The given starting point will be modified to store the finishing point of the
+ * algorithm.
+ *
+ * @param numIterations Maximum number of iterations to perform
+ * @param iterate Starting point (will be modified)
+ */
+template<typename FunctionType>
+double L_BFGS<FunctionType>::Optimize(arma::mat& iterate,
+ const size_t maxIterations)
+{
+ // The old iterate to be saved.
+ arma::mat oldIterate;
+ oldIterate.zeros(iterate.n_rows, iterate.n_cols);
+
+ // Whether to optimize until convergence.
+ bool optimizeUntilConvergence = (maxIterations == 0);
+
+ // The initial function value.
+ double functionValue = Evaluate(iterate);
+
+ // The gradient: the current and the old.
+ arma::mat gradient;
+ arma::mat oldGradient;
+ gradient.zeros(iterate.n_rows, iterate.n_cols);
+ oldGradient.zeros(iterate.n_rows, iterate.n_cols);
+
+ // The search direction.
+ arma::mat searchDirection;
+ searchDirection.zeros(iterate.n_rows, iterate.n_cols);
+
+ // The initial gradient value.
+ function.Gradient(iterate, gradient);
+
+ // The main optimization loop.
+ for (size_t itNum = 0; optimizeUntilConvergence || (itNum != maxIterations);
+ ++itNum)
+ {
+ Log::Debug << "L-BFGS iteration " << itNum << "; objective " <<
+ function.Evaluate(iterate) << "." << std::endl;
+
+ // Break when the norm of the gradient becomes too small.
+ if (GradientNormTooSmall(gradient))
+ {
+ Log::Debug << "L-BFGS gradient norm too small (terminating successfully)."
+ << std::endl;
+ break;
+ }
+
+ // Choose the scaling factor.
+ double scalingFactor = ChooseScalingFactor(itNum, gradient);
+
+ // Build an approximation to the Hessian and choose the search
+ // direction for the current iteration.
+ SearchDirection(gradient, itNum, scalingFactor, searchDirection);
+
+ // Save the old iterate and the gradient before stepping.
+ oldIterate = iterate;
+ oldGradient = gradient;
+
+ // Do a line search and take a step.
+ if (!LineSearch(functionValue, iterate, gradient, searchDirection))
+ {
+ Log::Debug << "Line search failed. Stopping optimization." << std::endl;
+ break; // The line search failed; nothing else to try.
+ }
+
+ // It is possible that the difference between the two coordinates is zero.
+ // In this case we terminate successfully.
+ if (accu(iterate != oldIterate) == 0)
+ {
+ Log::Debug << "L-BFGS step size of 0 (terminating successfully)."
+ << std::endl;
+ break;
+ }
+
+ // Overwrite an old basis set.
+ UpdateBasisSet(itNum, iterate, oldIterate, gradient, oldGradient);
+
+ } // End of the optimization loop.
+
+ return function.Evaluate(iterate);
+}
+
+}; // namespace optimization
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_OPTIMIZERS_LBFGS_LBFGS_IMPL_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/test_functions.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/lbfgs/test_functions.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/test_functions.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,257 +0,0 @@
-/**
- * @file test_functions.cpp
- * @author Ryan Curtin
- *
- * Implementations of the test functions defined in test_functions.hpp.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "test_functions.hpp"
-
-using namespace mlpack::optimization::test;
-
-//
-// RosenbrockFunction implementation
-//
-
-RosenbrockFunction::RosenbrockFunction()
-{
- initialPoint.set_size(2, 1);
- initialPoint[0] = -1.2;
- initialPoint[1] = 1;
-}
-
-/**
- * Calculate the objective function.
- */
-double RosenbrockFunction::Evaluate(const arma::mat& coordinates)
-{
- double x1 = coordinates[0];
- double x2 = coordinates[1];
-
- double objective = /* f1(x) */ 100 * std::pow(x2 - std::pow(x1, 2), 2) +
- /* f2(x) */ std::pow(1 - x1, 2);
-
- return objective;
-}
-
-/**
- * Calculate the gradient.
- */
-void RosenbrockFunction::Gradient(const arma::mat& coordinates,
- arma::mat& gradient)
-{
- // f'_{x1}(x) = -2 (1 - x1) + 400 (x1^3 - (x2 x1))
- // f'_{x2}(x) = 200 (x2 - x1^2)
-
- double x1 = coordinates[0];
- double x2 = coordinates[1];
-
- gradient.set_size(2, 1);
- gradient[0] = -2 * (1 - x1) + 400 * (std::pow(x1, 3) - x2 * x1);
- gradient[1] = 200 * (x2 - std::pow(x1, 2));
-}
-
-const arma::mat& RosenbrockFunction::GetInitialPoint() const
-{
- return initialPoint;
-}
-
-//
-// WoodFunction implementation
-//
-
-WoodFunction::WoodFunction()
-{
- initialPoint.set_size(4, 1);
- initialPoint[0] = -3;
- initialPoint[1] = -1;
- initialPoint[2] = -3;
- initialPoint[3] = -1;
-}
-
-/**
- * Calculate the objective function.
- */
-double WoodFunction::Evaluate(const arma::mat& coordinates)
-{
- // For convenience; we assume these temporaries will be optimized out.
- double x1 = coordinates[0];
- double x2 = coordinates[1];
- double x3 = coordinates[2];
- double x4 = coordinates[3];
-
- double objective = /* f1(x) */ 100 * std::pow(x2 - std::pow(x1, 2), 2) +
- /* f2(x) */ std::pow(1 - x1, 2) +
- /* f3(x) */ 90 * std::pow(x4 - std::pow(x3, 2), 2) +
- /* f4(x) */ std::pow(1 - x3, 2) +
- /* f5(x) */ 10 * std::pow(x2 + x4 - 2, 2) +
- /* f6(x) */ (1 / 10) * std::pow(x2 - x4, 2);
-
- return objective;
-}
-
-/**
- * Calculate the gradient.
- */
-void WoodFunction::Gradient(const arma::mat& coordinates,
- arma::mat& gradient)
-{
- // For convenience; we assume these temporaries will be optimized out.
- double x1 = coordinates[0];
- double x2 = coordinates[1];
- double x3 = coordinates[2];
- double x4 = coordinates[3];
-
- // f'_{x1}(x) = 400 (x1^3 - x2 x1) - 2 (1 - x1)
- // f'_{x2}(x) = 200 (x2 - x1^2) + 20 (x2 + x4 - 2) + (1 / 5) (x2 - x4)
- // f'_{x3}(x) = 360 (x3^3 - x4 x3) - 2 (1 - x3)
- // f'_{x4}(x) = 180 (x4 - x3^2) + 20 (x2 + x4 - 2) - (1 / 5) (x2 - x4)
- gradient.set_size(4, 1);
- gradient[0] = 400 * (std::pow(x1, 3) - x2 * x1) - 2 * (1 - x1);
- gradient[1] = 200 * (x2 - std::pow(x1, 2)) + 20 * (x2 + x4 - 2) +
- (1 / 5) * (x2 - x4);
- gradient[2] = 360 * (std::pow(x3, 3) - x4 * x3) - 2 * (1 - x3);
- gradient[3] = 180 * (x4 - std::pow(x3, 2)) + 20 * (x2 + x4 - 2) -
- (1 / 5) * (x2 - x4);
-}
-
-const arma::mat& WoodFunction::GetInitialPoint() const
-{
- return initialPoint;
-}
-
-//
-// GeneralizedRosenbrockFunction implementation
-//
-
-GeneralizedRosenbrockFunction::GeneralizedRosenbrockFunction(int n) : n(n)
-{
- initialPoint.set_size(n, 1);
- for (int i = 0; i < n; i++) // Set to [-1.2 1 -1.2 1 ...].
- {
- if (i % 2 == 1)
- initialPoint[i] = -1.2;
- else
- initialPoint[i] = 1;
- }
-}
-
-/**
- * Calculate the objective function.
- */
-double GeneralizedRosenbrockFunction::Evaluate(const arma::mat& coordinates)
- const
-{
- double fval = 0;
- for (int i = 0; i < (n - 1); i++)
- {
- fval += 100 * std::pow(std::pow(coordinates[i], 2) -
- coordinates[i + 1], 2) + std::pow(1 - coordinates[i], 2);
- }
-
- return fval;
-}
-
-/**
- * Calculate the gradient.
- */
-void GeneralizedRosenbrockFunction::Gradient(const arma::mat& coordinates,
- arma::mat& gradient) const
-{
- gradient.set_size(n);
- for (int i = 0; i < (n - 1); i++)
- {
- gradient[i] = 400 * (std::pow(coordinates[i], 3) - coordinates[i] *
- coordinates[i + 1]) + 2 * (coordinates[i] - 1);
-
- if (i > 0)
- gradient[i] += 200 * (coordinates[i] - std::pow(coordinates[i - 1], 2));
- }
-
- gradient[n - 1] = 200 * (coordinates[n - 1] -
- std::pow(coordinates[n - 2], 2));
-}
-
-//! Calculate the objective function of one of the individual functions.
-double GeneralizedRosenbrockFunction::Evaluate(const arma::mat& coordinates,
- const size_t i) const
-{
- return 100 * std::pow((std::pow(coordinates[i], 2) - coordinates[i + 1]), 2) +
- std::pow(1 - coordinates[i], 2);
-}
-
-//! Calculate the gradient of one of the individual functions.
-void GeneralizedRosenbrockFunction::Gradient(const arma::mat& coordinates,
- const size_t i,
- arma::mat& gradient) const
-{
- gradient.zeros(n);
-
- gradient[i] = 400 * (std::pow(coordinates[i], 3) - coordinates[i] *
- coordinates[i + 1]) + 2 * (coordinates[i] - 1);
- gradient[i + 1] = 200 * (coordinates[i + 1] - std::pow(coordinates[i], 2));
-}
-
-const arma::mat& GeneralizedRosenbrockFunction::GetInitialPoint() const
-{
- return initialPoint;
-}
-
-//
-// RosenbrockWoodFunction implementation
-//
-
-RosenbrockWoodFunction::RosenbrockWoodFunction() : rf(4), wf()
-{
- initialPoint.set_size(4, 2);
- initialPoint.col(0) = rf.GetInitialPoint();
- initialPoint.col(1) = wf.GetInitialPoint();
-}
-
-/**
- * Calculate the objective function.
- */
-double RosenbrockWoodFunction::Evaluate(const arma::mat& coordinates)
-{
- double objective = rf.Evaluate(coordinates.col(0)) +
- wf.Evaluate(coordinates.col(1));
-
- return objective;
-}
-
-/***
- * Calculate the gradient.
- */
-void RosenbrockWoodFunction::Gradient(const arma::mat& coordinates,
- arma::mat& gradient)
-{
- gradient.set_size(4, 2);
-
- arma::vec grf(4);
- arma::vec gwf(4);
-
- rf.Gradient(coordinates.col(0), grf);
- wf.Gradient(coordinates.col(1), gwf);
-
- gradient.col(0) = grf;
- gradient.col(1) = gwf;
-}
-
-const arma::mat& RosenbrockWoodFunction::GetInitialPoint() const
-{
- return initialPoint;
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/test_functions.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/lbfgs/test_functions.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/test_functions.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/test_functions.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,257 @@
+/**
+ * @file test_functions.cpp
+ * @author Ryan Curtin
+ *
+ * Implementations of the test functions defined in test_functions.hpp.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "test_functions.hpp"
+
+using namespace mlpack::optimization::test;
+
+//
+// RosenbrockFunction implementation
+//
+
+RosenbrockFunction::RosenbrockFunction()
+{
+ initialPoint.set_size(2, 1);
+ initialPoint[0] = -1.2;
+ initialPoint[1] = 1;
+}
+
+/**
+ * Calculate the objective function.
+ */
+double RosenbrockFunction::Evaluate(const arma::mat& coordinates)
+{
+ double x1 = coordinates[0];
+ double x2 = coordinates[1];
+
+ double objective = /* f1(x) */ 100 * std::pow(x2 - std::pow(x1, 2), 2) +
+ /* f2(x) */ std::pow(1 - x1, 2);
+
+ return objective;
+}
+
+/**
+ * Calculate the gradient.
+ */
+void RosenbrockFunction::Gradient(const arma::mat& coordinates,
+ arma::mat& gradient)
+{
+ // f'_{x1}(x) = -2 (1 - x1) + 400 (x1^3 - (x2 x1))
+ // f'_{x2}(x) = 200 (x2 - x1^2)
+
+ double x1 = coordinates[0];
+ double x2 = coordinates[1];
+
+ gradient.set_size(2, 1);
+ gradient[0] = -2 * (1 - x1) + 400 * (std::pow(x1, 3) - x2 * x1);
+ gradient[1] = 200 * (x2 - std::pow(x1, 2));
+}
+
+const arma::mat& RosenbrockFunction::GetInitialPoint() const
+{
+ return initialPoint;
+}
+
+//
+// WoodFunction implementation
+//
+
+WoodFunction::WoodFunction()
+{
+ initialPoint.set_size(4, 1);
+ initialPoint[0] = -3;
+ initialPoint[1] = -1;
+ initialPoint[2] = -3;
+ initialPoint[3] = -1;
+}
+
+/**
+ * Calculate the objective function.
+ */
+double WoodFunction::Evaluate(const arma::mat& coordinates)
+{
+ // For convenience; we assume these temporaries will be optimized out.
+ double x1 = coordinates[0];
+ double x2 = coordinates[1];
+ double x3 = coordinates[2];
+ double x4 = coordinates[3];
+
+ double objective = /* f1(x) */ 100 * std::pow(x2 - std::pow(x1, 2), 2) +
+ /* f2(x) */ std::pow(1 - x1, 2) +
+ /* f3(x) */ 90 * std::pow(x4 - std::pow(x3, 2), 2) +
+ /* f4(x) */ std::pow(1 - x3, 2) +
+ /* f5(x) */ 10 * std::pow(x2 + x4 - 2, 2) +
+ /* f6(x) */ (1 / 10) * std::pow(x2 - x4, 2);
+
+ return objective;
+}
+
+/**
+ * Calculate the gradient.
+ */
+void WoodFunction::Gradient(const arma::mat& coordinates,
+ arma::mat& gradient)
+{
+ // For convenience; we assume these temporaries will be optimized out.
+ double x1 = coordinates[0];
+ double x2 = coordinates[1];
+ double x3 = coordinates[2];
+ double x4 = coordinates[3];
+
+ // f'_{x1}(x) = 400 (x1^3 - x2 x1) - 2 (1 - x1)
+ // f'_{x2}(x) = 200 (x2 - x1^2) + 20 (x2 + x4 - 2) + (1 / 5) (x2 - x4)
+ // f'_{x3}(x) = 360 (x3^3 - x4 x3) - 2 (1 - x3)
+ // f'_{x4}(x) = 180 (x4 - x3^2) + 20 (x2 + x4 - 2) - (1 / 5) (x2 - x4)
+ gradient.set_size(4, 1);
+ gradient[0] = 400 * (std::pow(x1, 3) - x2 * x1) - 2 * (1 - x1);
+ gradient[1] = 200 * (x2 - std::pow(x1, 2)) + 20 * (x2 + x4 - 2) +
+ (1 / 5) * (x2 - x4);
+ gradient[2] = 360 * (std::pow(x3, 3) - x4 * x3) - 2 * (1 - x3);
+ gradient[3] = 180 * (x4 - std::pow(x3, 2)) + 20 * (x2 + x4 - 2) -
+ (1 / 5) * (x2 - x4);
+}
+
+const arma::mat& WoodFunction::GetInitialPoint() const
+{
+ return initialPoint;
+}
+
+//
+// GeneralizedRosenbrockFunction implementation
+//
+
+GeneralizedRosenbrockFunction::GeneralizedRosenbrockFunction(int n) : n(n)
+{
+ initialPoint.set_size(n, 1);
+ for (int i = 0; i < n; i++) // Set to [-1.2 1 -1.2 1 ...].
+ {
+ if (i % 2 == 1)
+ initialPoint[i] = -1.2;
+ else
+ initialPoint[i] = 1;
+ }
+}
+
+/**
+ * Calculate the objective function.
+ */
+double GeneralizedRosenbrockFunction::Evaluate(const arma::mat& coordinates)
+ const
+{
+ double fval = 0;
+ for (int i = 0; i < (n - 1); i++)
+ {
+ fval += 100 * std::pow(std::pow(coordinates[i], 2) -
+ coordinates[i + 1], 2) + std::pow(1 - coordinates[i], 2);
+ }
+
+ return fval;
+}
+
+/**
+ * Calculate the gradient.
+ */
+void GeneralizedRosenbrockFunction::Gradient(const arma::mat& coordinates,
+ arma::mat& gradient) const
+{
+ gradient.set_size(n);
+ for (int i = 0; i < (n - 1); i++)
+ {
+ gradient[i] = 400 * (std::pow(coordinates[i], 3) - coordinates[i] *
+ coordinates[i + 1]) + 2 * (coordinates[i] - 1);
+
+ if (i > 0)
+ gradient[i] += 200 * (coordinates[i] - std::pow(coordinates[i - 1], 2));
+ }
+
+ gradient[n - 1] = 200 * (coordinates[n - 1] -
+ std::pow(coordinates[n - 2], 2));
+}
+
+//! Calculate the objective function of one of the individual functions.
+double GeneralizedRosenbrockFunction::Evaluate(const arma::mat& coordinates,
+ const size_t i) const
+{
+ return 100 * std::pow((std::pow(coordinates[i], 2) - coordinates[i + 1]), 2) +
+ std::pow(1 - coordinates[i], 2);
+}
+
+//! Calculate the gradient of one of the individual functions.
+void GeneralizedRosenbrockFunction::Gradient(const arma::mat& coordinates,
+ const size_t i,
+ arma::mat& gradient) const
+{
+ gradient.zeros(n);
+
+ gradient[i] = 400 * (std::pow(coordinates[i], 3) - coordinates[i] *
+ coordinates[i + 1]) + 2 * (coordinates[i] - 1);
+ gradient[i + 1] = 200 * (coordinates[i + 1] - std::pow(coordinates[i], 2));
+}
+
+const arma::mat& GeneralizedRosenbrockFunction::GetInitialPoint() const
+{
+ return initialPoint;
+}
+
+//
+// RosenbrockWoodFunction implementation
+//
+
+RosenbrockWoodFunction::RosenbrockWoodFunction() : rf(4), wf()
+{
+ initialPoint.set_size(4, 2);
+ initialPoint.col(0) = rf.GetInitialPoint();
+ initialPoint.col(1) = wf.GetInitialPoint();
+}
+
+/**
+ * Calculate the objective function.
+ */
+double RosenbrockWoodFunction::Evaluate(const arma::mat& coordinates)
+{
+ double objective = rf.Evaluate(coordinates.col(0)) +
+ wf.Evaluate(coordinates.col(1));
+
+ return objective;
+}
+
+/***
+ * Calculate the gradient.
+ */
+void RosenbrockWoodFunction::Gradient(const arma::mat& coordinates,
+ arma::mat& gradient)
+{
+ gradient.set_size(4, 2);
+
+ arma::vec grf(4);
+ arma::vec gwf(4);
+
+ rf.Gradient(coordinates.col(0), grf);
+ wf.Gradient(coordinates.col(1), gwf);
+
+ gradient.col(0) = grf;
+ gradient.col(1) = gwf;
+}
+
+const arma::mat& RosenbrockWoodFunction::GetInitialPoint() const
+{
+ return initialPoint;
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/test_functions.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/lbfgs/test_functions.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/test_functions.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,174 +0,0 @@
-/**
- * @file test_functions.hpp
- * @author Ryan Curtin
- *
- * A collection of functions to test optimizers (in this case, L-BFGS). These
- * come from the following paper:
- *
- * "Testing Unconstrained Optimization Software"
- * Jorge J. Moré, Burton S. Garbow, and Kenneth E. Hillstrom. 1981.
- * ACM Trans. Math. Softw. 7, 1 (March 1981), 17-41.
- * http://portal.acm.org/citation.cfm?id=355934.355936
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_OPTIMIZERS_LBFGS_TEST_FUNCTIONS_HPP
-#define __MLPACK_CORE_OPTIMIZERS_LBFGS_TEST_FUNCTIONS_HPP
-
-#include <mlpack/core.hpp>
-
-// To fulfill the template policy class 'FunctionType', we must implement
-// the following:
-//
-// FunctionType(); // constructor
-// void Gradient(const arma::mat& coordinates, arma::mat& gradient);
-// double Evaluate(const arma::mat& coordinates);
-// const arma::mat& GetInitialPoint();
-//
-// Note that we are using an arma::mat instead of the more intuitive and
-// expected arma::vec. This is because L-BFGS will also optimize matrices.
-// However, remember that an arma::vec is simply an (n x 1) arma::mat. You can
-// use either internally but the L-BFGS method requires arma::mat& to be passed
-// (C++ does not allow implicit reference casting to subclasses).
-
-namespace mlpack {
-namespace optimization {
-namespace test {
-
-/**
- * The Rosenbrock function, defined by
- * f(x) = f1(x) + f2(x)
- * f1(x) = 100 (x2 - x1^2)^2
- * f2(x) = (1 - x1)^2
- * x_0 = [-1.2, 1]
- *
- * This should optimize to f(x) = 0, at x = [1, 1].
- *
- * "An automatic method for finding the greatest or least value of a function."
- * H.H. Rosenbrock. 1960. Comput. J. 3., 175-184.
- */
-class RosenbrockFunction
-{
- public:
- RosenbrockFunction(); // initialize initial point
-
- double Evaluate(const arma::mat& coordinates);
- void Gradient(const arma::mat& coordinates, arma::mat& gradient);
-
- const arma::mat& GetInitialPoint() const;
-
- private:
- arma::mat initialPoint;
-};
-
-/**
- * The Wood function, defined by
- * f(x) = f1(x) + f2(x) + f3(x) + f4(x) + f5(x) + f6(x)
- * f1(x) = 100 (x2 - x1^2)^2
- * f2(x) = (1 - x1)^2
- * f3(x) = 90 (x4 - x3^2)^2
- * f4(x) = (1 - x3)^2
- * f5(x) = 10 (x2 + x4 - 2)^2
- * f6(x) = (1 / 10) (x2 - x4)^2
- * x_0 = [-3, -1, -3, -1]
- *
- * This should optimize to f(x) = 0, at x = [1, 1, 1, 1].
- *
- * "A comparative study of nonlinear programming codes."
- * A.R. Colville. 1968. Rep. 320-2949, IBM N.Y. Scientific Center.
- */
-class WoodFunction
-{
- public:
- WoodFunction(); // initialize initial point
-
- double Evaluate(const arma::mat& coordinates);
- void Gradient(const arma::mat& coordinates, arma::mat& gradient);
-
- const arma::mat& GetInitialPoint() const;
-
- private:
- arma::mat initialPoint;
-};
-
-/**
- * The Generalized Rosenbrock function in n dimensions, defined by
- * f(x) = sum_i^{n - 1} (f(i)(x))
- * f_i(x) = 100 * (x_i^2 - x_{i + 1})^2 + (1 - x_i)^2
- * x_0 = [-1.2, 1, -1.2, 1, ...]
- *
- * This should optimize to f(x) = 0, at x = [1, 1, 1, 1, ...].
- *
- * This function can also be used for stochastic gradient descent (SGD) as a
- * decomposable function (DecomposableFunctionType), so there are other
- * overloads of Evaluate() and Gradient() implemented, as well as
- * NumFunctions().
- *
- * "An analysis of the behavior of a glass of genetic adaptive systems."
- * K.A. De Jong. Ph.D. thesis, University of Michigan, 1975.
- */
-class GeneralizedRosenbrockFunction
-{
- public:
- /***
- * Set the dimensionality of the extended Rosenbrock function.
- *
- * @param n Number of dimensions for the function.
- */
- GeneralizedRosenbrockFunction(int n);
-
- double Evaluate(const arma::mat& coordinates) const;
- void Gradient(const arma::mat& coordinates, arma::mat& gradient) const;
-
- size_t NumFunctions() const { return n - 1; }
- double Evaluate(const arma::mat& coordinates, const size_t i) const;
- void Gradient(const arma::mat& coordinates,
- const size_t i,
- arma::mat& gradient) const;
-
- const arma::mat& GetInitialPoint() const;
-
- private:
- arma::mat initialPoint;
- int n; // Dimensionality
-};
-
-/**
- * The Generalized Rosenbrock function in 4 dimensions with the Wood Function in
- * four dimensions. In this function we are actually optimizing a 2x4 matrix of
- * coordinates, not a vector.
- */
-class RosenbrockWoodFunction
-{
- public:
- RosenbrockWoodFunction(); // initialize initial point
-
- double Evaluate(const arma::mat& coordinates);
- void Gradient(const arma::mat& coordinates, arma::mat& gradient);
-
- const arma::mat& GetInitialPoint() const;
-
- private:
- arma::mat initialPoint;
- GeneralizedRosenbrockFunction rf;
- WoodFunction wf;
-};
-
-}; // namespace test
-}; // namespace optimization
-}; // namespace mlpack
-
-#endif // __MLPACK_CORE_OPTIMIZERS_LBFGS_TEST_FUNCTIONS_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/test_functions.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/lbfgs/test_functions.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/test_functions.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lbfgs/test_functions.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,174 @@
+/**
+ * @file test_functions.hpp
+ * @author Ryan Curtin
+ *
+ * A collection of functions to test optimizers (in this case, L-BFGS). These
+ * come from the following paper:
+ *
+ * "Testing Unconstrained Optimization Software"
+ * Jorge J. Moré, Burton S. Garbow, and Kenneth E. Hillstrom. 1981.
+ * ACM Trans. Math. Softw. 7, 1 (March 1981), 17-41.
+ * http://portal.acm.org/citation.cfm?id=355934.355936
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_OPTIMIZERS_LBFGS_TEST_FUNCTIONS_HPP
+#define __MLPACK_CORE_OPTIMIZERS_LBFGS_TEST_FUNCTIONS_HPP
+
+#include <mlpack/core.hpp>
+
+// To fulfill the template policy class 'FunctionType', we must implement
+// the following:
+//
+// FunctionType(); // constructor
+// void Gradient(const arma::mat& coordinates, arma::mat& gradient);
+// double Evaluate(const arma::mat& coordinates);
+// const arma::mat& GetInitialPoint();
+//
+// Note that we are using an arma::mat instead of the more intuitive and
+// expected arma::vec. This is because L-BFGS will also optimize matrices.
+// However, remember that an arma::vec is simply an (n x 1) arma::mat. You can
+// use either internally but the L-BFGS method requires arma::mat& to be passed
+// (C++ does not allow implicit reference casting to subclasses).
+
+namespace mlpack {
+namespace optimization {
+namespace test {
+
+/**
+ * The Rosenbrock function, defined by
+ * f(x) = f1(x) + f2(x)
+ * f1(x) = 100 (x2 - x1^2)^2
+ * f2(x) = (1 - x1)^2
+ * x_0 = [-1.2, 1]
+ *
+ * This should optimize to f(x) = 0, at x = [1, 1].
+ *
+ * "An automatic method for finding the greatest or least value of a function."
+ * H.H. Rosenbrock. 1960. Comput. J. 3., 175-184.
+ */
+class RosenbrockFunction
+{
+ public:
+ RosenbrockFunction(); // initialize initial point
+
+ double Evaluate(const arma::mat& coordinates);
+ void Gradient(const arma::mat& coordinates, arma::mat& gradient);
+
+ const arma::mat& GetInitialPoint() const;
+
+ private:
+ arma::mat initialPoint;
+};
+
+/**
+ * The Wood function, defined by
+ * f(x) = f1(x) + f2(x) + f3(x) + f4(x) + f5(x) + f6(x)
+ * f1(x) = 100 (x2 - x1^2)^2
+ * f2(x) = (1 - x1)^2
+ * f3(x) = 90 (x4 - x3^2)^2
+ * f4(x) = (1 - x3)^2
+ * f5(x) = 10 (x2 + x4 - 2)^2
+ * f6(x) = (1 / 10) (x2 - x4)^2
+ * x_0 = [-3, -1, -3, -1]
+ *
+ * This should optimize to f(x) = 0, at x = [1, 1, 1, 1].
+ *
+ * "A comparative study of nonlinear programming codes."
+ * A.R. Colville. 1968. Rep. 320-2949, IBM N.Y. Scientific Center.
+ */
+class WoodFunction
+{
+ public:
+ WoodFunction(); // initialize initial point
+
+ double Evaluate(const arma::mat& coordinates);
+ void Gradient(const arma::mat& coordinates, arma::mat& gradient);
+
+ const arma::mat& GetInitialPoint() const;
+
+ private:
+ arma::mat initialPoint;
+};
+
+/**
+ * The Generalized Rosenbrock function in n dimensions, defined by
+ * f(x) = sum_i^{n - 1} (f(i)(x))
+ * f_i(x) = 100 * (x_i^2 - x_{i + 1})^2 + (1 - x_i)^2
+ * x_0 = [-1.2, 1, -1.2, 1, ...]
+ *
+ * This should optimize to f(x) = 0, at x = [1, 1, 1, 1, ...].
+ *
+ * This function can also be used for stochastic gradient descent (SGD) as a
+ * decomposable function (DecomposableFunctionType), so there are other
+ * overloads of Evaluate() and Gradient() implemented, as well as
+ * NumFunctions().
+ *
+ * "An analysis of the behavior of a glass of genetic adaptive systems."
+ * K.A. De Jong. Ph.D. thesis, University of Michigan, 1975.
+ */
+class GeneralizedRosenbrockFunction
+{
+ public:
+ /***
+ * Set the dimensionality of the extended Rosenbrock function.
+ *
+ * @param n Number of dimensions for the function.
+ */
+ GeneralizedRosenbrockFunction(int n);
+
+ double Evaluate(const arma::mat& coordinates) const;
+ void Gradient(const arma::mat& coordinates, arma::mat& gradient) const;
+
+ size_t NumFunctions() const { return n - 1; }
+ double Evaluate(const arma::mat& coordinates, const size_t i) const;
+ void Gradient(const arma::mat& coordinates,
+ const size_t i,
+ arma::mat& gradient) const;
+
+ const arma::mat& GetInitialPoint() const;
+
+ private:
+ arma::mat initialPoint;
+ int n; // Dimensionality
+};
+
+/**
+ * The Generalized Rosenbrock function in 4 dimensions with the Wood Function in
+ * four dimensions. In this function we are actually optimizing a 2x4 matrix of
+ * coordinates, not a vector.
+ */
+class RosenbrockWoodFunction
+{
+ public:
+ RosenbrockWoodFunction(); // initialize initial point
+
+ double Evaluate(const arma::mat& coordinates);
+ void Gradient(const arma::mat& coordinates, arma::mat& gradient);
+
+ const arma::mat& GetInitialPoint() const;
+
+ private:
+ arma::mat initialPoint;
+ GeneralizedRosenbrockFunction rf;
+ WoodFunction wf;
+};
+
+}; // namespace test
+}; // namespace optimization
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_OPTIMIZERS_LBFGS_TEST_FUNCTIONS_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lrsdp/lrsdp.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/lrsdp/lrsdp.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lrsdp/lrsdp.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,154 +0,0 @@
-/**
- * @file lrsdp.hpp
- * @author Ryan Curtin
- *
- * An implementation of Monteiro and Burer's formulation of low-rank
- * semidefinite programs (LR-SDP).
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_OPTIMIZERS_LRSDP_LRSDP_HPP
-#define __MLPACK_CORE_OPTIMIZERS_LRSDP_LRSDP_HPP
-
-#include <mlpack/core.hpp>
-#include <mlpack/core/optimizers/aug_lagrangian/aug_lagrangian.hpp>
-
-namespace mlpack {
-namespace optimization {
-
-class LRSDP
-{
- public:
- /**
- * Create an LRSDP to be optimized. The solution will end up being a matrix
- * of size (rank) x (rows). To construct each constraint and the objective
- * function, use the functions A(), B(), and C() to set them correctly.
- *
- * @param numConstraints Number of constraints in the problem.
- * @param rank Rank of the solution (<= rows).
- * @param rows Number of rows in the solution.
- */
- LRSDP(const size_t numConstraints,
- const arma::mat& initialPoint);
-
- /**
- * Create an LRSDP to be optimized, passing in an already-created
- * AugLagrangian object. The given initial point should be set to the size
- * (rows) x (rank), where (rank) is the reduced rank of the problem.
- *
- * @param numConstraints Number of constraints in the problem.
- * @param initialPoint Initial point of the optimization.
- * @param auglag Pre-initialized AugLagrangian<LRSDP> object.
- */
- LRSDP(const size_t numConstraints,
- const arma::mat& initialPoint,
- AugLagrangian<LRSDP>& augLagrangian);
-
- /**
- * Optimize the LRSDP and return the final objective value. The given
- * coordinates will be modified to contain the final solution.
- *
- * @param coordinates Starting coordinates for the optimization.
- */
- double Optimize(arma::mat& coordinates);
-
- /**
- * Evaluate the objective function of the LRSDP (no constraints) at the given
- * coordinates. This is used by AugLagrangian<LRSDP>.
- */
- double Evaluate(const arma::mat& coordinates) const;
-
- /**
- * Evaluate the gradient of the LRSDP (no constraints) at the given
- * coordinates. This is used by AugLagrangian<LRSDP>.
- */
- void Gradient(const arma::mat& coordinates, arma::mat& gradient) const;
-
- /**
- * Evaluate a particular constraint of the LRSDP at the given coordinates.
- */
- double EvaluateConstraint(const size_t index,
- const arma::mat& coordinates) const;
-
- /**
- * Evaluate the gradient of a particular constraint of the LRSDP at the given
- * coordinates.
- */
- void GradientConstraint(const size_t index,
- const arma::mat& coordinates,
- arma::mat& gradient) const;
-
- //! Get the number of constraints in the LRSDP.
- size_t NumConstraints() const { return b.n_elem; }
-
- //! Get the initial point of the LRSDP.
- const arma::mat& GetInitialPoint();
-
- //! Return the objective function matrix (C).
- const arma::mat& C() const { return c; }
- //! Modify the objective function matrix (C).
- arma::mat& C() { return c; }
-
- //! Return the vector of A matrices (which correspond to the constraints).
- const std::vector<arma::mat>& A() const { return a; }
- //! Modify the veector of A matrices (which correspond to the constraints).
- std::vector<arma::mat>& A() { return a; }
-
- //! Return the vector of modes for the A matrices.
- const arma::uvec& AModes() const { return aModes; }
- //! Modify the vector of modes for the A matrices.
- arma::uvec& AModes() { return aModes; }
-
- //! Return the vector of B values.
- const arma::vec& B() const { return b; }
- //! Modify the vector of B values.
- arma::vec& B() { return b; }
-
- //! Return the augmented Lagrangian object.
- const AugLagrangian<LRSDP>& AugLag() const { return augLag; }
- //! Modify the augmented Lagrangian object.
- AugLagrangian<LRSDP>& AugLag() { return augLag; }
-
- private:
- // Should probably use sparse matrices for some of these.
-
- //! For objective function.
- arma::mat c;
- //! A_i for each constraint.
- std::vector<arma::mat> a;
- //! b_i for each constraint.
- arma::vec b;
-
- //! 1 if entries in matrix, 0 for normal.
- arma::uvec aModes;
-
- //! Initial point.
- arma::mat initialPoint;
-
- //! Internal AugLagrangian object, if one was not passed at construction time.
- AugLagrangian<LRSDP> augLagInternal;
-
- //! The AugLagrangian object which will be used for optimization.
- AugLagrangian<LRSDP>& augLag;
-};
-
-}; // namespace optimization
-}; // namespace mlpack
-
-// Include implementation.
-#include "lrsdp_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lrsdp/lrsdp.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/lrsdp/lrsdp.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lrsdp/lrsdp.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lrsdp/lrsdp.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,154 @@
+/**
+ * @file lrsdp.hpp
+ * @author Ryan Curtin
+ *
+ * An implementation of Monteiro and Burer's formulation of low-rank
+ * semidefinite programs (LR-SDP).
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_OPTIMIZERS_LRSDP_LRSDP_HPP
+#define __MLPACK_CORE_OPTIMIZERS_LRSDP_LRSDP_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/optimizers/aug_lagrangian/aug_lagrangian.hpp>
+
+namespace mlpack {
+namespace optimization {
+
+class LRSDP
+{
+ public:
+ /**
+ * Create an LRSDP to be optimized. The solution will end up being a matrix
+ * of size (rank) x (rows). To construct each constraint and the objective
+ * function, use the functions A(), B(), and C() to set them correctly.
+ *
+ * @param numConstraints Number of constraints in the problem.
+ * @param rank Rank of the solution (<= rows).
+ * @param rows Number of rows in the solution.
+ */
+ LRSDP(const size_t numConstraints,
+ const arma::mat& initialPoint);
+
+ /**
+ * Create an LRSDP to be optimized, passing in an already-created
+ * AugLagrangian object. The given initial point should be set to the size
+ * (rows) x (rank), where (rank) is the reduced rank of the problem.
+ *
+ * @param numConstraints Number of constraints in the problem.
+ * @param initialPoint Initial point of the optimization.
+ * @param auglag Pre-initialized AugLagrangian<LRSDP> object.
+ */
+ LRSDP(const size_t numConstraints,
+ const arma::mat& initialPoint,
+ AugLagrangian<LRSDP>& augLagrangian);
+
+ /**
+ * Optimize the LRSDP and return the final objective value. The given
+ * coordinates will be modified to contain the final solution.
+ *
+ * @param coordinates Starting coordinates for the optimization.
+ */
+ double Optimize(arma::mat& coordinates);
+
+ /**
+ * Evaluate the objective function of the LRSDP (no constraints) at the given
+ * coordinates. This is used by AugLagrangian<LRSDP>.
+ */
+ double Evaluate(const arma::mat& coordinates) const;
+
+ /**
+ * Evaluate the gradient of the LRSDP (no constraints) at the given
+ * coordinates. This is used by AugLagrangian<LRSDP>.
+ */
+ void Gradient(const arma::mat& coordinates, arma::mat& gradient) const;
+
+ /**
+ * Evaluate a particular constraint of the LRSDP at the given coordinates.
+ */
+ double EvaluateConstraint(const size_t index,
+ const arma::mat& coordinates) const;
+
+ /**
+ * Evaluate the gradient of a particular constraint of the LRSDP at the given
+ * coordinates.
+ */
+ void GradientConstraint(const size_t index,
+ const arma::mat& coordinates,
+ arma::mat& gradient) const;
+
+ //! Get the number of constraints in the LRSDP.
+ size_t NumConstraints() const { return b.n_elem; }
+
+ //! Get the initial point of the LRSDP.
+ const arma::mat& GetInitialPoint();
+
+ //! Return the objective function matrix (C).
+ const arma::mat& C() const { return c; }
+ //! Modify the objective function matrix (C).
+ arma::mat& C() { return c; }
+
+ //! Return the vector of A matrices (which correspond to the constraints).
+ const std::vector<arma::mat>& A() const { return a; }
+ //! Modify the veector of A matrices (which correspond to the constraints).
+ std::vector<arma::mat>& A() { return a; }
+
+ //! Return the vector of modes for the A matrices.
+ const arma::uvec& AModes() const { return aModes; }
+ //! Modify the vector of modes for the A matrices.
+ arma::uvec& AModes() { return aModes; }
+
+ //! Return the vector of B values.
+ const arma::vec& B() const { return b; }
+ //! Modify the vector of B values.
+ arma::vec& B() { return b; }
+
+ //! Return the augmented Lagrangian object.
+ const AugLagrangian<LRSDP>& AugLag() const { return augLag; }
+ //! Modify the augmented Lagrangian object.
+ AugLagrangian<LRSDP>& AugLag() { return augLag; }
+
+ private:
+ // Should probably use sparse matrices for some of these.
+
+ //! For objective function.
+ arma::mat c;
+ //! A_i for each constraint.
+ std::vector<arma::mat> a;
+ //! b_i for each constraint.
+ arma::vec b;
+
+ //! 1 if entries in matrix, 0 for normal.
+ arma::uvec aModes;
+
+ //! Initial point.
+ arma::mat initialPoint;
+
+ //! Internal AugLagrangian object, if one was not passed at construction time.
+ AugLagrangian<LRSDP> augLagInternal;
+
+ //! The AugLagrangian object which will be used for optimization.
+ AugLagrangian<LRSDP>& augLag;
+};
+
+}; // namespace optimization
+}; // namespace mlpack
+
+// Include implementation.
+#include "lrsdp_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lrsdp/lrsdp_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/lrsdp/lrsdp_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lrsdp/lrsdp_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,192 +0,0 @@
-/**
- * @file lrsdp_impl.hpp
- * @author Ryan Curtin
- *
- * An implementation of Monteiro and Burer's formulation of low-rank
- * semidefinite programs (LR-SDP).
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_OPTIMIZERS_LRSDP_LRSDP_IMPL_HPP
-#define __MLPACK_CORE_OPTIMIZERS_LRSDP_LRSDP_IMPL_HPP
-
-// In case it hasn't already been included.
-#include "lrsdp.hpp"
-
-namespace mlpack {
-namespace optimization {
-
-LRSDP::LRSDP(const size_t numConstraints,
- const arma::mat& initialPoint) :
- a(numConstraints),
- b(numConstraints),
- aModes(numConstraints),
- initialPoint(initialPoint),
- augLagInternal(*this),
- augLag(augLagInternal)
-{ }
-
-LRSDP::LRSDP(const size_t numConstraints,
- const arma::mat& initialPoint,
- AugLagrangian<LRSDP>& augLag) :
- a(numConstraints),
- b(numConstraints),
- aModes(numConstraints),
- initialPoint(initialPoint),
- augLagInternal(*this),
- augLag(augLag)
-{ }
-
-double LRSDP::Optimize(arma::mat& coordinates)
-{
- augLag.Sigma() = 20;
- augLag.Optimize(coordinates, 1000);
-
- return Evaluate(coordinates);
-}
-
-double LRSDP::Evaluate(const arma::mat& coordinates) const
-{
- return -accu(coordinates * trans(coordinates));
-}
-
-void LRSDP::Gradient(const arma::mat& /*coordinates*/,
- arma::mat& /*gradient*/) const
-{
- Log::Fatal << "LRSDP::Gradient() called! Uh-oh..." << std::endl;
-}
-
-double LRSDP::EvaluateConstraint(const size_t index,
- const arma::mat& coordinates) const
-{
- arma::mat rrt = coordinates * trans(coordinates);
- if (aModes[index] == 0)
- return trace(a[index] * rrt) - b[index];
- else
- {
- double value = -b[index];
- for (size_t i = 0; i < a[index].n_cols; ++i)
- value += a[index](2, i) * rrt(a[index](0, i), a[index](1, i));
-
- return value;
- }
-}
-
-void LRSDP::GradientConstraint(const size_t /*index*/,
- const arma::mat& /*coordinates*/,
- arma::mat& /*gradient*/) const
-{
- Log::Fatal << "LRSDP::GradientConstraint() called! Uh-oh..." << std::endl;
-}
-
-const arma::mat& LRSDP::GetInitialPoint()
-{
- return initialPoint;
-}
-
-// Custom specializations of the AugmentedLagrangianFunction for the LRSDP case.
-template<>
-double AugLagrangianFunction<LRSDP>::Evaluate(const arma::mat& coordinates)
- const
-{
- // We can calculate the entire objective in a smart way.
- // L(R, y, s) = Tr(C * (R R^T)) -
- // sum_{i = 1}^{m} (y_i (Tr(A_i * (R R^T)) - b_i)) +
- // (sigma / 2) * sum_{i = 1}^{m} (Tr(A_i * (R R^T)) - b_i)^2
-
- // Let's start with the objective: Tr(C * (R R^T)).
- // Simple, possibly slow solution.
- arma::mat rrt = coordinates * trans(coordinates);
- double objective = trace(function.C() * rrt);
-
- // Now each constraint.
- for (size_t i = 0; i < function.B().n_elem; ++i)
- {
- // Take the trace subtracted by the b_i.
- double constraint = -function.B()[i];
-
- if (function.AModes()[i] == 0)
- {
- constraint += trace(function.A()[i] * rrt);
- }
- else
- {
- for (size_t j = 0; j < function.A()[i].n_cols; ++j)
- {
- constraint += function.A()[i](2, j) *
- rrt(function.A()[i](0, j), function.A()[i](1, j));
- }
- }
-
- objective -= (lambda[i] * constraint);
- objective += (sigma / 2) * std::pow(constraint, 2.0);
- }
-
- return objective;
-}
-
-template<>
-void AugLagrangianFunction<LRSDP>::Gradient(const arma::mat& coordinates,
- arma::mat& gradient) const
-{
- // We can calculate the gradient in a smart way.
- // L'(R, y, s) = 2 * S' * R
- // with
- // S' = C - sum_{i = 1}^{m} y'_i A_i
- // y'_i = y_i - sigma * (Trace(A_i * (R R^T)) - b_i)
- arma::mat rrt = coordinates * trans(coordinates);
- arma::mat s = function.C();
-
- for (size_t i = 0; i < function.B().n_elem; ++i)
- {
- double constraint = -function.B()[i];
-
- if (function.AModes()[i] == 0)
- {
- constraint += trace(function.A()[i] * rrt);
- }
- else
- {
- for (size_t j = 0; j < function.A()[i].n_cols; ++j)
- {
- constraint += function.A()[i](2, j) *
- rrt(function.A()[i](0, j), function.A()[i](1, j));
- }
- }
-
- double y = lambda[i] - sigma * constraint;
-
- if (function.AModes()[i] == 0)
- {
- s -= (y * function.A()[i]);
- }
- else
- {
- // We only need to subtract the entries which could be modified.
- for (size_t j = 0; j < function.A()[i].n_cols; ++j)
- {
- s(function.A()[i](0, j), function.A()[i](1, j)) -= y;
- }
- }
- }
-
- gradient = 2 * s * coordinates;
-}
-
-}; // namespace optimization
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lrsdp/lrsdp_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/lrsdp/lrsdp_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lrsdp/lrsdp_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/lrsdp/lrsdp_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,192 @@
+/**
+ * @file lrsdp_impl.hpp
+ * @author Ryan Curtin
+ *
+ * An implementation of Monteiro and Burer's formulation of low-rank
+ * semidefinite programs (LR-SDP).
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_OPTIMIZERS_LRSDP_LRSDP_IMPL_HPP
+#define __MLPACK_CORE_OPTIMIZERS_LRSDP_LRSDP_IMPL_HPP
+
+// In case it hasn't already been included.
+#include "lrsdp.hpp"
+
+namespace mlpack {
+namespace optimization {
+
+LRSDP::LRSDP(const size_t numConstraints,
+ const arma::mat& initialPoint) :
+ a(numConstraints),
+ b(numConstraints),
+ aModes(numConstraints),
+ initialPoint(initialPoint),
+ augLagInternal(*this),
+ augLag(augLagInternal)
+{ }
+
+LRSDP::LRSDP(const size_t numConstraints,
+ const arma::mat& initialPoint,
+ AugLagrangian<LRSDP>& augLag) :
+ a(numConstraints),
+ b(numConstraints),
+ aModes(numConstraints),
+ initialPoint(initialPoint),
+ augLagInternal(*this),
+ augLag(augLag)
+{ }
+
+double LRSDP::Optimize(arma::mat& coordinates)
+{
+ augLag.Sigma() = 20;
+ augLag.Optimize(coordinates, 1000);
+
+ return Evaluate(coordinates);
+}
+
+double LRSDP::Evaluate(const arma::mat& coordinates) const
+{
+ return -accu(coordinates * trans(coordinates));
+}
+
+void LRSDP::Gradient(const arma::mat& /*coordinates*/,
+ arma::mat& /*gradient*/) const
+{
+ Log::Fatal << "LRSDP::Gradient() called! Uh-oh..." << std::endl;
+}
+
+double LRSDP::EvaluateConstraint(const size_t index,
+ const arma::mat& coordinates) const
+{
+ arma::mat rrt = coordinates * trans(coordinates);
+ if (aModes[index] == 0)
+ return trace(a[index] * rrt) - b[index];
+ else
+ {
+ double value = -b[index];
+ for (size_t i = 0; i < a[index].n_cols; ++i)
+ value += a[index](2, i) * rrt(a[index](0, i), a[index](1, i));
+
+ return value;
+ }
+}
+
+void LRSDP::GradientConstraint(const size_t /*index*/,
+ const arma::mat& /*coordinates*/,
+ arma::mat& /*gradient*/) const
+{
+ Log::Fatal << "LRSDP::GradientConstraint() called! Uh-oh..." << std::endl;
+}
+
+const arma::mat& LRSDP::GetInitialPoint()
+{
+ return initialPoint;
+}
+
+// Custom specializations of the AugmentedLagrangianFunction for the LRSDP case.
+template<>
+double AugLagrangianFunction<LRSDP>::Evaluate(const arma::mat& coordinates)
+ const
+{
+ // We can calculate the entire objective in a smart way.
+ // L(R, y, s) = Tr(C * (R R^T)) -
+ // sum_{i = 1}^{m} (y_i (Tr(A_i * (R R^T)) - b_i)) +
+ // (sigma / 2) * sum_{i = 1}^{m} (Tr(A_i * (R R^T)) - b_i)^2
+
+ // Let's start with the objective: Tr(C * (R R^T)).
+ // Simple, possibly slow solution.
+ arma::mat rrt = coordinates * trans(coordinates);
+ double objective = trace(function.C() * rrt);
+
+ // Now each constraint.
+ for (size_t i = 0; i < function.B().n_elem; ++i)
+ {
+ // Take the trace subtracted by the b_i.
+ double constraint = -function.B()[i];
+
+ if (function.AModes()[i] == 0)
+ {
+ constraint += trace(function.A()[i] * rrt);
+ }
+ else
+ {
+ for (size_t j = 0; j < function.A()[i].n_cols; ++j)
+ {
+ constraint += function.A()[i](2, j) *
+ rrt(function.A()[i](0, j), function.A()[i](1, j));
+ }
+ }
+
+ objective -= (lambda[i] * constraint);
+ objective += (sigma / 2) * std::pow(constraint, 2.0);
+ }
+
+ return objective;
+}
+
+template<>
+void AugLagrangianFunction<LRSDP>::Gradient(const arma::mat& coordinates,
+ arma::mat& gradient) const
+{
+ // We can calculate the gradient in a smart way.
+ // L'(R, y, s) = 2 * S' * R
+ // with
+ // S' = C - sum_{i = 1}^{m} y'_i A_i
+ // y'_i = y_i - sigma * (Trace(A_i * (R R^T)) - b_i)
+ arma::mat rrt = coordinates * trans(coordinates);
+ arma::mat s = function.C();
+
+ for (size_t i = 0; i < function.B().n_elem; ++i)
+ {
+ double constraint = -function.B()[i];
+
+ if (function.AModes()[i] == 0)
+ {
+ constraint += trace(function.A()[i] * rrt);
+ }
+ else
+ {
+ for (size_t j = 0; j < function.A()[i].n_cols; ++j)
+ {
+ constraint += function.A()[i](2, j) *
+ rrt(function.A()[i](0, j), function.A()[i](1, j));
+ }
+ }
+
+ double y = lambda[i] - sigma * constraint;
+
+ if (function.AModes()[i] == 0)
+ {
+ s -= (y * function.A()[i]);
+ }
+ else
+ {
+ // We only need to subtract the entries which could be modified.
+ for (size_t j = 0; j < function.A()[i].n_cols; ++j)
+ {
+ s(function.A()[i](0, j), function.A()[i](1, j)) -= y;
+ }
+ }
+ }
+
+ gradient = 2 * s * coordinates;
+}
+
+}; // namespace optimization
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/sgd.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/sgd/sgd.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/sgd.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,163 +0,0 @@
-/**
- * @file sgd.hpp
- * @author Ryan Curtin
- *
- * Stochastic Gradient Descent (SGD).
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_OPTIMIZERS_SGD_SGD_HPP
-#define __MLPACK_CORE_OPTIMIZERS_SGD_SGD_HPP
-
-namespace mlpack {
-namespace optimization {
-
-/**
- * Stochastic Gradient Descent is a technique for minimizing a function which
- * can be expressed as a sum of other functions. That is, suppose we have
- *
- * \f[
- * f(A) = \sum_{i = 0}^{n} f_i(A)
- * \f]
- *
- * and our task is to minimize \f$ A \f$. Stochastic gradient descent iterates
- * over each function \f$ f_i(A) \f$, producing the following update scheme:
- *
- * \f[
- * A_{j + 1} = A_j + \alpha \nabla f_i(A)
- * \f]
- *
- * where \f$ \alpha \f$ is a parameter which specifies the step size. \f$ i \f$
- * is chosen according to \f$ j \f$ (the iteration number). The SGD class
- * supports either scanning through each of the \f$ n \f$ functions \f$ f_i(A)
- * \f$ linearly, or in a random sequence. The algorithm continues until \f$ j
- * \f$ reaches the maximum number of iterations -- or when a full sequence of
- * updates through each of the \f$ n \f$ functions \f$ f_i(A) \f$ produces an
- * improvement within a certain tolerance \f$ \epsilon \f$. That is,
- *
- * \f[
- * | f(A_{j + n}) - f(A_j) | < \epsilon.
- * \f]
- *
- * The parameter \f$\epsilon\f$ is specified by the tolerance parameter to the
- * constructor; \f$n\f$ is specified by the maxIterations parameter.
- *
- * This class is useful for data-dependent functions whose objective function
- * can be expressed as a sum of objective functions operating on an individual
- * point. Then, SGD considers the gradient of the objective function operating
- * on an individual point in its update of \f$ A \f$.
- *
- * For SGD to work, a DecomposableFunctionType template parameter is required.
- * This class must implement the following function:
- *
- * size_t NumFunctions();
- * double Evaluate(const arma::mat& coordinates, const size_t i);
- * void Gradient(const arma::mat& coordinates,
- * const size_t i,
- * arma::mat& gradient);
- *
- * NumFunctions() should return the number of functions (\f$n\f$), and in the
- * other two functions, the parameter i refers to which individual function (or
- * gradient) is being evaluated. So, for the case of a data-dependent function,
- * such as NCA (see mlpack::nca::NCA), NumFunctions() should return the number
- * of points in the dataset, and Evaluate(coordinates, 0) will evaluate the
- * objective function on the first point in the dataset (presumably, the dataset
- * is held internally in the DecomposableFunctionType).
- *
- * @tparam DecomposableFunctionType Decomposable objective function type to be
- * minimized.
- */
-template<typename DecomposableFunctionType>
-class SGD
-{
- public:
- /**
- * Construct the SGD optimizer with the given function and parameters.
- *
- * @param function Function to be optimized (minimized).
- * @param stepSize Step size for each iteration.
- * @param maxIterations Maximum number of iterations allowed (0 means no
- * limit).
- * @param tolerance Maximum absolute tolerance to terminate algorithm.
- * @param shuffle If true, the function order is shuffled; otherwise, each
- * function is visited in linear order.
- */
- SGD(DecomposableFunctionType& function,
- const double stepSize = 0.01,
- const size_t maxIterations = 100000,
- const double tolerance = 1e-5,
- const bool shuffle = true);
-
- /**
- * Optimize the given function using stochastic gradient descent. The given
- * starting point will be modified to store the finishing point of the
- * algorithm, and the final objective value is returned.
- *
- * @param iterate Starting point (will be modified).
- * @return Objective value of the final point.
- */
- double Optimize(arma::mat& iterate);
-
- //! Get the instantiated function to be optimized.
- const DecomposableFunctionType& Function() const { return function; }
- //! Modify the instantiated function.
- DecomposableFunctionType& Function() { return function; }
-
- //! Get the step size.
- double StepSize() const { return stepSize; }
- //! Modify the step size.
- double& StepSize() { return stepSize; }
-
- //! Get the maximum number of iterations (0 indicates no limit).
- size_t MaxIterations() const { return maxIterations; }
- //! Modify the maximum number of iterations (0 indicates no limit).
- size_t& MaxIterations() { return maxIterations; }
-
- //! Get the tolerance for termination.
- double Tolerance() const { return tolerance; }
- //! Modify the tolerance for termination.
- double& Tolerance() { return tolerance; }
-
- //! Get whether or not the individual functions are shuffled.
- bool Shuffle() const { return shuffle; }
- //! Modify whether or not the individual functions are shuffled.
- bool& Shuffle() { return shuffle; }
-
- private:
- //! The instantiated function.
- DecomposableFunctionType& function;
-
- //! The step size for each example.
- double stepSize;
-
- //! The maximum number of allowed iterations.
- size_t maxIterations;
-
- //! The tolerance for termination.
- double tolerance;
-
- //! Controls whether or not the individual functions are shuffled when
- //! iterating.
- bool shuffle;
-};
-
-}; // namespace optimization
-}; // namespace mlpack
-
-// Include implementation.
-#include "sgd_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/sgd.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/sgd/sgd.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/sgd.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/sgd.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,163 @@
+/**
+ * @file sgd.hpp
+ * @author Ryan Curtin
+ *
+ * Stochastic Gradient Descent (SGD).
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_OPTIMIZERS_SGD_SGD_HPP
+#define __MLPACK_CORE_OPTIMIZERS_SGD_SGD_HPP
+
+namespace mlpack {
+namespace optimization {
+
+/**
+ * Stochastic Gradient Descent is a technique for minimizing a function which
+ * can be expressed as a sum of other functions. That is, suppose we have
+ *
+ * \f[
+ * f(A) = \sum_{i = 0}^{n} f_i(A)
+ * \f]
+ *
+ * and our task is to minimize \f$ A \f$. Stochastic gradient descent iterates
+ * over each function \f$ f_i(A) \f$, producing the following update scheme:
+ *
+ * \f[
+ * A_{j + 1} = A_j + \alpha \nabla f_i(A)
+ * \f]
+ *
+ * where \f$ \alpha \f$ is a parameter which specifies the step size. \f$ i \f$
+ * is chosen according to \f$ j \f$ (the iteration number). The SGD class
+ * supports either scanning through each of the \f$ n \f$ functions \f$ f_i(A)
+ * \f$ linearly, or in a random sequence. The algorithm continues until \f$ j
+ * \f$ reaches the maximum number of iterations -- or when a full sequence of
+ * updates through each of the \f$ n \f$ functions \f$ f_i(A) \f$ produces an
+ * improvement within a certain tolerance \f$ \epsilon \f$. That is,
+ *
+ * \f[
+ * | f(A_{j + n}) - f(A_j) | < \epsilon.
+ * \f]
+ *
+ * The parameter \f$\epsilon\f$ is specified by the tolerance parameter to the
+ * constructor; \f$n\f$ is specified by the maxIterations parameter.
+ *
+ * This class is useful for data-dependent functions whose objective function
+ * can be expressed as a sum of objective functions operating on an individual
+ * point. Then, SGD considers the gradient of the objective function operating
+ * on an individual point in its update of \f$ A \f$.
+ *
+ * For SGD to work, a DecomposableFunctionType template parameter is required.
+ * This class must implement the following function:
+ *
+ * size_t NumFunctions();
+ * double Evaluate(const arma::mat& coordinates, const size_t i);
+ * void Gradient(const arma::mat& coordinates,
+ * const size_t i,
+ * arma::mat& gradient);
+ *
+ * NumFunctions() should return the number of functions (\f$n\f$), and in the
+ * other two functions, the parameter i refers to which individual function (or
+ * gradient) is being evaluated. So, for the case of a data-dependent function,
+ * such as NCA (see mlpack::nca::NCA), NumFunctions() should return the number
+ * of points in the dataset, and Evaluate(coordinates, 0) will evaluate the
+ * objective function on the first point in the dataset (presumably, the dataset
+ * is held internally in the DecomposableFunctionType).
+ *
+ * @tparam DecomposableFunctionType Decomposable objective function type to be
+ * minimized.
+ */
+template<typename DecomposableFunctionType>
+class SGD
+{
+ public:
+ /**
+ * Construct the SGD optimizer with the given function and parameters.
+ *
+ * @param function Function to be optimized (minimized).
+ * @param stepSize Step size for each iteration.
+ * @param maxIterations Maximum number of iterations allowed (0 means no
+ * limit).
+ * @param tolerance Maximum absolute tolerance to terminate algorithm.
+ * @param shuffle If true, the function order is shuffled; otherwise, each
+ * function is visited in linear order.
+ */
+ SGD(DecomposableFunctionType& function,
+ const double stepSize = 0.01,
+ const size_t maxIterations = 100000,
+ const double tolerance = 1e-5,
+ const bool shuffle = true);
+
+ /**
+ * Optimize the given function using stochastic gradient descent. The given
+ * starting point will be modified to store the finishing point of the
+ * algorithm, and the final objective value is returned.
+ *
+ * @param iterate Starting point (will be modified).
+ * @return Objective value of the final point.
+ */
+ double Optimize(arma::mat& iterate);
+
+ //! Get the instantiated function to be optimized.
+ const DecomposableFunctionType& Function() const { return function; }
+ //! Modify the instantiated function.
+ DecomposableFunctionType& Function() { return function; }
+
+ //! Get the step size.
+ double StepSize() const { return stepSize; }
+ //! Modify the step size.
+ double& StepSize() { return stepSize; }
+
+ //! Get the maximum number of iterations (0 indicates no limit).
+ size_t MaxIterations() const { return maxIterations; }
+ //! Modify the maximum number of iterations (0 indicates no limit).
+ size_t& MaxIterations() { return maxIterations; }
+
+ //! Get the tolerance for termination.
+ double Tolerance() const { return tolerance; }
+ //! Modify the tolerance for termination.
+ double& Tolerance() { return tolerance; }
+
+ //! Get whether or not the individual functions are shuffled.
+ bool Shuffle() const { return shuffle; }
+ //! Modify whether or not the individual functions are shuffled.
+ bool& Shuffle() { return shuffle; }
+
+ private:
+ //! The instantiated function.
+ DecomposableFunctionType& function;
+
+ //! The step size for each example.
+ double stepSize;
+
+ //! The maximum number of allowed iterations.
+ size_t maxIterations;
+
+ //! The tolerance for termination.
+ double tolerance;
+
+ //! Controls whether or not the individual functions are shuffled when
+ //! iterating.
+ bool shuffle;
+};
+
+}; // namespace optimization
+}; // namespace mlpack
+
+// Include implementation.
+#include "sgd_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/sgd_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/sgd/sgd_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/sgd_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,118 +0,0 @@
-/**
- * @file sgd_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of stochastic gradient descent.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_OPTIMIZERS_SGD_SGD_IMPL_HPP
-#define __MLPACK_CORE_OPTIMIZERS_SGD_SGD_IMPL_HPP
-
-// In case it hasn't been included yet.
-#include "sgd.hpp"
-
-namespace mlpack {
-namespace optimization {
-
-template<typename DecomposableFunctionType>
-SGD<DecomposableFunctionType>::SGD(DecomposableFunctionType& function,
- const double stepSize,
- const size_t maxIterations,
- const double tolerance,
- const bool shuffle) :
- function(function),
- stepSize(stepSize),
- maxIterations(maxIterations),
- tolerance(tolerance),
- shuffle(shuffle)
-{ /* Nothing to do. */ }
-
-//! Optimize the function (minimize).
-template<typename DecomposableFunctionType>
-double SGD<DecomposableFunctionType>::Optimize(arma::mat& iterate)
-{
- // Find the number of functions to use.
- const size_t numFunctions = function.NumFunctions();
-
- // This is used only if shuffle is true.
- arma::vec visitationOrder;
- if (shuffle)
- visitationOrder = arma::shuffle(arma::linspace(0, (numFunctions - 1),
- numFunctions));
-
- // To keep track of where we are and how things are going.
- size_t currentFunction = 0;
- double overallObjective = 0;
- double lastObjective = DBL_MAX;
-
- // Calculate the first objective function.
- for (size_t i = 0; i < numFunctions; ++i)
- overallObjective += function.Evaluate(iterate, i);
-
- // Now iterate!
- arma::mat gradient(iterate.n_rows, iterate.n_cols);
- for (size_t i = 1; i != maxIterations; ++i, ++currentFunction)
- {
- // Is this iteration the start of a sequence?
- if ((currentFunction % numFunctions) == 0)
- {
- // Output current objective function.
- Log::Info << "SGD: iteration " << i << ", objective " << overallObjective
- << "." << std::endl;
-
- if (overallObjective != overallObjective)
- {
- Log::Warn << "SGD: converged to " << overallObjective << "; terminating"
- << " with failure. Try a smaller step size?" << std::endl;
- return overallObjective;
- }
-
- if (std::abs(lastObjective - overallObjective) < tolerance)
- {
- Log::Info << "SGD: minimized within tolerance " << tolerance << "; "
- << "terminating optimization." << std::endl;
- return overallObjective;
- }
-
- // Reset the counter variables.
- lastObjective = overallObjective;
- overallObjective = 0;
- currentFunction = 0;
-
- if (shuffle) // Determine order of visitation.
- visitationOrder = arma::shuffle(visitationOrder);
- }
-
- // Evaluate the gradient for this iteration.
- function.Gradient(iterate, currentFunction, gradient);
-
- // And update the iterate.
- iterate -= stepSize * gradient;
-
- // Now add that to the overall objective function.
- overallObjective += function.Evaluate(iterate, currentFunction);
- }
-
- Log::Info << "SGD: maximum iterations (" << maxIterations << ") reached; "
- << "terminating optimization." << std::endl;
- return overallObjective;
-}
-
-}; // namespace optimization
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/sgd_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/sgd/sgd_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/sgd_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/sgd_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,118 @@
+/**
+ * @file sgd_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of stochastic gradient descent.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_OPTIMIZERS_SGD_SGD_IMPL_HPP
+#define __MLPACK_CORE_OPTIMIZERS_SGD_SGD_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "sgd.hpp"
+
+namespace mlpack {
+namespace optimization {
+
+template<typename DecomposableFunctionType>
+SGD<DecomposableFunctionType>::SGD(DecomposableFunctionType& function,
+ const double stepSize,
+ const size_t maxIterations,
+ const double tolerance,
+ const bool shuffle) :
+ function(function),
+ stepSize(stepSize),
+ maxIterations(maxIterations),
+ tolerance(tolerance),
+ shuffle(shuffle)
+{ /* Nothing to do. */ }
+
+//! Optimize the function (minimize).
+template<typename DecomposableFunctionType>
+double SGD<DecomposableFunctionType>::Optimize(arma::mat& iterate)
+{
+ // Find the number of functions to use.
+ const size_t numFunctions = function.NumFunctions();
+
+ // This is used only if shuffle is true.
+ arma::vec visitationOrder;
+ if (shuffle)
+ visitationOrder = arma::shuffle(arma::linspace(0, (numFunctions - 1),
+ numFunctions));
+
+ // To keep track of where we are and how things are going.
+ size_t currentFunction = 0;
+ double overallObjective = 0;
+ double lastObjective = DBL_MAX;
+
+ // Calculate the first objective function.
+ for (size_t i = 0; i < numFunctions; ++i)
+ overallObjective += function.Evaluate(iterate, i);
+
+ // Now iterate!
+ arma::mat gradient(iterate.n_rows, iterate.n_cols);
+ for (size_t i = 1; i != maxIterations; ++i, ++currentFunction)
+ {
+ // Is this iteration the start of a sequence?
+ if ((currentFunction % numFunctions) == 0)
+ {
+ // Output current objective function.
+ Log::Info << "SGD: iteration " << i << ", objective " << overallObjective
+ << "." << std::endl;
+
+ if (overallObjective != overallObjective)
+ {
+ Log::Warn << "SGD: converged to " << overallObjective << "; terminating"
+ << " with failure. Try a smaller step size?" << std::endl;
+ return overallObjective;
+ }
+
+ if (std::abs(lastObjective - overallObjective) < tolerance)
+ {
+ Log::Info << "SGD: minimized within tolerance " << tolerance << "; "
+ << "terminating optimization." << std::endl;
+ return overallObjective;
+ }
+
+ // Reset the counter variables.
+ lastObjective = overallObjective;
+ overallObjective = 0;
+ currentFunction = 0;
+
+ if (shuffle) // Determine order of visitation.
+ visitationOrder = arma::shuffle(visitationOrder);
+ }
+
+ // Evaluate the gradient for this iteration.
+ function.Gradient(iterate, currentFunction, gradient);
+
+ // And update the iterate.
+ iterate -= stepSize * gradient;
+
+ // Now add that to the overall objective function.
+ overallObjective += function.Evaluate(iterate, currentFunction);
+ }
+
+ Log::Info << "SGD: maximum iterations (" << maxIterations << ") reached; "
+ << "terminating optimization." << std::endl;
+ return overallObjective;
+}
+
+}; // namespace optimization
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/test_function.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/sgd/test_function.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/test_function.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,70 +0,0 @@
-/**
- * @file test_function.cpp
- * @author Ryan Curtin
- *
- * Implementation of very simple test function for stochastic gradient descent
- * (SGD).
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "test_function.hpp"
-
-using namespace mlpack;
-using namespace mlpack::optimization;
-using namespace mlpack::optimization::test;
-
-double SGDTestFunction::Evaluate(const arma::mat& coordinates, const size_t i)
- const
-{
- switch (i)
- {
- case 0:
- return -std::exp(-std::abs(coordinates[0]));
-
- case 1:
- return std::pow(coordinates[1], 2);
-
- case 2:
- return std::pow(coordinates[2], 4) + 3 * std::pow(coordinates[2], 2);
-
- default:
- return 0;
- }
-}
-
-void SGDTestFunction::Gradient(const arma::mat& coordinates,
- const size_t i,
- arma::mat& gradient) const
-{
- gradient.zeros(3);
- switch (i)
- {
- case 0:
- if (coordinates[0] >= 0)
- gradient[0] = std::exp(-coordinates[0]);
- else
- gradient[0] = -std::exp(coordinates[1]);
- break;
-
- case 1:
- gradient[1] = 2 * coordinates[1];
- break;
-
- case 2:
- gradient[2] = 4 * std::pow(coordinates[2], 3) + 6 * coordinates[2];
- break;
- }
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/test_function.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/sgd/test_function.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/test_function.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/test_function.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,70 @@
+/**
+ * @file test_function.cpp
+ * @author Ryan Curtin
+ *
+ * Implementation of very simple test function for stochastic gradient descent
+ * (SGD).
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "test_function.hpp"
+
+using namespace mlpack;
+using namespace mlpack::optimization;
+using namespace mlpack::optimization::test;
+
+double SGDTestFunction::Evaluate(const arma::mat& coordinates, const size_t i)
+ const
+{
+ switch (i)
+ {
+ case 0:
+ return -std::exp(-std::abs(coordinates[0]));
+
+ case 1:
+ return std::pow(coordinates[1], 2);
+
+ case 2:
+ return std::pow(coordinates[2], 4) + 3 * std::pow(coordinates[2], 2);
+
+ default:
+ return 0;
+ }
+}
+
+void SGDTestFunction::Gradient(const arma::mat& coordinates,
+ const size_t i,
+ arma::mat& gradient) const
+{
+ gradient.zeros(3);
+ switch (i)
+ {
+ case 0:
+ if (coordinates[0] >= 0)
+ gradient[0] = std::exp(-coordinates[0]);
+ else
+ gradient[0] = -std::exp(coordinates[1]);
+ break;
+
+ case 1:
+ gradient[1] = 2 * coordinates[1];
+ break;
+
+ case 2:
+ gradient[2] = 4 * std::pow(coordinates[2], 3) + 6 * coordinates[2];
+ break;
+ }
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/test_function.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/sgd/test_function.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/test_function.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,60 +0,0 @@
-/**
- * @file test_function.hpp
- * @author Ryan Curtin
- *
- * Very simple test function for SGD.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_OPTIMIZERS_SGD_TEST_FUNCTION_HPP
-#define __MLPACK_CORE_OPTIMIZERS_SGD_TEST_FUNCTION_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace optimization {
-namespace test {
-
-//! Very, very simple test function which is the composite of three other
-//! functions. It turns out that although this function is very simple,
-//! optimizing it fully can take a very long time. It seems to take in excess
-//! of 10 million iterations with a step size of 0.0005.
-class SGDTestFunction
-{
- public:
- //! Nothing to do for the constructor.
- SGDTestFunction() { }
-
- //! Return 3 (the number of functions).
- size_t NumFunctions() const { return 3; }
-
- //! Get the starting point.
- arma::mat GetInitialPoint() const { return arma::mat("6; -45.6; 6.2"); }
-
- //! Evaluate a function.
- double Evaluate(const arma::mat& coordinates, const size_t i) const;
-
- //! Evaluate the gradient of a function.
- void Gradient(const arma::mat& coordinates,
- const size_t i,
- arma::mat& gradient) const;
-};
-
-}; // namespace test
-}; // namespace optimization
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/test_function.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/optimizers/sgd/test_function.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/test_function.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/optimizers/sgd/test_function.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,60 @@
+/**
+ * @file test_function.hpp
+ * @author Ryan Curtin
+ *
+ * Very simple test function for SGD.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_OPTIMIZERS_SGD_TEST_FUNCTION_HPP
+#define __MLPACK_CORE_OPTIMIZERS_SGD_TEST_FUNCTION_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace optimization {
+namespace test {
+
+//! Very, very simple test function which is the composite of three other
+//! functions. It turns out that although this function is very simple,
+//! optimizing it fully can take a very long time. It seems to take in excess
+//! of 10 million iterations with a step size of 0.0005.
+class SGDTestFunction
+{
+ public:
+ //! Nothing to do for the constructor.
+ SGDTestFunction() { }
+
+ //! Return 3 (the number of functions).
+ size_t NumFunctions() const { return 3; }
+
+ //! Get the starting point.
+ arma::mat GetInitialPoint() const { return arma::mat("6; -45.6; 6.2"); }
+
+ //! Evaluate a function.
+ double Evaluate(const arma::mat& coordinates, const size_t i) const;
+
+ //! Evaluate the gradient of a function.
+ void Gradient(const arma::mat& coordinates,
+ const size_t i,
+ arma::mat& gradient) const;
+};
+
+}; // namespace test
+}; // namespace optimization
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/ballbound.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/ballbound.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/ballbound.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,153 +0,0 @@
-/**
- * @file ballbound.hpp
- *
- * Bounds that are useful for binary space partitioning trees.
- * Interface to a ball bound that works in arbitrary metric spaces.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#ifndef __MLPACK_CORE_TREE_BALLBOUND_HPP
-#define __MLPACK_CORE_TREE_BALLBOUND_HPP
-
-#include <mlpack/core.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-
-namespace mlpack {
-namespace bound {
-
-/**
- * Ball bound that works in the regular Euclidean metric space.
- *
- * @tparam VecType Type of vector (arma::vec or arma::spvec).
- */
-template<typename VecType = arma::vec>
-class BallBound
-{
- public:
- typedef VecType Vec;
-
- private:
- double radius;
- VecType center;
-
- public:
- BallBound() : radius(0) { }
-
- /**
- * Create the ball bound with the specified dimensionality.
- *
- * @param dimension Dimensionality of ball bound.
- */
- BallBound(const size_t dimension) : radius(0), center(dimension) { }
-
- /**
- * Create the ball bound with the specified radius and center.
- *
- * @param radius Radius of ball bound.
- * @param center Center of ball bound.
- */
- BallBound(const double radius, const VecType& center) :
- radius(radius), center(center) { }
-
- //! Get the radius of the ball.
- double Radius() const { return radius; }
- //! Modify the radius of the ball.
- double& Radius() { return radius; }
-
- //! Get the center point of the ball.
- const VecType& Center() const { return center; }
- //! Modify the center point of the ball.
- VecType& Center() { return center; }
-
- // Get the range in a certain dimension.
- math::Range operator[](const size_t i) const;
-
- /**
- * Determines if a point is within this bound.
- */
- bool Contains(const VecType& point) const;
-
- /**
- * Gets the center.
- *
- * Don't really use this directly. This is only here for consistency
- * with DHrectBound, so it can plug in more directly if a "centroid"
- * is needed.
- */
- void CalculateMidpoint(VecType& centroid) const;
-
- /**
- * Calculates minimum bound-to-point squared distance.
- */
- double MinDistance(const VecType& point) const;
-
- /**
- * Calculates minimum bound-to-bound squared distance.
- */
- double MinDistance(const BallBound& other) const;
-
- /**
- * Computes maximum distance.
- */
- double MaxDistance(const VecType& point) const;
-
- /**
- * Computes maximum distance.
- */
- double MaxDistance(const BallBound& other) const;
-
- /**
- * Calculates minimum and maximum bound-to-point distance.
- */
- math::Range RangeDistance(const VecType& other) const;
-
- /**
- * Calculates minimum and maximum bound-to-bound distance.
- *
- * Example: bound1.MinDistanceSq(other) for minimum distance.
- */
- math::Range RangeDistance(const BallBound& other) const;
-
- /**
- * Expand the bound to include the given node.
- */
- const BallBound& operator|=(const BallBound& other);
-
- /**
- * Expand the bound to include the given point. The centroid is recalculated
- * to be the center of all of the given points.
- *
- * @tparam MatType Type of matrix; could be arma::mat, arma::spmat, or a
- * vector.
- * @tparam data Data points to add.
- */
- template<typename MatType>
- const BallBound& operator|=(const MatType& data);
-
- /**
- * Returns a string representation of this object.
- */
- std::string ToString() const;
-
-};
-
-}; // namespace bound
-}; // namespace mlpack
-
-#include "ballbound_impl.hpp"
-
-#endif // __MLPACK_CORE_TREE_DBALLBOUND_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/ballbound.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/ballbound.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/ballbound.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/ballbound.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,153 @@
+/**
+ * @file ballbound.hpp
+ *
+ * Bounds that are useful for binary space partitioning trees.
+ * Interface to a ball bound that works in arbitrary metric spaces.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef __MLPACK_CORE_TREE_BALLBOUND_HPP
+#define __MLPACK_CORE_TREE_BALLBOUND_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+
+namespace mlpack {
+namespace bound {
+
+/**
+ * Ball bound that works in the regular Euclidean metric space.
+ *
+ * @tparam VecType Type of vector (arma::vec or arma::spvec).
+ */
+template<typename VecType = arma::vec>
+class BallBound
+{
+ public:
+ typedef VecType Vec;
+
+ private:
+ double radius;
+ VecType center;
+
+ public:
+ BallBound() : radius(0) { }
+
+ /**
+ * Create the ball bound with the specified dimensionality.
+ *
+ * @param dimension Dimensionality of ball bound.
+ */
+ BallBound(const size_t dimension) : radius(0), center(dimension) { }
+
+ /**
+ * Create the ball bound with the specified radius and center.
+ *
+ * @param radius Radius of ball bound.
+ * @param center Center of ball bound.
+ */
+ BallBound(const double radius, const VecType& center) :
+ radius(radius), center(center) { }
+
+ //! Get the radius of the ball.
+ double Radius() const { return radius; }
+ //! Modify the radius of the ball.
+ double& Radius() { return radius; }
+
+ //! Get the center point of the ball.
+ const VecType& Center() const { return center; }
+ //! Modify the center point of the ball.
+ VecType& Center() { return center; }
+
+ // Get the range in a certain dimension.
+ math::Range operator[](const size_t i) const;
+
+ /**
+ * Determines if a point is within this bound.
+ */
+ bool Contains(const VecType& point) const;
+
+ /**
+ * Gets the center.
+ *
+ * Don't really use this directly. This is only here for consistency
+ * with DHrectBound, so it can plug in more directly if a "centroid"
+ * is needed.
+ */
+ void CalculateMidpoint(VecType& centroid) const;
+
+ /**
+ * Calculates minimum bound-to-point squared distance.
+ */
+ double MinDistance(const VecType& point) const;
+
+ /**
+ * Calculates minimum bound-to-bound squared distance.
+ */
+ double MinDistance(const BallBound& other) const;
+
+ /**
+ * Computes maximum distance.
+ */
+ double MaxDistance(const VecType& point) const;
+
+ /**
+ * Computes maximum distance.
+ */
+ double MaxDistance(const BallBound& other) const;
+
+ /**
+ * Calculates minimum and maximum bound-to-point distance.
+ */
+ math::Range RangeDistance(const VecType& other) const;
+
+ /**
+ * Calculates minimum and maximum bound-to-bound distance.
+ *
+ * Example: bound1.MinDistanceSq(other) for minimum distance.
+ */
+ math::Range RangeDistance(const BallBound& other) const;
+
+ /**
+ * Expand the bound to include the given node.
+ */
+ const BallBound& operator|=(const BallBound& other);
+
+ /**
+ * Expand the bound to include the given point. The centroid is recalculated
+ * to be the center of all of the given points.
+ *
+ * @tparam MatType Type of matrix; could be arma::mat, arma::spmat, or a
+ * vector.
+ * @tparam data Data points to add.
+ */
+ template<typename MatType>
+ const BallBound& operator|=(const MatType& data);
+
+ /**
+ * Returns a string representation of this object.
+ */
+ std::string ToString() const;
+
+};
+
+}; // namespace bound
+}; // namespace mlpack
+
+#include "ballbound_impl.hpp"
+
+#endif // __MLPACK_CORE_TREE_DBALLBOUND_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/ballbound_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/ballbound_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/ballbound_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,224 +0,0 @@
-/**
- * @file ballbound_impl.hpp
- *
- * Bounds that are useful for binary space partitioning trees.
- * Implementation of BallBound ball bound metric policy class.
- *
- * @experimental
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_BALLBOUND_IMPL_HPP
-#define __MLPACK_CORE_TREE_BALLBOUND_IMPL_HPP
-
-// In case it hasn't been included already.
-#include "ballbound.hpp"
-
-#include <string>
-
-namespace mlpack {
-namespace bound {
-
-//! Get the range in a certain dimension.
-template<typename VecType>
-math::Range BallBound<VecType>::operator[](const size_t i) const
-{
- if (radius < 0)
- return math::Range();
- else
- return math::Range(center[i] - radius, center[i] + radius);
-}
-
-/**
- * Determines if a point is within the bound.
- */
-template<typename VecType>
-bool BallBound<VecType>::Contains(const VecType& point) const
-{
- if (radius < 0)
- return false;
- else
- return metric::EuclideanDistance::Evaluate(center, point) <= radius;
-}
-
-/**
- * Gets the center.
- *
- * Don't really use this directly. This is only here for consistency
- * with DHrectBound, so it can plug in more directly if a "centroid"
- * is needed.
- */
-template<typename VecType>
-void BallBound<VecType>::CalculateMidpoint(VecType& centroid) const
-{
- centroid = center;
-}
-
-/**
- * Calculates minimum bound-to-point squared distance.
- */
-template<typename VecType>
-double BallBound<VecType>::MinDistance(const VecType& point) const
-{
- if (radius < 0)
- return DBL_MAX;
- else
- return math::ClampNonNegative(metric::EuclideanDistance::Evaluate(point,
- center) - radius);
-}
-
-/**
- * Calculates minimum bound-to-bound squared distance.
- */
-template<typename VecType>
-double BallBound<VecType>::MinDistance(const BallBound& other) const
-{
- if (radius < 0)
- return DBL_MAX;
- else
- {
- double delta = metric::EuclideanDistance::Evaluate(center, other.center)
- - radius - other.radius;
- return math::ClampNonNegative(delta);
- }
-}
-
-/**
- * Computes maximum distance.
- */
-template<typename VecType>
-double BallBound<VecType>::MaxDistance(const VecType& point) const
-{
- if (radius < 0)
- return DBL_MAX;
- else
- return metric::EuclideanDistance::Evaluate(point, center) + radius;
-}
-
-/**
- * Computes maximum distance.
- */
-template<typename VecType>
-double BallBound<VecType>::MaxDistance(const BallBound& other) const
-{
- if (radius < 0)
- return DBL_MAX;
- else
- return metric::EuclideanDistance::Evaluate(other.center, center) + radius
- + other.radius;
-}
-
-/**
- * Calculates minimum and maximum bound-to-bound squared distance.
- *
- * Example: bound1.MinDistanceSq(other) for minimum squared distance.
- */
-template<typename VecType>
-math::Range BallBound<VecType>::RangeDistance(const VecType& point)
- const
-{
- if (radius < 0)
- return math::Range(DBL_MAX, DBL_MAX);
- else
- {
- double dist = metric::EuclideanDistance::Evaluate(center, point);
- return math::Range(math::ClampNonNegative(dist - radius),
- dist + radius);
- }
-}
-
-template<typename VecType>
-math::Range BallBound<VecType>::RangeDistance(
- const BallBound& other) const
-{
- if (radius < 0)
- return math::Range(DBL_MAX, DBL_MAX);
- else
- {
- double dist = metric::EuclideanDistance::Evaluate(center, other.center);
- double sumradius = radius + other.radius;
- return math::Range(math::ClampNonNegative(dist - sumradius),
- dist + sumradius);
- }
-}
-
-/**
- * Expand the bound to include the given bound.
- *
-template<typename VecType>
-const BallBound<VecType>&
-BallBound<VecType>::operator|=(
- const BallBound<VecType>& other)
-{
- double dist = metric::EuclideanDistance::Evaluate(center, other);
-
- // Now expand the radius as necessary.
- if (dist > radius)
- radius = dist;
-
- return *this;
-}*/
-
-/**
- * Expand the bound to include the given point.
- */
-template<typename VecType>
-template<typename MatType>
-const BallBound<VecType>&
-BallBound<VecType>::operator|=(const MatType& data)
-{
- if (radius < 0)
- {
- center = data.col(0);
- radius = 0;
- }
-
- // Now iteratively add points. There is probably a closed-form solution to
- // find the minimum bounding circle, and it is probably faster.
- for (size_t i = 1; i < data.n_cols; ++i)
- {
- double dist = metric::EuclideanDistance::Evaluate(center, (VecType)
- data.col(i)) - radius;
-
- if (dist > 0)
- {
- // Move (dist / 2) towards the new point and increase radius by
- // (dist / 2).
- arma::vec diff = data.col(i) - center;
- center += 0.5 * diff;
- radius += 0.5 * dist;
- }
- }
-
- return *this;
-}
-/**
- * Returns a string representation of this object.
- */
-template<typename VecType>
-std::string BallBound<VecType>::ToString() const
-{
- std::ostringstream convert;
- convert << "BallBound [" << this << "]" << std::endl;
- convert << "Radius: " << radius << std::endl;
- convert << "Center: " << std::endl << center;
- return convert.str();
-}
-
-}; // namespace bound
-}; // namespace mlpack
-
-#endif // __MLPACK_CORE_TREE_DBALLBOUND_IMPL_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/ballbound_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/ballbound_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/ballbound_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/ballbound_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,224 @@
+/**
+ * @file ballbound_impl.hpp
+ *
+ * Bounds that are useful for binary space partitioning trees.
+ * Implementation of BallBound ball bound metric policy class.
+ *
+ * @experimental
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_BALLBOUND_IMPL_HPP
+#define __MLPACK_CORE_TREE_BALLBOUND_IMPL_HPP
+
+// In case it hasn't been included already.
+#include "ballbound.hpp"
+
+#include <string>
+
+namespace mlpack {
+namespace bound {
+
+//! Get the range in a certain dimension.
+template<typename VecType>
+math::Range BallBound<VecType>::operator[](const size_t i) const
+{
+ if (radius < 0)
+ return math::Range();
+ else
+ return math::Range(center[i] - radius, center[i] + radius);
+}
+
+/**
+ * Determines if a point is within the bound.
+ */
+template<typename VecType>
+bool BallBound<VecType>::Contains(const VecType& point) const
+{
+ if (radius < 0)
+ return false;
+ else
+ return metric::EuclideanDistance::Evaluate(center, point) <= radius;
+}
+
+/**
+ * Gets the center.
+ *
+ * Don't really use this directly. This is only here for consistency
+ * with DHrectBound, so it can plug in more directly if a "centroid"
+ * is needed.
+ */
+template<typename VecType>
+void BallBound<VecType>::CalculateMidpoint(VecType& centroid) const
+{
+ centroid = center;
+}
+
+/**
+ * Calculates minimum bound-to-point squared distance.
+ */
+template<typename VecType>
+double BallBound<VecType>::MinDistance(const VecType& point) const
+{
+ if (radius < 0)
+ return DBL_MAX;
+ else
+ return math::ClampNonNegative(metric::EuclideanDistance::Evaluate(point,
+ center) - radius);
+}
+
+/**
+ * Calculates minimum bound-to-bound squared distance.
+ */
+template<typename VecType>
+double BallBound<VecType>::MinDistance(const BallBound& other) const
+{
+ if (radius < 0)
+ return DBL_MAX;
+ else
+ {
+ double delta = metric::EuclideanDistance::Evaluate(center, other.center)
+ - radius - other.radius;
+ return math::ClampNonNegative(delta);
+ }
+}
+
+/**
+ * Computes maximum distance.
+ */
+template<typename VecType>
+double BallBound<VecType>::MaxDistance(const VecType& point) const
+{
+ if (radius < 0)
+ return DBL_MAX;
+ else
+ return metric::EuclideanDistance::Evaluate(point, center) + radius;
+}
+
+/**
+ * Computes maximum distance.
+ */
+template<typename VecType>
+double BallBound<VecType>::MaxDistance(const BallBound& other) const
+{
+ if (radius < 0)
+ return DBL_MAX;
+ else
+ return metric::EuclideanDistance::Evaluate(other.center, center) + radius
+ + other.radius;
+}
+
+/**
+ * Calculates minimum and maximum bound-to-bound squared distance.
+ *
+ * Example: bound1.MinDistanceSq(other) for minimum squared distance.
+ */
+template<typename VecType>
+math::Range BallBound<VecType>::RangeDistance(const VecType& point)
+ const
+{
+ if (radius < 0)
+ return math::Range(DBL_MAX, DBL_MAX);
+ else
+ {
+ double dist = metric::EuclideanDistance::Evaluate(center, point);
+ return math::Range(math::ClampNonNegative(dist - radius),
+ dist + radius);
+ }
+}
+
+template<typename VecType>
+math::Range BallBound<VecType>::RangeDistance(
+ const BallBound& other) const
+{
+ if (radius < 0)
+ return math::Range(DBL_MAX, DBL_MAX);
+ else
+ {
+ double dist = metric::EuclideanDistance::Evaluate(center, other.center);
+ double sumradius = radius + other.radius;
+ return math::Range(math::ClampNonNegative(dist - sumradius),
+ dist + sumradius);
+ }
+}
+
+/**
+ * Expand the bound to include the given bound.
+ *
+template<typename VecType>
+const BallBound<VecType>&
+BallBound<VecType>::operator|=(
+ const BallBound<VecType>& other)
+{
+ double dist = metric::EuclideanDistance::Evaluate(center, other);
+
+ // Now expand the radius as necessary.
+ if (dist > radius)
+ radius = dist;
+
+ return *this;
+}*/
+
+/**
+ * Expand the bound to include the given point.
+ */
+template<typename VecType>
+template<typename MatType>
+const BallBound<VecType>&
+BallBound<VecType>::operator|=(const MatType& data)
+{
+ if (radius < 0)
+ {
+ center = data.col(0);
+ radius = 0;
+ }
+
+ // Now iteratively add points. There is probably a closed-form solution to
+ // find the minimum bounding circle, and it is probably faster.
+ for (size_t i = 1; i < data.n_cols; ++i)
+ {
+ double dist = metric::EuclideanDistance::Evaluate(center, (VecType)
+ data.col(i)) - radius;
+
+ if (dist > 0)
+ {
+ // Move (dist / 2) towards the new point and increase radius by
+ // (dist / 2).
+ arma::vec diff = data.col(i) - center;
+ center += 0.5 * diff;
+ radius += 0.5 * dist;
+ }
+ }
+
+ return *this;
+}
+/**
+ * Returns a string representation of this object.
+ */
+template<typename VecType>
+std::string BallBound<VecType>::ToString() const
+{
+ std::ostringstream convert;
+ convert << "BallBound [" << this << "]" << std::endl;
+ convert << "Radius: " << radius << std::endl;
+ convert << "Center: " << std::endl << center;
+ return convert.str();
+}
+
+}; // namespace bound
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_TREE_DBALLBOUND_IMPL_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,465 +0,0 @@
-/**
- * @file binary_space_tree.hpp
- *
- * Definition of generalized binary space partitioning tree (BinarySpaceTree).
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_HPP
-#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_HPP
-
-#include <mlpack/core.hpp>
-
-#include "../statistic.hpp"
-
-namespace mlpack {
-namespace tree /** Trees and tree-building procedures. */ {
-
-/**
- * A binary space partitioning tree, such as a KD-tree or a ball tree. Once the
- * bound and type of dataset is defined, the tree will construct itself. Call
- * the constructor with the dataset to build the tree on, and the entire tree
- * will be built.
- *
- * This particular tree does not allow growth, so you cannot add or delete nodes
- * from it. If you need to add or delete a node, the better procedure is to
- * rebuild the tree entirely.
- *
- * This tree does take one parameter, which is the leaf size to be used.
- *
- * @tparam BoundType The bound used for each node. The valid types of bounds
- * and the necessary skeleton interface for this class can be found in
- * bounds/.
- * @tparam StatisticType Extra data contained in the node. See statistic.hpp
- * for the necessary skeleton interface.
- */
-template<typename BoundType,
- typename StatisticType = EmptyStatistic,
- typename MatType = arma::mat>
-class BinarySpaceTree
-{
- private:
- //! The left child node.
- BinarySpaceTree* left;
- //! The right child node.
- BinarySpaceTree* right;
- //! The parent node (NULL if this is the root of the tree).
- BinarySpaceTree* parent;
- //! The index of the first point in the dataset contained in this node (and
- //! its children).
- size_t begin;
- //! The number of points of the dataset contained in this node (and its
- //! children).
- size_t count;
- //! The leaf size.
- size_t leafSize;
- //! The bound object for this node.
- BoundType bound;
- //! Any extra data contained in the node.
- StatisticType stat;
- //! The dimension this node split on if it is a parent.
- size_t splitDimension;
- //! The distance to the furthest descendant, cached to speed things up.
- double furthestDescendantDistance;
- //! The dataset.
- MatType& dataset;
-
- public:
- //! So other classes can use TreeType::Mat.
- typedef MatType Mat;
-
- //! A single-tree traverser for binary space trees; see
- //! single_tree_traverser.hpp for implementation.
- template<typename RuleType>
- class SingleTreeTraverser;
-
- //! A dual-tree traverser for binary space trees; see dual_tree_traverser.hpp.
- template<typename RuleType>
- class DualTreeTraverser;
-
- /**
- * Construct this as the root node of a binary space tree using the given
- * dataset. This will modify the ordering of the points in the dataset!
- *
- * @param data Dataset to create tree from. This will be modified!
- * @param leafSize Size of each leaf in the tree.
- */
- BinarySpaceTree(MatType& data, const size_t leafSize = 20);
-
- /**
- * Construct this as the root node of a binary space tree using the given
- * dataset. This will modify the ordering of points in the dataset! A
- * mapping of the old point indices to the new point indices is filled.
- *
- * @param data Dataset to create tree from. This will be modified!
- * @param oldFromNew Vector which will be filled with the old positions for
- * each new point.
- * @param leafSize Size of each leaf in the tree.
- */
- BinarySpaceTree(MatType& data,
- std::vector<size_t>& oldFromNew,
- const size_t leafSize = 20);
-
- /**
- * Construct this as the root node of a binary space tree using the given
- * dataset. This will modify the ordering of points in the dataset! A
- * mapping of the old point indices to the new point indices is filled, as
- * well as a mapping of the new point indices to the old point indices.
- *
- * @param data Dataset to create tree from. This will be modified!
- * @param oldFromNew Vector which will be filled with the old positions for
- * each new point.
- * @param newFromOld Vector which will be filled with the new positions for
- * each old point.
- * @param leafSize Size of each leaf in the tree.
- */
- BinarySpaceTree(MatType& data,
- std::vector<size_t>& oldFromNew,
- std::vector<size_t>& newFromOld,
- const size_t leafSize = 20);
-
- /**
- * Construct this node on a subset of the given matrix, starting at column
- * begin and using count points. The ordering of that subset of points
- * will be modified! This is used for recursive tree-building by the other
- * constructors which don't specify point indices.
- *
- * @param data Dataset to create tree from. This will be modified!
- * @param begin Index of point to start tree construction with.
- * @param count Number of points to use to construct tree.
- * @param leafSize Size of each leaf in the tree.
- */
- BinarySpaceTree(MatType& data,
- const size_t begin,
- const size_t count,
- BinarySpaceTree* parent = NULL,
- const size_t leafSize = 20);
-
- /**
- * Construct this node on a subset of the given matrix, starting at column
- * begin_in and using count_in points. The ordering of that subset of points
- * will be modified! This is used for recursive tree-building by the other
- * constructors which don't specify point indices.
- *
- * A mapping of the old point indices to the new point indices is filled, but
- * it is expected that the vector is already allocated with size greater than
- * or equal to (begin_in + count_in), and if that is not true, invalid memory
- * reads (and writes) will occur.
- *
- * @param data Dataset to create tree from. This will be modified!
- * @param begin Index of point to start tree construction with.
- * @param count Number of points to use to construct tree.
- * @param oldFromNew Vector which will be filled with the old positions for
- * each new point.
- * @param leafSize Size of each leaf in the tree.
- */
- BinarySpaceTree(MatType& data,
- const size_t begin,
- const size_t count,
- std::vector<size_t>& oldFromNew,
- BinarySpaceTree* parent = NULL,
- const size_t leafSize = 20);
-
- /**
- * Construct this node on a subset of the given matrix, starting at column
- * begin_in and using count_in points. The ordering of that subset of points
- * will be modified! This is used for recursive tree-building by the other
- * constructors which don't specify point indices.
- *
- * A mapping of the old point indices to the new point indices is filled, as
- * well as a mapping of the new point indices to the old point indices. It is
- * expected that the vector is already allocated with size greater than or
- * equal to (begin_in + count_in), and if that is not true, invalid memory
- * reads (and writes) will occur.
- *
- * @param data Dataset to create tree from. This will be modified!
- * @param begin Index of point to start tree construction with.
- * @param count Number of points to use to construct tree.
- * @param oldFromNew Vector which will be filled with the old positions for
- * each new point.
- * @param newFromOld Vector which will be filled with the new positions for
- * each old point.
- * @param leafSize Size of each leaf in the tree.
- */
- BinarySpaceTree(MatType& data,
- const size_t begin,
- const size_t count,
- std::vector<size_t>& oldFromNew,
- std::vector<size_t>& newFromOld,
- BinarySpaceTree* parent = NULL,
- const size_t leafSize = 20);
-
- /**
- * Create a binary space tree by copying the other tree. Be careful! This
- * can take a long time and use a lot of memory.
- *
- * @param other Tree to be replicated.
- */
- BinarySpaceTree(const BinarySpaceTree& other);
-
- /**
- * Deletes this node, deallocating the memory for the children and calling
- * their destructors in turn. This will invalidate any pointers or references
- * to any nodes which are children of this one.
- */
- ~BinarySpaceTree();
-
- /**
- * Find a node in this tree by its begin and count (const).
- *
- * Every node is uniquely identified by these two numbers.
- * This is useful for communicating position over the network,
- * when pointers would be invalid.
- *
- * @param begin The begin() of the node to find.
- * @param count The count() of the node to find.
- * @return The found node, or NULL if not found.
- */
- const BinarySpaceTree* FindByBeginCount(size_t begin,
- size_t count) const;
-
- /**
- * Find a node in this tree by its begin and count.
- *
- * Every node is uniquely identified by these two numbers.
- * This is useful for communicating position over the network,
- * when pointers would be invalid.
- *
- * @param begin The begin() of the node to find.
- * @param count The count() of the node to find.
- * @return The found node, or NULL if not found.
- */
- BinarySpaceTree* FindByBeginCount(size_t begin, size_t count);
-
- //! Return the bound object for this node.
- const BoundType& Bound() const { return bound; }
- //! Return the bound object for this node.
- BoundType& Bound() { return bound; }
-
- //! Return the statistic object for this node.
- const StatisticType& Stat() const { return stat; }
- //! Return the statistic object for this node.
- StatisticType& Stat() { return stat; }
-
- //! Return whether or not this node is a leaf (true if it has no children).
- bool IsLeaf() const;
-
- //! Return the leaf size.
- size_t LeafSize() const { return leafSize; }
- //! Modify the leaf size.
- size_t& LeafSize() { return leafSize; }
-
- //! Fills the tree to the specified level.
- size_t ExtendTree(const size_t level);
-
- //! Gets the left child of this node.
- BinarySpaceTree* Left() const { return left; }
- //! Modify the left child of this node.
- BinarySpaceTree*& Left() { return left; }
-
- //! Gets the right child of this node.
- BinarySpaceTree* Right() const { return right; }
- //! Modify the right child of this node.
- BinarySpaceTree*& Right() { return right; }
-
- //! Gets the parent of this node.
- BinarySpaceTree* Parent() const { return parent; }
- //! Modify the parent of this node.
- BinarySpaceTree*& Parent() { return parent; }
-
- //! Get the split dimension for this node.
- size_t SplitDimension() const { return splitDimension; }
- //! Modify the split dimension for this node.
- size_t& SplitDimension() { return splitDimension; }
-
- //! Get the dataset which the tree is built on.
- const arma::mat& Dataset() const { return dataset; }
- //! Modify the dataset which the tree is built on. Be careful!
- arma::mat& Dataset() { return dataset; }
-
- //! Get the metric which the tree uses.
- typename BoundType::MetricType Metric() const { return bound.Metric(); }
-
- //! Get the centroid of the node and store it in the given vector.
- void Centroid(arma::vec& centroid) { bound.Centroid(centroid); }
-
- //! Return the number of children in this node.
- size_t NumChildren() const;
-
- /**
- * Return the furthest possible descendant distance. This returns the maximum
- * distance from the centroid to the edge of the bound and not the empirical
- * quantity which is the actual furthest descendant distance. So the actual
- * furthest descendant distance may be less than what this method returns (but
- * it will never be greater than this).
- */
- double FurthestDescendantDistance() const;
-
- /**
- * Return the specified child (0 will be left, 1 will be right). If the index
- * is greater than 1, this will return the right child.
- *
- * @param child Index of child to return.
- */
- BinarySpaceTree& Child(const size_t child) const;
-
- //! Return the number of points in this node (0 if not a leaf).
- size_t NumPoints() const;
-
- /**
- * Return the index (with reference to the dataset) of a particular point in
- * this node. This will happily return invalid indices if the given index is
- * greater than the number of points in this node (obtained with NumPoints())
- * -- be careful.
- *
- * @param index Index of point for which a dataset index is wanted.
- */
- size_t Point(const size_t index) const;
-
- //! Return the minimum distance to another node.
- double MinDistance(const BinarySpaceTree* other) const
- {
- return bound.MinDistance(other->Bound());
- }
-
- //! Return the maximum distance to another node.
- double MaxDistance(const BinarySpaceTree* other) const
- {
- return bound.MaxDistance(other->Bound());
- }
-
- //! Return the minimum distance to another point.
- double MinDistance(const arma::vec& point) const
- {
- return bound.MinDistance(point);
- }
-
- //! Return the maximum distance to another point.
- double MaxDistance(const arma::vec& point) const
- {
- return bound.MaxDistance(point);
- }
-
- /**
- * Returns the dimension this parent's children are split on.
- */
- size_t GetSplitDimension() const;
-
- /**
- * Obtains the number of nodes in the tree, starting with this.
- */
- size_t TreeSize() const;
-
- /**
- * Obtains the number of levels below this node in the tree, starting with
- * this.
- */
- size_t TreeDepth() const;
-
- //! Return the index of the beginning point of this subset.
- size_t Begin() const { return begin; }
- //! Modify the index of the beginning point of this subset.
- size_t& Begin() { return begin; }
-
- /**
- * Gets the index one beyond the last index in the subset.
- */
- size_t End() const;
-
- //! Return the number of points in this subset.
- size_t Count() const { return count; }
- //! Modify the number of points in this subset.
- size_t& Count() { return count; }
-
- //! Returns false: this tree type does not have self children.
- static bool HasSelfChildren() { return false; }
-
- private:
- /**
- * Private copy constructor, available only to fill (pad) the tree to a
- * specified level.
- */
- BinarySpaceTree(const size_t begin,
- const size_t count,
- BoundType bound,
- StatisticType stat,
- const int leafSize = 20) :
- left(NULL),
- right(NULL),
- begin(begin),
- count(count),
- bound(bound),
- stat(stat),
- leafSize(leafSize) { }
-
- BinarySpaceTree* CopyMe()
- {
- return new BinarySpaceTree(begin, count, bound, stat, leafSize);
- }
-
- /**
- * Splits the current node, assigning its left and right children recursively.
- *
- * @param data Dataset which we are using.
- */
- void SplitNode(MatType& data);
-
- /**
- * Splits the current node, assigning its left and right children recursively.
- * Also returns a list of the changed indices.
- *
- * @param data Dataset which we are using.
- * @param oldFromNew Vector holding permuted indices.
- */
- void SplitNode(MatType& data, std::vector<size_t>& oldFromNew);
-
- /**
- * Find the index to split on for this node, given that we are splitting in
- * the given split dimension on the specified split value.
- *
- * @param data Dataset which we are using.
- * @param splitDim Dimension of dataset to split on.
- * @param splitVal Value to split on, in the given split dimension.
- */
- size_t GetSplitIndex(MatType& data, int splitDim, double splitVal);
-
- /**
- * Find the index to split on for this node, given that we are splitting in
- * the given split dimension on the specified split value. Also returns a
- * list of the changed indices.
- *
- * @param data Dataset which we are using.
- * @param splitDim Dimension of dataset to split on.
- * @param splitVal Value to split on, in the given split dimension.
- * @param oldFromNew Vector holding permuted indices.
- */
- size_t GetSplitIndex(MatType& data, int splitDim, double splitVal,
- std::vector<size_t>& oldFromNew);
- public:
- /**
- * Returns a string representation of this object.
- */
- std::string ToString() const;
-
-};
-
-}; // namespace tree
-}; // namespace mlpack
-
-// Include implementation.
-#include "binary_space_tree_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,465 @@
+/**
+ * @file binary_space_tree.hpp
+ *
+ * Definition of generalized binary space partitioning tree (BinarySpaceTree).
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_HPP
+
+#include <mlpack/core.hpp>
+
+#include "../statistic.hpp"
+
+namespace mlpack {
+namespace tree /** Trees and tree-building procedures. */ {
+
+/**
+ * A binary space partitioning tree, such as a KD-tree or a ball tree. Once the
+ * bound and type of dataset is defined, the tree will construct itself. Call
+ * the constructor with the dataset to build the tree on, and the entire tree
+ * will be built.
+ *
+ * This particular tree does not allow growth, so you cannot add or delete nodes
+ * from it. If you need to add or delete a node, the better procedure is to
+ * rebuild the tree entirely.
+ *
+ * This tree does take one parameter, which is the leaf size to be used.
+ *
+ * @tparam BoundType The bound used for each node. The valid types of bounds
+ * and the necessary skeleton interface for this class can be found in
+ * bounds/.
+ * @tparam StatisticType Extra data contained in the node. See statistic.hpp
+ * for the necessary skeleton interface.
+ */
+template<typename BoundType,
+ typename StatisticType = EmptyStatistic,
+ typename MatType = arma::mat>
+class BinarySpaceTree
+{
+ private:
+ //! The left child node.
+ BinarySpaceTree* left;
+ //! The right child node.
+ BinarySpaceTree* right;
+ //! The parent node (NULL if this is the root of the tree).
+ BinarySpaceTree* parent;
+ //! The index of the first point in the dataset contained in this node (and
+ //! its children).
+ size_t begin;
+ //! The number of points of the dataset contained in this node (and its
+ //! children).
+ size_t count;
+ //! The leaf size.
+ size_t leafSize;
+ //! The bound object for this node.
+ BoundType bound;
+ //! Any extra data contained in the node.
+ StatisticType stat;
+ //! The dimension this node split on if it is a parent.
+ size_t splitDimension;
+ //! The distance to the furthest descendant, cached to speed things up.
+ double furthestDescendantDistance;
+ //! The dataset.
+ MatType& dataset;
+
+ public:
+ //! So other classes can use TreeType::Mat.
+ typedef MatType Mat;
+
+ //! A single-tree traverser for binary space trees; see
+ //! single_tree_traverser.hpp for implementation.
+ template<typename RuleType>
+ class SingleTreeTraverser;
+
+ //! A dual-tree traverser for binary space trees; see dual_tree_traverser.hpp.
+ template<typename RuleType>
+ class DualTreeTraverser;
+
+ /**
+ * Construct this as the root node of a binary space tree using the given
+ * dataset. This will modify the ordering of the points in the dataset!
+ *
+ * @param data Dataset to create tree from. This will be modified!
+ * @param leafSize Size of each leaf in the tree.
+ */
+ BinarySpaceTree(MatType& data, const size_t leafSize = 20);
+
+ /**
+ * Construct this as the root node of a binary space tree using the given
+ * dataset. This will modify the ordering of points in the dataset! A
+ * mapping of the old point indices to the new point indices is filled.
+ *
+ * @param data Dataset to create tree from. This will be modified!
+ * @param oldFromNew Vector which will be filled with the old positions for
+ * each new point.
+ * @param leafSize Size of each leaf in the tree.
+ */
+ BinarySpaceTree(MatType& data,
+ std::vector<size_t>& oldFromNew,
+ const size_t leafSize = 20);
+
+ /**
+ * Construct this as the root node of a binary space tree using the given
+ * dataset. This will modify the ordering of points in the dataset! A
+ * mapping of the old point indices to the new point indices is filled, as
+ * well as a mapping of the new point indices to the old point indices.
+ *
+ * @param data Dataset to create tree from. This will be modified!
+ * @param oldFromNew Vector which will be filled with the old positions for
+ * each new point.
+ * @param newFromOld Vector which will be filled with the new positions for
+ * each old point.
+ * @param leafSize Size of each leaf in the tree.
+ */
+ BinarySpaceTree(MatType& data,
+ std::vector<size_t>& oldFromNew,
+ std::vector<size_t>& newFromOld,
+ const size_t leafSize = 20);
+
+ /**
+ * Construct this node on a subset of the given matrix, starting at column
+ * begin and using count points. The ordering of that subset of points
+ * will be modified! This is used for recursive tree-building by the other
+ * constructors which don't specify point indices.
+ *
+ * @param data Dataset to create tree from. This will be modified!
+ * @param begin Index of point to start tree construction with.
+ * @param count Number of points to use to construct tree.
+ * @param leafSize Size of each leaf in the tree.
+ */
+ BinarySpaceTree(MatType& data,
+ const size_t begin,
+ const size_t count,
+ BinarySpaceTree* parent = NULL,
+ const size_t leafSize = 20);
+
+ /**
+ * Construct this node on a subset of the given matrix, starting at column
+ * begin_in and using count_in points. The ordering of that subset of points
+ * will be modified! This is used for recursive tree-building by the other
+ * constructors which don't specify point indices.
+ *
+ * A mapping of the old point indices to the new point indices is filled, but
+ * it is expected that the vector is already allocated with size greater than
+ * or equal to (begin_in + count_in), and if that is not true, invalid memory
+ * reads (and writes) will occur.
+ *
+ * @param data Dataset to create tree from. This will be modified!
+ * @param begin Index of point to start tree construction with.
+ * @param count Number of points to use to construct tree.
+ * @param oldFromNew Vector which will be filled with the old positions for
+ * each new point.
+ * @param leafSize Size of each leaf in the tree.
+ */
+ BinarySpaceTree(MatType& data,
+ const size_t begin,
+ const size_t count,
+ std::vector<size_t>& oldFromNew,
+ BinarySpaceTree* parent = NULL,
+ const size_t leafSize = 20);
+
+ /**
+ * Construct this node on a subset of the given matrix, starting at column
+ * begin_in and using count_in points. The ordering of that subset of points
+ * will be modified! This is used for recursive tree-building by the other
+ * constructors which don't specify point indices.
+ *
+ * A mapping of the old point indices to the new point indices is filled, as
+ * well as a mapping of the new point indices to the old point indices. It is
+ * expected that the vector is already allocated with size greater than or
+ * equal to (begin_in + count_in), and if that is not true, invalid memory
+ * reads (and writes) will occur.
+ *
+ * @param data Dataset to create tree from. This will be modified!
+ * @param begin Index of point to start tree construction with.
+ * @param count Number of points to use to construct tree.
+ * @param oldFromNew Vector which will be filled with the old positions for
+ * each new point.
+ * @param newFromOld Vector which will be filled with the new positions for
+ * each old point.
+ * @param leafSize Size of each leaf in the tree.
+ */
+ BinarySpaceTree(MatType& data,
+ const size_t begin,
+ const size_t count,
+ std::vector<size_t>& oldFromNew,
+ std::vector<size_t>& newFromOld,
+ BinarySpaceTree* parent = NULL,
+ const size_t leafSize = 20);
+
+ /**
+ * Create a binary space tree by copying the other tree. Be careful! This
+ * can take a long time and use a lot of memory.
+ *
+ * @param other Tree to be replicated.
+ */
+ BinarySpaceTree(const BinarySpaceTree& other);
+
+ /**
+ * Deletes this node, deallocating the memory for the children and calling
+ * their destructors in turn. This will invalidate any pointers or references
+ * to any nodes which are children of this one.
+ */
+ ~BinarySpaceTree();
+
+ /**
+ * Find a node in this tree by its begin and count (const).
+ *
+ * Every node is uniquely identified by these two numbers.
+ * This is useful for communicating position over the network,
+ * when pointers would be invalid.
+ *
+ * @param begin The begin() of the node to find.
+ * @param count The count() of the node to find.
+ * @return The found node, or NULL if not found.
+ */
+ const BinarySpaceTree* FindByBeginCount(size_t begin,
+ size_t count) const;
+
+ /**
+ * Find a node in this tree by its begin and count.
+ *
+ * Every node is uniquely identified by these two numbers.
+ * This is useful for communicating position over the network,
+ * when pointers would be invalid.
+ *
+ * @param begin The begin() of the node to find.
+ * @param count The count() of the node to find.
+ * @return The found node, or NULL if not found.
+ */
+ BinarySpaceTree* FindByBeginCount(size_t begin, size_t count);
+
+ //! Return the bound object for this node.
+ const BoundType& Bound() const { return bound; }
+ //! Return the bound object for this node.
+ BoundType& Bound() { return bound; }
+
+ //! Return the statistic object for this node.
+ const StatisticType& Stat() const { return stat; }
+ //! Return the statistic object for this node.
+ StatisticType& Stat() { return stat; }
+
+ //! Return whether or not this node is a leaf (true if it has no children).
+ bool IsLeaf() const;
+
+ //! Return the leaf size.
+ size_t LeafSize() const { return leafSize; }
+ //! Modify the leaf size.
+ size_t& LeafSize() { return leafSize; }
+
+ //! Fills the tree to the specified level.
+ size_t ExtendTree(const size_t level);
+
+ //! Gets the left child of this node.
+ BinarySpaceTree* Left() const { return left; }
+ //! Modify the left child of this node.
+ BinarySpaceTree*& Left() { return left; }
+
+ //! Gets the right child of this node.
+ BinarySpaceTree* Right() const { return right; }
+ //! Modify the right child of this node.
+ BinarySpaceTree*& Right() { return right; }
+
+ //! Gets the parent of this node.
+ BinarySpaceTree* Parent() const { return parent; }
+ //! Modify the parent of this node.
+ BinarySpaceTree*& Parent() { return parent; }
+
+ //! Get the split dimension for this node.
+ size_t SplitDimension() const { return splitDimension; }
+ //! Modify the split dimension for this node.
+ size_t& SplitDimension() { return splitDimension; }
+
+ //! Get the dataset which the tree is built on.
+ const arma::mat& Dataset() const { return dataset; }
+ //! Modify the dataset which the tree is built on. Be careful!
+ arma::mat& Dataset() { return dataset; }
+
+ //! Get the metric which the tree uses.
+ typename BoundType::MetricType Metric() const { return bound.Metric(); }
+
+ //! Get the centroid of the node and store it in the given vector.
+ void Centroid(arma::vec& centroid) { bound.Centroid(centroid); }
+
+ //! Return the number of children in this node.
+ size_t NumChildren() const;
+
+ /**
+ * Return the furthest possible descendant distance. This returns the maximum
+ * distance from the centroid to the edge of the bound and not the empirical
+ * quantity which is the actual furthest descendant distance. So the actual
+ * furthest descendant distance may be less than what this method returns (but
+ * it will never be greater than this).
+ */
+ double FurthestDescendantDistance() const;
+
+ /**
+ * Return the specified child (0 will be left, 1 will be right). If the index
+ * is greater than 1, this will return the right child.
+ *
+ * @param child Index of child to return.
+ */
+ BinarySpaceTree& Child(const size_t child) const;
+
+ //! Return the number of points in this node (0 if not a leaf).
+ size_t NumPoints() const;
+
+ /**
+ * Return the index (with reference to the dataset) of a particular point in
+ * this node. This will happily return invalid indices if the given index is
+ * greater than the number of points in this node (obtained with NumPoints())
+ * -- be careful.
+ *
+ * @param index Index of point for which a dataset index is wanted.
+ */
+ size_t Point(const size_t index) const;
+
+ //! Return the minimum distance to another node.
+ double MinDistance(const BinarySpaceTree* other) const
+ {
+ return bound.MinDistance(other->Bound());
+ }
+
+ //! Return the maximum distance to another node.
+ double MaxDistance(const BinarySpaceTree* other) const
+ {
+ return bound.MaxDistance(other->Bound());
+ }
+
+ //! Return the minimum distance to another point.
+ double MinDistance(const arma::vec& point) const
+ {
+ return bound.MinDistance(point);
+ }
+
+ //! Return the maximum distance to another point.
+ double MaxDistance(const arma::vec& point) const
+ {
+ return bound.MaxDistance(point);
+ }
+
+ /**
+ * Returns the dimension this parent's children are split on.
+ */
+ size_t GetSplitDimension() const;
+
+ /**
+ * Obtains the number of nodes in the tree, starting with this.
+ */
+ size_t TreeSize() const;
+
+ /**
+ * Obtains the number of levels below this node in the tree, starting with
+ * this.
+ */
+ size_t TreeDepth() const;
+
+ //! Return the index of the beginning point of this subset.
+ size_t Begin() const { return begin; }
+ //! Modify the index of the beginning point of this subset.
+ size_t& Begin() { return begin; }
+
+ /**
+ * Gets the index one beyond the last index in the subset.
+ */
+ size_t End() const;
+
+ //! Return the number of points in this subset.
+ size_t Count() const { return count; }
+ //! Modify the number of points in this subset.
+ size_t& Count() { return count; }
+
+ //! Returns false: this tree type does not have self children.
+ static bool HasSelfChildren() { return false; }
+
+ private:
+ /**
+ * Private copy constructor, available only to fill (pad) the tree to a
+ * specified level.
+ */
+ BinarySpaceTree(const size_t begin,
+ const size_t count,
+ BoundType bound,
+ StatisticType stat,
+ const int leafSize = 20) :
+ left(NULL),
+ right(NULL),
+ begin(begin),
+ count(count),
+ bound(bound),
+ stat(stat),
+ leafSize(leafSize) { }
+
+ BinarySpaceTree* CopyMe()
+ {
+ return new BinarySpaceTree(begin, count, bound, stat, leafSize);
+ }
+
+ /**
+ * Splits the current node, assigning its left and right children recursively.
+ *
+ * @param data Dataset which we are using.
+ */
+ void SplitNode(MatType& data);
+
+ /**
+ * Splits the current node, assigning its left and right children recursively.
+ * Also returns a list of the changed indices.
+ *
+ * @param data Dataset which we are using.
+ * @param oldFromNew Vector holding permuted indices.
+ */
+ void SplitNode(MatType& data, std::vector<size_t>& oldFromNew);
+
+ /**
+ * Find the index to split on for this node, given that we are splitting in
+ * the given split dimension on the specified split value.
+ *
+ * @param data Dataset which we are using.
+ * @param splitDim Dimension of dataset to split on.
+ * @param splitVal Value to split on, in the given split dimension.
+ */
+ size_t GetSplitIndex(MatType& data, int splitDim, double splitVal);
+
+ /**
+ * Find the index to split on for this node, given that we are splitting in
+ * the given split dimension on the specified split value. Also returns a
+ * list of the changed indices.
+ *
+ * @param data Dataset which we are using.
+ * @param splitDim Dimension of dataset to split on.
+ * @param splitVal Value to split on, in the given split dimension.
+ * @param oldFromNew Vector holding permuted indices.
+ */
+ size_t GetSplitIndex(MatType& data, int splitDim, double splitVal,
+ std::vector<size_t>& oldFromNew);
+ public:
+ /**
+ * Returns a string representation of this object.
+ */
+ std::string ToString() const;
+
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+// Include implementation.
+#include "binary_space_tree_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,681 +0,0 @@
-/**
- * @file binary_space_tree_impl.hpp
- *
- * Implementation of generalized space partitioning tree.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_IMPL_HPP
-#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_IMPL_HPP
-
-// In case it wasn't included already for some reason.
-#include "binary_space_tree.hpp"
-
-#include <mlpack/core/util/cli.hpp>
-#include <mlpack/core/util/log.hpp>
-#include <mlpack/core/util/string_util.hpp>
-
-namespace mlpack {
-namespace tree {
-
-// Each of these overloads is kept as a separate function to keep the overhead
-// from the two std::vectors out, if possible.
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
- MatType& data,
- const size_t leafSize) :
- left(NULL),
- right(NULL),
- parent(NULL),
- begin(0), /* This root node starts at index 0, */
- count(data.n_cols), /* and spans all of the dataset. */
- leafSize(leafSize),
- bound(data.n_rows),
- dataset(data)
-{
- // Do the actual splitting of this node.
- SplitNode(data);
-
- // Create the statistic depending on if we are a leaf or not.
- stat = StatisticType(*this);
-}
-
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
- MatType& data,
- std::vector<size_t>& oldFromNew,
- const size_t leafSize) :
- left(NULL),
- right(NULL),
- parent(NULL),
- begin(0),
- count(data.n_cols),
- leafSize(leafSize),
- bound(data.n_rows),
- dataset(data)
-{
- // Initialize oldFromNew correctly.
- oldFromNew.resize(data.n_cols);
- for (size_t i = 0; i < data.n_cols; i++)
- oldFromNew[i] = i; // Fill with unharmed indices.
-
- // Now do the actual splitting.
- SplitNode(data, oldFromNew);
-
- // Create the statistic depending on if we are a leaf or not.
- stat = StatisticType(*this);
-}
-
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
- MatType& data,
- std::vector<size_t>& oldFromNew,
- std::vector<size_t>& newFromOld,
- const size_t leafSize) :
- left(NULL),
- right(NULL),
- parent(NULL),
- begin(0),
- count(data.n_cols),
- leafSize(leafSize),
- bound(data.n_rows),
- dataset(data)
-{
- // Initialize the oldFromNew vector correctly.
- oldFromNew.resize(data.n_cols);
- for (size_t i = 0; i < data.n_cols; i++)
- oldFromNew[i] = i; // Fill with unharmed indices.
-
- // Now do the actual splitting.
- SplitNode(data, oldFromNew);
-
- // Create the statistic depending on if we are a leaf or not.
- stat = StatisticType(*this);
-
- // Map the newFromOld indices correctly.
- newFromOld.resize(data.n_cols);
- for (size_t i = 0; i < data.n_cols; i++)
- newFromOld[oldFromNew[i]] = i;
-}
-
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
- MatType& data,
- const size_t begin,
- const size_t count,
- BinarySpaceTree* parent,
- const size_t leafSize) :
- left(NULL),
- right(NULL),
- parent(parent),
- begin(begin),
- count(count),
- leafSize(leafSize),
- bound(data.n_rows),
- dataset(data)
-{
- // Perform the actual splitting.
- SplitNode(data);
-
- // Create the statistic depending on if we are a leaf or not.
- stat = StatisticType(*this);
-}
-
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
- MatType& data,
- const size_t begin,
- const size_t count,
- std::vector<size_t>& oldFromNew,
- BinarySpaceTree* parent,
- const size_t leafSize) :
- left(NULL),
- right(NULL),
- parent(parent),
- begin(begin),
- count(count),
- leafSize(leafSize),
- bound(data.n_rows),
- dataset(data)
-{
- // Hopefully the vector is initialized correctly! We can't check that
- // entirely but we can do a minor sanity check.
- assert(oldFromNew.size() == data.n_cols);
-
- // Perform the actual splitting.
- SplitNode(data, oldFromNew);
-
- // Create the statistic depending on if we are a leaf or not.
- stat = StatisticType(*this);
-}
-
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
- MatType& data,
- const size_t begin,
- const size_t count,
- std::vector<size_t>& oldFromNew,
- std::vector<size_t>& newFromOld,
- BinarySpaceTree* parent,
- const size_t leafSize) :
- left(NULL),
- right(NULL),
- parent(parent),
- begin(begin),
- count(count),
- leafSize(leafSize),
- bound(data.n_rows),
- dataset(data)
-{
- // Hopefully the vector is initialized correctly! We can't check that
- // entirely but we can do a minor sanity check.
- Log::Assert(oldFromNew.size() == data.n_cols);
-
- // Perform the actual splitting.
- SplitNode(data, oldFromNew);
-
- // Create the statistic depending on if we are a leaf or not.
- stat = StatisticType(*this);
-
- // Map the newFromOld indices correctly.
- newFromOld.resize(data.n_cols);
- for (size_t i = 0; i < data.n_cols; i++)
- newFromOld[oldFromNew[i]] = i;
-}
-
-/*
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree() :
- left(NULL),
- right(NULL),
- parent(NULL),
- begin(0),
- count(0),
- bound(),
- stat(),
- leafSize(20) // Default leaf size is 20.
-{
- // Nothing to do.
-}*/
-
-/**
- * Create a binary space tree by copying the other tree. Be careful! This can
- * take a long time and use a lot of memory.
- */
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
- const BinarySpaceTree& other) :
- left(NULL),
- right(NULL),
- parent(other.parent),
- begin(other.begin),
- count(other.count),
- leafSize(other.leafSize),
- bound(other.bound),
- stat(other.stat),
- splitDimension(other.splitDimension),
- furthestDescendantDistance(other.furthestDescendantDistance),
- dataset(other.dataset)
-{
- // Create left and right children (if any).
- if (other.Left())
- {
- left = new BinarySpaceTree(*other.Left());
- left->Parent() = this; // Set parent to this, not other tree.
- }
-
- if (other.Right())
- {
- right = new BinarySpaceTree(*other.Right());
- right->Parent() = this; // Set parent to this, not other tree.
- }
-}
-
-/**
- * Deletes this node, deallocating the memory for the children and calling their
- * destructors in turn. This will invalidate any pointers or references to any
- * nodes which are children of this one.
- */
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::~BinarySpaceTree()
-{
- if (left)
- delete left;
- if (right)
- delete right;
-}
-
-/**
- * Find a node in this tree by its begin and count.
- *
- * Every node is uniquely identified by these two numbers.
- * This is useful for communicating position over the network,
- * when pointers would be invalid.
- *
- * @param queryBegin The Begin() of the node to find.
- * @param queryCount The Count() of the node to find.
- * @return The found node, or NULL if nothing is found.
- */
-template<typename BoundType, typename StatisticType, typename MatType>
-const BinarySpaceTree<BoundType, StatisticType, MatType>*
-BinarySpaceTree<BoundType, StatisticType, MatType>::FindByBeginCount(
- size_t queryBegin,
- size_t queryCount) const
-{
- Log::Assert(queryBegin >= begin);
- Log::Assert(queryCount <= count);
-
- if (begin == queryBegin && count == queryCount)
- return this;
- else if (IsLeaf())
- return NULL;
- else if (queryBegin < right->Begin())
- return left->FindByBeginCount(queryBegin, queryCount);
- else
- return right->FindByBeginCount(queryBegin, queryCount);
-}
-
-/**
- * Find a node in this tree by its begin and count (const).
- *
- * Every node is uniquely identified by these two numbers.
- * This is useful for communicating position over the network,
- * when pointers would be invalid.
- *
- * @param queryBegin the Begin() of the node to find
- * @param queryCount the Count() of the node to find
- * @return the found node, or NULL
- */
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>*
-BinarySpaceTree<BoundType, StatisticType, MatType>::FindByBeginCount(
- const size_t queryBegin,
- const size_t queryCount)
-{
- mlpack::Log::Assert(begin >= queryBegin);
- mlpack::Log::Assert(count <= queryCount);
-
- if (begin == queryBegin && count == queryCount)
- return this;
- else if (IsLeaf())
- return NULL;
- else if (queryBegin < left->End())
- return left->FindByBeginCount(queryBegin, queryCount);
- else if (right)
- return right->FindByBeginCount(queryBegin, queryCount);
- else
- return NULL;
-}
-
-template<typename BoundType, typename StatisticType, typename MatType>
-size_t BinarySpaceTree<BoundType, StatisticType, MatType>::ExtendTree(
- size_t level)
-{
- --level;
- // Return the number of nodes duplicated.
- size_t nodesDuplicated = 0;
- if (level > 0)
- {
- if (!left)
- {
- left = CopyMe();
- ++nodesDuplicated;
- }
- nodesDuplicated += left->ExtendTree(level);
- if (right)
- {
- nodesDuplicated += right->ExtendTree(level);
- }
- }
- return nodesDuplicated;
-}
-
-/* TODO: we can likely calculate this earlier, then store the
- * result in a private member variable; for now, we can
- * just calculate as needed...
- *
- * Also, perhaps we should rewrite these recursive functions
- * to avoid exceeding the stack limit
- */
-
-template<typename BoundType, typename StatisticType, typename MatType>
-size_t BinarySpaceTree<BoundType, StatisticType, MatType>::TreeSize() const
-{
- // Recursively count the nodes on each side of the tree. The plus one is
- // because we have to count this node, too.
- return 1 + (left ? left->TreeSize() : 0) + (right ? right->TreeSize() : 0);
-}
-
-template<typename BoundType, typename StatisticType, typename MatType>
-size_t BinarySpaceTree<BoundType, StatisticType, MatType>::TreeDepth() const
-{
- // Recursively count the depth on each side of the tree. The plus one is
- // because we have to count this node, too.
- return 1 + std::max((left ? left->TreeDepth() : 0),
- (right ? right->TreeDepth() : 0));
-}
-
-template<typename BoundType, typename StatisticType, typename MatType>
-inline bool BinarySpaceTree<BoundType, StatisticType, MatType>::IsLeaf() const
-{
- return !left;
-}
-
-/**
- * Returns the number of children in this node.
- */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t
- BinarySpaceTree<BoundType, StatisticType, MatType>::NumChildren() const
-{
- if (left && right)
- return 2;
- if (left)
- return 1;
-
- return 0;
-}
-
-/**
- * Return the furthest possible descendant distance. This returns the maximum
- * distance from the centroid to the edge of the bound and not the empirical
- * quantity which is the actual furthest descendant distance. So the actual
- * furthest descendant distance may be less than what this method returns (but
- * it will never be greater than this).
- */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline double BinarySpaceTree<BoundType, StatisticType, MatType>::
- FurthestDescendantDistance() const
-{
- return furthestDescendantDistance;
-}
-
-/**
- * Return the specified child.
- */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline BinarySpaceTree<BoundType, StatisticType, MatType>&
- BinarySpaceTree<BoundType, StatisticType, MatType>::Child(
- const size_t child) const
-{
- if (child == 0)
- return *left;
- else
- return *right;
-}
-
-/**
- * Return the number of points contained in this node.
- */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t
-BinarySpaceTree<BoundType, StatisticType, MatType>::NumPoints() const
-{
- if (left)
- return 0;
-
- return count;
-}
-
-/**
- * Return the index of a particular point contained in this node.
- */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t
-BinarySpaceTree<BoundType, StatisticType, MatType>::Point(const size_t index)
- const
-{
- return (begin + index);
-}
-
-/**
- * Gets the index one beyond the last index in the series.
- */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t BinarySpaceTree<BoundType, StatisticType, MatType>::End() const
-{
- return begin + count;
-}
-
-template<typename BoundType, typename StatisticType, typename MatType>
-void
- BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(MatType& data)
-{
- // We need to expand the bounds of this node properly.
- bound |= data.cols(begin, begin + count - 1);
-
- // Calculate the furthest descendant distance.
- furthestDescendantDistance = 0.5 * bound.Diameter();
-
- // Now, check if we need to split at all.
- if (count <= leafSize)
- return; // We can't split this.
-
- // Figure out which dimension to split on.
- size_t splitDim = data.n_rows; // Indicate invalid by maxDim + 1.
- double maxWidth = -1;
-
- // Find the split dimension.
- for (size_t d = 0; d < data.n_rows; d++)
- {
- double width = bound[d].Width();
-
- if (width > maxWidth)
- {
- maxWidth = width;
- splitDim = d;
- }
- }
- splitDimension = splitDim;
-
- // Split in the middle of that dimension.
- double splitVal = bound[splitDim].Mid();
-
- if (maxWidth == 0) // All these points are the same. We can't split.
- return;
-
- // Perform the actual splitting. This will order the dataset such that points
- // with value in dimension split_dim less than or equal to splitVal are on
- // the left of splitCol, and points with value in dimension splitDim greater
- // than splitVal are on the right side of splitCol.
- size_t splitCol = GetSplitIndex(data, splitDim, splitVal);
-
- // Now that we know the split column, we will recursively split the children
- // by calling their constructors (which perform this splitting process).
- left = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, begin,
- splitCol - begin, this, leafSize);
- right = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, splitCol,
- begin + count - splitCol, this, leafSize);
-}
-
-template<typename BoundType, typename StatisticType, typename MatType>
-void BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(
- MatType& data,
- std::vector<size_t>& oldFromNew)
-{
- // This should be a single function for Bound.
- // We need to expand the bounds of this node properly.
- bound |= data.cols(begin, begin + count - 1);
-
- // Calculate the furthest descendant distance.
- furthestDescendantDistance = 0.5 * bound.Diameter();
-
- // First, check if we need to split at all.
- if (count <= leafSize)
- return; // We can't split this.
-
- // Figure out which dimension to split on.
- size_t splitDim = data.n_rows; // Indicate invalid by max_dim + 1.
- double maxWidth = -1;
-
- // Find the split dimension.
- for (size_t d = 0; d < data.n_rows; d++)
- {
- double width = bound[d].Width();
-
- if (width > maxWidth)
- {
- maxWidth = width;
- splitDim = d;
- }
- }
- splitDimension = splitDim;
-
- // Split in the middle of that dimension.
- double splitVal = bound[splitDim].Mid();
-
- if (maxWidth == 0) // All these points are the same. We can't split.
- return;
-
- // Perform the actual splitting. This will order the dataset such that points
- // with value in dimension split_dim less than or equal to splitVal are on
- // the left of splitCol, and points with value in dimension splitDim greater
- // than splitVal are on the right side of splitCol.
- size_t splitCol = GetSplitIndex(data, splitDim, splitVal, oldFromNew);
-
- // Now that we know the split column, we will recursively split the children
- // by calling their constructors (which perform this splitting process).
- left = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, begin,
- splitCol - begin, oldFromNew, this, leafSize);
- right = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, splitCol,
- begin + count - splitCol, oldFromNew, this, leafSize);
-}
-
-template<typename BoundType, typename StatisticType, typename MatType>
-size_t BinarySpaceTree<BoundType, StatisticType, MatType>::GetSplitIndex(
- MatType& data,
- int splitDim,
- double splitVal)
-{
- // This method modifies the input dataset. We loop both from the left and
- // right sides of the points contained in this node. The points less than
- // split_val should be on the left side of the matrix, and the points greater
- // than split_val should be on the right side of the matrix.
- size_t left = begin;
- size_t right = begin + count - 1;
-
- // First half-iteration of the loop is out here because the termination
- // condition is in the middle.
- while ((data(splitDim, left) < splitVal) && (left <= right))
- left++;
- while ((data(splitDim, right) >= splitVal) && (left <= right))
- right--;
-
- while (left <= right)
- {
- // Swap columns.
- data.swap_cols(left, right);
-
- // See how many points on the left are correct. When they are correct,
- // increase the left counter accordingly. When we encounter one that isn't
- // correct, stop. We will switch it later.
- while ((data(splitDim, left) < splitVal) && (left <= right))
- left++;
-
- // Now see how many points on the right are correct. When they are correct,
- // decrease the right counter accordingly. When we encounter one that isn't
- // correct, stop. We will switch it with the wrong point we found in the
- // previous loop.
- while ((data(splitDim, right) >= splitVal) && (left <= right))
- right--;
- }
-
- Log::Assert(left == right + 1);
-
- return left;
-}
-
-template<typename BoundType, typename StatisticType, typename MatType>
-size_t BinarySpaceTree<BoundType, StatisticType, MatType>::GetSplitIndex(
- MatType& data,
- int splitDim,
- double splitVal,
- std::vector<size_t>& oldFromNew)
-{
- // This method modifies the input dataset. We loop both from the left and
- // right sides of the points contained in this node. The points less than
- // split_val should be on the left side of the matrix, and the points greater
- // than split_val should be on the right side of the matrix.
- size_t left = begin;
- size_t right = begin + count - 1;
-
- // First half-iteration of the loop is out here because the termination
- // condition is in the middle.
- while ((data(splitDim, left) < splitVal) && (left <= right))
- left++;
- while ((data(splitDim, right) >= splitVal) && (left <= right))
- right--;
-
- while (left <= right)
- {
- // Swap columns.
- data.swap_cols(left, right);
-
- // Update the indices for what we changed.
- size_t t = oldFromNew[left];
- oldFromNew[left] = oldFromNew[right];
- oldFromNew[right] = t;
-
- // See how many points on the left are correct. When they are correct,
- // increase the left counter accordingly. When we encounter one that isn't
- // correct, stop. We will switch it later.
- while ((data(splitDim, left) < splitVal) && (left <= right))
- left++;
-
- // Now see how many points on the right are correct. When they are correct,
- // decrease the right counter accordingly. When we encounter one that isn't
- // correct, stop. We will switch it with the wrong point we found in the
- // previous loop.
- while ((data(splitDim, right) >= splitVal) && (left <= right))
- right--;
- }
-
- Log::Assert(left == right + 1);
-
- return left;
-}
-
-/**
- * Returns a string representation of this object.
- */
-template<typename BoundType, typename StatisticType, typename MatType>
-std::string BinarySpaceTree<BoundType, StatisticType, MatType>::ToString() const
-{
- std::ostringstream convert;
- convert << "BinarySpaceTree [" << this << "]" << std::endl;
- convert << "begin: " << begin << std::endl;
- convert << "count: " << count << std::endl;
- convert << "bound: " << mlpack::util::Indent(bound.ToString());
- convert << "statistic: " << stat.ToString();
- convert << "leaf size: " << leafSize << std::endl;
- convert << "splitDimension: " << splitDimension << std::endl;
- if (left != NULL)
- {
- convert << "left:" << std::endl;
- convert << mlpack::util::Indent(left->ToString());
- }
- if (right != NULL)
- {
- convert << "right:" << std::endl;
- convert << mlpack::util::Indent(right->ToString());
- }
- return convert.str();
-}
-
-}; // namespace tree
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,681 @@
+/**
+ * @file binary_space_tree_impl.hpp
+ *
+ * Implementation of generalized space partitioning tree.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_IMPL_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_IMPL_HPP
+
+// In case it wasn't included already for some reason.
+#include "binary_space_tree.hpp"
+
+#include <mlpack/core/util/cli.hpp>
+#include <mlpack/core/util/log.hpp>
+#include <mlpack/core/util/string_util.hpp>
+
+namespace mlpack {
+namespace tree {
+
+// Each of these overloads is kept as a separate function to keep the overhead
+// from the two std::vectors out, if possible.
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+ MatType& data,
+ const size_t leafSize) :
+ left(NULL),
+ right(NULL),
+ parent(NULL),
+ begin(0), /* This root node starts at index 0, */
+ count(data.n_cols), /* and spans all of the dataset. */
+ leafSize(leafSize),
+ bound(data.n_rows),
+ dataset(data)
+{
+ // Do the actual splitting of this node.
+ SplitNode(data);
+
+ // Create the statistic depending on if we are a leaf or not.
+ stat = StatisticType(*this);
+}
+
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+ MatType& data,
+ std::vector<size_t>& oldFromNew,
+ const size_t leafSize) :
+ left(NULL),
+ right(NULL),
+ parent(NULL),
+ begin(0),
+ count(data.n_cols),
+ leafSize(leafSize),
+ bound(data.n_rows),
+ dataset(data)
+{
+ // Initialize oldFromNew correctly.
+ oldFromNew.resize(data.n_cols);
+ for (size_t i = 0; i < data.n_cols; i++)
+ oldFromNew[i] = i; // Fill with unharmed indices.
+
+ // Now do the actual splitting.
+ SplitNode(data, oldFromNew);
+
+ // Create the statistic depending on if we are a leaf or not.
+ stat = StatisticType(*this);
+}
+
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+ MatType& data,
+ std::vector<size_t>& oldFromNew,
+ std::vector<size_t>& newFromOld,
+ const size_t leafSize) :
+ left(NULL),
+ right(NULL),
+ parent(NULL),
+ begin(0),
+ count(data.n_cols),
+ leafSize(leafSize),
+ bound(data.n_rows),
+ dataset(data)
+{
+ // Initialize the oldFromNew vector correctly.
+ oldFromNew.resize(data.n_cols);
+ for (size_t i = 0; i < data.n_cols; i++)
+ oldFromNew[i] = i; // Fill with unharmed indices.
+
+ // Now do the actual splitting.
+ SplitNode(data, oldFromNew);
+
+ // Create the statistic depending on if we are a leaf or not.
+ stat = StatisticType(*this);
+
+ // Map the newFromOld indices correctly.
+ newFromOld.resize(data.n_cols);
+ for (size_t i = 0; i < data.n_cols; i++)
+ newFromOld[oldFromNew[i]] = i;
+}
+
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+ MatType& data,
+ const size_t begin,
+ const size_t count,
+ BinarySpaceTree* parent,
+ const size_t leafSize) :
+ left(NULL),
+ right(NULL),
+ parent(parent),
+ begin(begin),
+ count(count),
+ leafSize(leafSize),
+ bound(data.n_rows),
+ dataset(data)
+{
+ // Perform the actual splitting.
+ SplitNode(data);
+
+ // Create the statistic depending on if we are a leaf or not.
+ stat = StatisticType(*this);
+}
+
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+ MatType& data,
+ const size_t begin,
+ const size_t count,
+ std::vector<size_t>& oldFromNew,
+ BinarySpaceTree* parent,
+ const size_t leafSize) :
+ left(NULL),
+ right(NULL),
+ parent(parent),
+ begin(begin),
+ count(count),
+ leafSize(leafSize),
+ bound(data.n_rows),
+ dataset(data)
+{
+ // Hopefully the vector is initialized correctly! We can't check that
+ // entirely but we can do a minor sanity check.
+ assert(oldFromNew.size() == data.n_cols);
+
+ // Perform the actual splitting.
+ SplitNode(data, oldFromNew);
+
+ // Create the statistic depending on if we are a leaf or not.
+ stat = StatisticType(*this);
+}
+
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+ MatType& data,
+ const size_t begin,
+ const size_t count,
+ std::vector<size_t>& oldFromNew,
+ std::vector<size_t>& newFromOld,
+ BinarySpaceTree* parent,
+ const size_t leafSize) :
+ left(NULL),
+ right(NULL),
+ parent(parent),
+ begin(begin),
+ count(count),
+ leafSize(leafSize),
+ bound(data.n_rows),
+ dataset(data)
+{
+ // Hopefully the vector is initialized correctly! We can't check that
+ // entirely but we can do a minor sanity check.
+ Log::Assert(oldFromNew.size() == data.n_cols);
+
+ // Perform the actual splitting.
+ SplitNode(data, oldFromNew);
+
+ // Create the statistic depending on if we are a leaf or not.
+ stat = StatisticType(*this);
+
+ // Map the newFromOld indices correctly.
+ newFromOld.resize(data.n_cols);
+ for (size_t i = 0; i < data.n_cols; i++)
+ newFromOld[oldFromNew[i]] = i;
+}
+
+/*
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree() :
+ left(NULL),
+ right(NULL),
+ parent(NULL),
+ begin(0),
+ count(0),
+ bound(),
+ stat(),
+ leafSize(20) // Default leaf size is 20.
+{
+ // Nothing to do.
+}*/
+
+/**
+ * Create a binary space tree by copying the other tree. Be careful! This can
+ * take a long time and use a lot of memory.
+ */
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+ const BinarySpaceTree& other) :
+ left(NULL),
+ right(NULL),
+ parent(other.parent),
+ begin(other.begin),
+ count(other.count),
+ leafSize(other.leafSize),
+ bound(other.bound),
+ stat(other.stat),
+ splitDimension(other.splitDimension),
+ furthestDescendantDistance(other.furthestDescendantDistance),
+ dataset(other.dataset)
+{
+ // Create left and right children (if any).
+ if (other.Left())
+ {
+ left = new BinarySpaceTree(*other.Left());
+ left->Parent() = this; // Set parent to this, not other tree.
+ }
+
+ if (other.Right())
+ {
+ right = new BinarySpaceTree(*other.Right());
+ right->Parent() = this; // Set parent to this, not other tree.
+ }
+}
+
+/**
+ * Deletes this node, deallocating the memory for the children and calling their
+ * destructors in turn. This will invalidate any pointers or references to any
+ * nodes which are children of this one.
+ */
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::~BinarySpaceTree()
+{
+ if (left)
+ delete left;
+ if (right)
+ delete right;
+}
+
+/**
+ * Find a node in this tree by its begin and count.
+ *
+ * Every node is uniquely identified by these two numbers.
+ * This is useful for communicating position over the network,
+ * when pointers would be invalid.
+ *
+ * @param queryBegin The Begin() of the node to find.
+ * @param queryCount The Count() of the node to find.
+ * @return The found node, or NULL if nothing is found.
+ */
+template<typename BoundType, typename StatisticType, typename MatType>
+const BinarySpaceTree<BoundType, StatisticType, MatType>*
+BinarySpaceTree<BoundType, StatisticType, MatType>::FindByBeginCount(
+ size_t queryBegin,
+ size_t queryCount) const
+{
+ Log::Assert(queryBegin >= begin);
+ Log::Assert(queryCount <= count);
+
+ if (begin == queryBegin && count == queryCount)
+ return this;
+ else if (IsLeaf())
+ return NULL;
+ else if (queryBegin < right->Begin())
+ return left->FindByBeginCount(queryBegin, queryCount);
+ else
+ return right->FindByBeginCount(queryBegin, queryCount);
+}
+
+/**
+ * Find a node in this tree by its begin and count (const).
+ *
+ * Every node is uniquely identified by these two numbers.
+ * This is useful for communicating position over the network,
+ * when pointers would be invalid.
+ *
+ * @param queryBegin the Begin() of the node to find
+ * @param queryCount the Count() of the node to find
+ * @return the found node, or NULL
+ */
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>*
+BinarySpaceTree<BoundType, StatisticType, MatType>::FindByBeginCount(
+ const size_t queryBegin,
+ const size_t queryCount)
+{
+ mlpack::Log::Assert(begin >= queryBegin);
+ mlpack::Log::Assert(count <= queryCount);
+
+ if (begin == queryBegin && count == queryCount)
+ return this;
+ else if (IsLeaf())
+ return NULL;
+ else if (queryBegin < left->End())
+ return left->FindByBeginCount(queryBegin, queryCount);
+ else if (right)
+ return right->FindByBeginCount(queryBegin, queryCount);
+ else
+ return NULL;
+}
+
+template<typename BoundType, typename StatisticType, typename MatType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType>::ExtendTree(
+ size_t level)
+{
+ --level;
+ // Return the number of nodes duplicated.
+ size_t nodesDuplicated = 0;
+ if (level > 0)
+ {
+ if (!left)
+ {
+ left = CopyMe();
+ ++nodesDuplicated;
+ }
+ nodesDuplicated += left->ExtendTree(level);
+ if (right)
+ {
+ nodesDuplicated += right->ExtendTree(level);
+ }
+ }
+ return nodesDuplicated;
+}
+
+/* TODO: we can likely calculate this earlier, then store the
+ * result in a private member variable; for now, we can
+ * just calculate as needed...
+ *
+ * Also, perhaps we should rewrite these recursive functions
+ * to avoid exceeding the stack limit
+ */
+
+template<typename BoundType, typename StatisticType, typename MatType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType>::TreeSize() const
+{
+ // Recursively count the nodes on each side of the tree. The plus one is
+ // because we have to count this node, too.
+ return 1 + (left ? left->TreeSize() : 0) + (right ? right->TreeSize() : 0);
+}
+
+template<typename BoundType, typename StatisticType, typename MatType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType>::TreeDepth() const
+{
+ // Recursively count the depth on each side of the tree. The plus one is
+ // because we have to count this node, too.
+ return 1 + std::max((left ? left->TreeDepth() : 0),
+ (right ? right->TreeDepth() : 0));
+}
+
+template<typename BoundType, typename StatisticType, typename MatType>
+inline bool BinarySpaceTree<BoundType, StatisticType, MatType>::IsLeaf() const
+{
+ return !left;
+}
+
+/**
+ * Returns the number of children in this node.
+ */
+template<typename BoundType, typename StatisticType, typename MatType>
+inline size_t
+ BinarySpaceTree<BoundType, StatisticType, MatType>::NumChildren() const
+{
+ if (left && right)
+ return 2;
+ if (left)
+ return 1;
+
+ return 0;
+}
+
+/**
+ * Return the furthest possible descendant distance. This returns the maximum
+ * distance from the centroid to the edge of the bound and not the empirical
+ * quantity which is the actual furthest descendant distance. So the actual
+ * furthest descendant distance may be less than what this method returns (but
+ * it will never be greater than this).
+ */
+template<typename BoundType, typename StatisticType, typename MatType>
+inline double BinarySpaceTree<BoundType, StatisticType, MatType>::
+ FurthestDescendantDistance() const
+{
+ return furthestDescendantDistance;
+}
+
+/**
+ * Return the specified child.
+ */
+template<typename BoundType, typename StatisticType, typename MatType>
+inline BinarySpaceTree<BoundType, StatisticType, MatType>&
+ BinarySpaceTree<BoundType, StatisticType, MatType>::Child(
+ const size_t child) const
+{
+ if (child == 0)
+ return *left;
+ else
+ return *right;
+}
+
+/**
+ * Return the number of points contained in this node.
+ */
+template<typename BoundType, typename StatisticType, typename MatType>
+inline size_t
+BinarySpaceTree<BoundType, StatisticType, MatType>::NumPoints() const
+{
+ if (left)
+ return 0;
+
+ return count;
+}
+
+/**
+ * Return the index of a particular point contained in this node.
+ */
+template<typename BoundType, typename StatisticType, typename MatType>
+inline size_t
+BinarySpaceTree<BoundType, StatisticType, MatType>::Point(const size_t index)
+ const
+{
+ return (begin + index);
+}
+
+/**
+ * Gets the index one beyond the last index in the series.
+ */
+template<typename BoundType, typename StatisticType, typename MatType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType>::End() const
+{
+ return begin + count;
+}
+
+template<typename BoundType, typename StatisticType, typename MatType>
+void
+ BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(MatType& data)
+{
+ // We need to expand the bounds of this node properly.
+ bound |= data.cols(begin, begin + count - 1);
+
+ // Calculate the furthest descendant distance.
+ furthestDescendantDistance = 0.5 * bound.Diameter();
+
+ // Now, check if we need to split at all.
+ if (count <= leafSize)
+ return; // We can't split this.
+
+ // Figure out which dimension to split on.
+ size_t splitDim = data.n_rows; // Indicate invalid by maxDim + 1.
+ double maxWidth = -1;
+
+ // Find the split dimension.
+ for (size_t d = 0; d < data.n_rows; d++)
+ {
+ double width = bound[d].Width();
+
+ if (width > maxWidth)
+ {
+ maxWidth = width;
+ splitDim = d;
+ }
+ }
+ splitDimension = splitDim;
+
+ // Split in the middle of that dimension.
+ double splitVal = bound[splitDim].Mid();
+
+ if (maxWidth == 0) // All these points are the same. We can't split.
+ return;
+
+ // Perform the actual splitting. This will order the dataset such that points
+ // with value in dimension split_dim less than or equal to splitVal are on
+ // the left of splitCol, and points with value in dimension splitDim greater
+ // than splitVal are on the right side of splitCol.
+ size_t splitCol = GetSplitIndex(data, splitDim, splitVal);
+
+ // Now that we know the split column, we will recursively split the children
+ // by calling their constructors (which perform this splitting process).
+ left = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, begin,
+ splitCol - begin, this, leafSize);
+ right = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, splitCol,
+ begin + count - splitCol, this, leafSize);
+}
+
+template<typename BoundType, typename StatisticType, typename MatType>
+void BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(
+ MatType& data,
+ std::vector<size_t>& oldFromNew)
+{
+ // This should be a single function for Bound.
+ // We need to expand the bounds of this node properly.
+ bound |= data.cols(begin, begin + count - 1);
+
+ // Calculate the furthest descendant distance.
+ furthestDescendantDistance = 0.5 * bound.Diameter();
+
+ // First, check if we need to split at all.
+ if (count <= leafSize)
+ return; // We can't split this.
+
+ // Figure out which dimension to split on.
+ size_t splitDim = data.n_rows; // Indicate invalid by max_dim + 1.
+ double maxWidth = -1;
+
+ // Find the split dimension.
+ for (size_t d = 0; d < data.n_rows; d++)
+ {
+ double width = bound[d].Width();
+
+ if (width > maxWidth)
+ {
+ maxWidth = width;
+ splitDim = d;
+ }
+ }
+ splitDimension = splitDim;
+
+ // Split in the middle of that dimension.
+ double splitVal = bound[splitDim].Mid();
+
+ if (maxWidth == 0) // All these points are the same. We can't split.
+ return;
+
+ // Perform the actual splitting. This will order the dataset such that points
+ // with value in dimension split_dim less than or equal to splitVal are on
+ // the left of splitCol, and points with value in dimension splitDim greater
+ // than splitVal are on the right side of splitCol.
+ size_t splitCol = GetSplitIndex(data, splitDim, splitVal, oldFromNew);
+
+ // Now that we know the split column, we will recursively split the children
+ // by calling their constructors (which perform this splitting process).
+ left = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, begin,
+ splitCol - begin, oldFromNew, this, leafSize);
+ right = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, splitCol,
+ begin + count - splitCol, oldFromNew, this, leafSize);
+}
+
+template<typename BoundType, typename StatisticType, typename MatType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType>::GetSplitIndex(
+ MatType& data,
+ int splitDim,
+ double splitVal)
+{
+ // This method modifies the input dataset. We loop both from the left and
+ // right sides of the points contained in this node. The points less than
+ // split_val should be on the left side of the matrix, and the points greater
+ // than split_val should be on the right side of the matrix.
+ size_t left = begin;
+ size_t right = begin + count - 1;
+
+ // First half-iteration of the loop is out here because the termination
+ // condition is in the middle.
+ while ((data(splitDim, left) < splitVal) && (left <= right))
+ left++;
+ while ((data(splitDim, right) >= splitVal) && (left <= right))
+ right--;
+
+ while (left <= right)
+ {
+ // Swap columns.
+ data.swap_cols(left, right);
+
+ // See how many points on the left are correct. When they are correct,
+ // increase the left counter accordingly. When we encounter one that isn't
+ // correct, stop. We will switch it later.
+ while ((data(splitDim, left) < splitVal) && (left <= right))
+ left++;
+
+ // Now see how many points on the right are correct. When they are correct,
+ // decrease the right counter accordingly. When we encounter one that isn't
+ // correct, stop. We will switch it with the wrong point we found in the
+ // previous loop.
+ while ((data(splitDim, right) >= splitVal) && (left <= right))
+ right--;
+ }
+
+ Log::Assert(left == right + 1);
+
+ return left;
+}
+
+template<typename BoundType, typename StatisticType, typename MatType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType>::GetSplitIndex(
+ MatType& data,
+ int splitDim,
+ double splitVal,
+ std::vector<size_t>& oldFromNew)
+{
+ // This method modifies the input dataset. We loop both from the left and
+ // right sides of the points contained in this node. The points less than
+ // split_val should be on the left side of the matrix, and the points greater
+ // than split_val should be on the right side of the matrix.
+ size_t left = begin;
+ size_t right = begin + count - 1;
+
+ // First half-iteration of the loop is out here because the termination
+ // condition is in the middle.
+ while ((data(splitDim, left) < splitVal) && (left <= right))
+ left++;
+ while ((data(splitDim, right) >= splitVal) && (left <= right))
+ right--;
+
+ while (left <= right)
+ {
+ // Swap columns.
+ data.swap_cols(left, right);
+
+ // Update the indices for what we changed.
+ size_t t = oldFromNew[left];
+ oldFromNew[left] = oldFromNew[right];
+ oldFromNew[right] = t;
+
+ // See how many points on the left are correct. When they are correct,
+ // increase the left counter accordingly. When we encounter one that isn't
+ // correct, stop. We will switch it later.
+ while ((data(splitDim, left) < splitVal) && (left <= right))
+ left++;
+
+ // Now see how many points on the right are correct. When they are correct,
+ // decrease the right counter accordingly. When we encounter one that isn't
+ // correct, stop. We will switch it with the wrong point we found in the
+ // previous loop.
+ while ((data(splitDim, right) >= splitVal) && (left <= right))
+ right--;
+ }
+
+ Log::Assert(left == right + 1);
+
+ return left;
+}
+
+/**
+ * Returns a string representation of this object.
+ */
+template<typename BoundType, typename StatisticType, typename MatType>
+std::string BinarySpaceTree<BoundType, StatisticType, MatType>::ToString() const
+{
+ std::ostringstream convert;
+ convert << "BinarySpaceTree [" << this << "]" << std::endl;
+ convert << "begin: " << begin << std::endl;
+ convert << "count: " << count << std::endl;
+ convert << "bound: " << mlpack::util::Indent(bound.ToString());
+ convert << "statistic: " << stat.ToString();
+ convert << "leaf size: " << leafSize << std::endl;
+ convert << "splitDimension: " << splitDimension << std::endl;
+ if (left != NULL)
+ {
+ convert << "left:" << std::endl;
+ convert << mlpack::util::Indent(left->ToString());
+ }
+ if (right != NULL)
+ {
+ convert << "right:" << std::endl;
+ convert << mlpack::util::Indent(right->ToString());
+ }
+ return convert.str();
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,73 +0,0 @@
-/**
- * @file dual_tree_traverser.hpp
- * @author Ryan Curtin
- *
- * Defines the DualTreeTraverser for the BinarySpaceTree tree type. This is a
- * nested class of BinarySpaceTree which traverses two trees in a depth-first
- * manner with a given set of rules which indicate the branches which can be
- * pruned and the order in which to recurse.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_HPP
-#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_HPP
-
-#include <mlpack/core.hpp>
-
-#include "binary_space_tree.hpp"
-
-namespace mlpack {
-namespace tree {
-
-template<typename BoundType, typename StatisticType, typename MatType>
-template<typename RuleType>
-class BinarySpaceTree<BoundType, StatisticType, MatType>::DualTreeTraverser
-{
- public:
- /**
- * Instantiate the dual-tree traverser with the given rule set.
- */
- DualTreeTraverser(RuleType& rule);
-
- /**
- * Traverse the two trees. This does not reset the number of prunes.
- *
- * @param queryNode The query node to be traversed.
- * @param referenceNode The reference node to be traversed.
- */
- void Traverse(BinarySpaceTree& queryNode, BinarySpaceTree& referenceNode);
-
- //! Get the number of prunes.
- size_t NumPrunes() const { return numPrunes; }
- //! Modify the number of prunes.
- size_t& NumPrunes() { return numPrunes; }
-
- private:
- //! Reference to the rules with which the trees will be traversed.
- RuleType& rule;
-
- //! The number of nodes which have been pruned during traversal.
- size_t numPrunes;
-};
-
-}; // namespace tree
-}; // namespace mlpack
-
-// Include implementation.
-#include "dual_tree_traverser_impl.hpp"
-
-#endif // __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_HPP
-
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,73 @@
+/**
+ * @file dual_tree_traverser.hpp
+ * @author Ryan Curtin
+ *
+ * Defines the DualTreeTraverser for the BinarySpaceTree tree type. This is a
+ * nested class of BinarySpaceTree which traverses two trees in a depth-first
+ * manner with a given set of rules which indicate the branches which can be
+ * pruned and the order in which to recurse.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_HPP
+
+#include <mlpack/core.hpp>
+
+#include "binary_space_tree.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename BoundType, typename StatisticType, typename MatType>
+template<typename RuleType>
+class BinarySpaceTree<BoundType, StatisticType, MatType>::DualTreeTraverser
+{
+ public:
+ /**
+ * Instantiate the dual-tree traverser with the given rule set.
+ */
+ DualTreeTraverser(RuleType& rule);
+
+ /**
+ * Traverse the two trees. This does not reset the number of prunes.
+ *
+ * @param queryNode The query node to be traversed.
+ * @param referenceNode The reference node to be traversed.
+ */
+ void Traverse(BinarySpaceTree& queryNode, BinarySpaceTree& referenceNode);
+
+ //! Get the number of prunes.
+ size_t NumPrunes() const { return numPrunes; }
+ //! Modify the number of prunes.
+ size_t& NumPrunes() { return numPrunes; }
+
+ private:
+ //! Reference to the rules with which the trees will be traversed.
+ RuleType& rule;
+
+ //! The number of nodes which have been pruned during traversal.
+ size_t numPrunes;
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+// Include implementation.
+#include "dual_tree_traverser_impl.hpp"
+
+#endif // __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_HPP
+
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,255 +0,0 @@
-/**
- * @file dual_tree_traverser_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of the DualTreeTraverser for BinarySpaceTree. This is a way
- * to perform a dual-tree traversal of two trees. The trees must be the same
- * type.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
-#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
-
-// In case it hasn't been included yet.
-#include "dual_tree_traverser.hpp"
-
-namespace mlpack {
-namespace tree {
-
-template<typename BoundType, typename StatisticType, typename MatType>
-template<typename RuleType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::
-DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
- rule(rule),
- numPrunes(0)
-{ /* Nothing to do. */ }
-
-template<typename BoundType, typename StatisticType, typename MatType>
-template<typename RuleType>
-void BinarySpaceTree<BoundType, StatisticType, MatType>::
-DualTreeTraverser<RuleType>::Traverse(
- BinarySpaceTree<BoundType, StatisticType, MatType>& queryNode,
- BinarySpaceTree<BoundType, StatisticType, MatType>& referenceNode)
-{
- // If both are leaves, we must evaluate the base case.
- if (queryNode.IsLeaf() && referenceNode.IsLeaf())
- {
- // Loop through each of the points in each node.
- for (size_t query = queryNode.Begin(); query < queryNode.End(); ++query)
- {
- // See if we need to investigate this point (this function should be
- // implemented for the single-tree recursion too).
- const double score = rule.Score(query, referenceNode);
-
- if (score == DBL_MAX)
- continue; // We can't improve this particular point.
-
- for (size_t ref = referenceNode.Begin(); ref < referenceNode.End(); ++ref)
- rule.BaseCase(query, ref);
- }
- }
- else if ((!queryNode.IsLeaf()) && referenceNode.IsLeaf())
- {
- // We have to recurse down the query node. In this case the recursion order
- // does not matter.
- double leftScore = rule.Score(*queryNode.Left(), referenceNode);
-
- if (leftScore != DBL_MAX)
- Traverse(*queryNode.Left(), referenceNode);
- else
- ++numPrunes;
-
- double rightScore = rule.Score(*queryNode.Right(), referenceNode);
-
- if (rightScore != DBL_MAX)
- Traverse(*queryNode.Right(), referenceNode);
- else
- ++numPrunes;
- }
- else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
- {
- // We have to recurse down the reference node. In this case the recursion
- // order does matter.
- double leftScore = rule.Score(queryNode, *referenceNode.Left());
- double rightScore = rule.Score(queryNode, *referenceNode.Right());
-
- if (leftScore < rightScore)
- {
- // Recurse to the left.
- Traverse(queryNode, *referenceNode.Left());
-
- // Is it still valid to recurse to the right?
- rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore);
-
- if (rightScore != DBL_MAX)
- Traverse(queryNode, *referenceNode.Right());
- else
- ++numPrunes;
- }
- else if (rightScore < leftScore)
- {
- // Recurse to the right.
- Traverse(queryNode, *referenceNode.Right());
-
- // Is it still valid to recurse to the left?
- leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore);
-
- if (leftScore != DBL_MAX)
- Traverse(queryNode, *referenceNode.Left());
- else
- ++numPrunes;
- }
- else // leftScore is equal to rightScore.
- {
- if (leftScore == DBL_MAX)
- {
- numPrunes += 2;
- }
- else
- {
- // Choose the left first.
- Traverse(queryNode, *referenceNode.Left());
-
- rightScore = rule.Rescore(queryNode, *referenceNode.Right(),
- rightScore);
-
- if (rightScore != DBL_MAX)
- Traverse(queryNode, *referenceNode.Right());
- else
- ++numPrunes;
- }
- }
- }
- else
- {
- // We have to recurse down both query and reference nodes. Because the
- // query descent order does not matter, we will go to the left query child
- // first.
- double leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
- double rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right());
-
- if (leftScore < rightScore)
- {
- // Recurse to the left.
- Traverse(*queryNode.Left(), *referenceNode.Left());
-
- // Is it still valid to recurse to the right?
- rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
- rightScore);
-
- if (rightScore != DBL_MAX)
- Traverse(*queryNode.Left(), *referenceNode.Right());
- else
- ++numPrunes;
- }
- else if (rightScore < leftScore)
- {
- // Recurse to the right.
- Traverse(*queryNode.Left(), *referenceNode.Right());
-
- // Is it still valid to recurse to the left?
- leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(),
- leftScore);
-
- if (leftScore != DBL_MAX)
- Traverse(*queryNode.Left(), *referenceNode.Left());
- else
- ++numPrunes;
- }
- else
- {
- if (leftScore == DBL_MAX)
- {
- numPrunes += 2;
- }
- else
- {
- // Choose the left first.
- Traverse(*queryNode.Left(), *referenceNode.Left());
-
- // Is it still valid to recurse to the right?
- rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
- rightScore);
-
- if (rightScore != DBL_MAX)
- Traverse(*queryNode.Left(), *referenceNode.Right());
- else
- ++numPrunes;
- }
- }
-
- // Now recurse down the right query node.
- leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left());
- rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right());
-
- if (leftScore < rightScore)
- {
- // Recurse to the left.
- Traverse(*queryNode.Right(), *referenceNode.Left());
-
- // Is it still valid to recurse to the right?
- rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
- rightScore);
-
- if (rightScore != DBL_MAX)
- Traverse(*queryNode.Right(), *referenceNode.Right());
- else
- ++numPrunes;
- }
- else if (rightScore < leftScore)
- {
- // Recurse to the right.
- Traverse(*queryNode.Right(), *referenceNode.Right());
-
- // Is it still valid to recurse to the left?
- leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(),
- leftScore);
-
- if (leftScore != DBL_MAX)
- Traverse(*queryNode.Right(), *referenceNode.Left());
- else
- ++numPrunes;
- }
- else
- {
- if (leftScore == DBL_MAX)
- {
- numPrunes += 2;
- }
- else
- {
- // Choose the left first.
- Traverse(*queryNode.Right(), *referenceNode.Left());
-
- // Is it still valid to recurse to the right?
- rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
- rightScore);
-
- if (rightScore != DBL_MAX)
- Traverse(*queryNode.Right(), *referenceNode.Right());
- else
- ++numPrunes;
- }
- }
- }
-}
-
-}; // namespace tree
-}; // namespace mlpack
-
-#endif // __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
-
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,255 @@
+/**
+ * @file dual_tree_traverser_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the DualTreeTraverser for BinarySpaceTree. This is a way
+ * to perform a dual-tree traversal of two trees. The trees must be the same
+ * type.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "dual_tree_traverser.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename BoundType, typename StatisticType, typename MatType>
+template<typename RuleType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::
+DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
+ rule(rule),
+ numPrunes(0)
+{ /* Nothing to do. */ }
+
+template<typename BoundType, typename StatisticType, typename MatType>
+template<typename RuleType>
+void BinarySpaceTree<BoundType, StatisticType, MatType>::
+DualTreeTraverser<RuleType>::Traverse(
+ BinarySpaceTree<BoundType, StatisticType, MatType>& queryNode,
+ BinarySpaceTree<BoundType, StatisticType, MatType>& referenceNode)
+{
+ // If both are leaves, we must evaluate the base case.
+ if (queryNode.IsLeaf() && referenceNode.IsLeaf())
+ {
+ // Loop through each of the points in each node.
+ for (size_t query = queryNode.Begin(); query < queryNode.End(); ++query)
+ {
+ // See if we need to investigate this point (this function should be
+ // implemented for the single-tree recursion too).
+ const double score = rule.Score(query, referenceNode);
+
+ if (score == DBL_MAX)
+ continue; // We can't improve this particular point.
+
+ for (size_t ref = referenceNode.Begin(); ref < referenceNode.End(); ++ref)
+ rule.BaseCase(query, ref);
+ }
+ }
+ else if ((!queryNode.IsLeaf()) && referenceNode.IsLeaf())
+ {
+ // We have to recurse down the query node. In this case the recursion order
+ // does not matter.
+ double leftScore = rule.Score(*queryNode.Left(), referenceNode);
+
+ if (leftScore != DBL_MAX)
+ Traverse(*queryNode.Left(), referenceNode);
+ else
+ ++numPrunes;
+
+ double rightScore = rule.Score(*queryNode.Right(), referenceNode);
+
+ if (rightScore != DBL_MAX)
+ Traverse(*queryNode.Right(), referenceNode);
+ else
+ ++numPrunes;
+ }
+ else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
+ {
+ // We have to recurse down the reference node. In this case the recursion
+ // order does matter.
+ double leftScore = rule.Score(queryNode, *referenceNode.Left());
+ double rightScore = rule.Score(queryNode, *referenceNode.Right());
+
+ if (leftScore < rightScore)
+ {
+ // Recurse to the left.
+ Traverse(queryNode, *referenceNode.Left());
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore);
+
+ if (rightScore != DBL_MAX)
+ Traverse(queryNode, *referenceNode.Right());
+ else
+ ++numPrunes;
+ }
+ else if (rightScore < leftScore)
+ {
+ // Recurse to the right.
+ Traverse(queryNode, *referenceNode.Right());
+
+ // Is it still valid to recurse to the left?
+ leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore);
+
+ if (leftScore != DBL_MAX)
+ Traverse(queryNode, *referenceNode.Left());
+ else
+ ++numPrunes;
+ }
+ else // leftScore is equal to rightScore.
+ {
+ if (leftScore == DBL_MAX)
+ {
+ numPrunes += 2;
+ }
+ else
+ {
+ // Choose the left first.
+ Traverse(queryNode, *referenceNode.Left());
+
+ rightScore = rule.Rescore(queryNode, *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ Traverse(queryNode, *referenceNode.Right());
+ else
+ ++numPrunes;
+ }
+ }
+ }
+ else
+ {
+ // We have to recurse down both query and reference nodes. Because the
+ // query descent order does not matter, we will go to the left query child
+ // first.
+ double leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
+ double rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right());
+
+ if (leftScore < rightScore)
+ {
+ // Recurse to the left.
+ Traverse(*queryNode.Left(), *referenceNode.Left());
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ Traverse(*queryNode.Left(), *referenceNode.Right());
+ else
+ ++numPrunes;
+ }
+ else if (rightScore < leftScore)
+ {
+ // Recurse to the right.
+ Traverse(*queryNode.Left(), *referenceNode.Right());
+
+ // Is it still valid to recurse to the left?
+ leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(),
+ leftScore);
+
+ if (leftScore != DBL_MAX)
+ Traverse(*queryNode.Left(), *referenceNode.Left());
+ else
+ ++numPrunes;
+ }
+ else
+ {
+ if (leftScore == DBL_MAX)
+ {
+ numPrunes += 2;
+ }
+ else
+ {
+ // Choose the left first.
+ Traverse(*queryNode.Left(), *referenceNode.Left());
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ Traverse(*queryNode.Left(), *referenceNode.Right());
+ else
+ ++numPrunes;
+ }
+ }
+
+ // Now recurse down the right query node.
+ leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left());
+ rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right());
+
+ if (leftScore < rightScore)
+ {
+ // Recurse to the left.
+ Traverse(*queryNode.Right(), *referenceNode.Left());
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ Traverse(*queryNode.Right(), *referenceNode.Right());
+ else
+ ++numPrunes;
+ }
+ else if (rightScore < leftScore)
+ {
+ // Recurse to the right.
+ Traverse(*queryNode.Right(), *referenceNode.Right());
+
+ // Is it still valid to recurse to the left?
+ leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(),
+ leftScore);
+
+ if (leftScore != DBL_MAX)
+ Traverse(*queryNode.Right(), *referenceNode.Left());
+ else
+ ++numPrunes;
+ }
+ else
+ {
+ if (leftScore == DBL_MAX)
+ {
+ numPrunes += 2;
+ }
+ else
+ {
+ // Choose the left first.
+ Traverse(*queryNode.Right(), *referenceNode.Left());
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ Traverse(*queryNode.Right(), *referenceNode.Right());
+ else
+ ++numPrunes;
+ }
+ }
+ }
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
+
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/single_tree_traverser.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/binary_space_tree/single_tree_traverser.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/single_tree_traverser.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,72 +0,0 @@
-/**
- * @file single_tree_traverser.hpp
- * @author Ryan Curtin
- *
- * A nested class of BinarySpaceTree which traverses the entire tree with a
- * given set of rules which indicate the branches which can be pruned and the
- * order in which to recurse. This traverser is a depth-first traverser.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_SINGLE_TREE_TRAVERSER_HPP
-#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_SINGLE_TREE_TRAVERSER_HPP
-
-#include <mlpack/core.hpp>
-
-#include "binary_space_tree.hpp"
-
-namespace mlpack {
-namespace tree {
-
-template<typename BoundType, typename StatisticType, typename MatType>
-template<typename RuleType>
-class BinarySpaceTree<BoundType, StatisticType, MatType>::SingleTreeTraverser
-{
- public:
- /**
- * Instantiate the single tree traverser with the given rule set.
- */
- SingleTreeTraverser(RuleType& rule);
-
- /**
- * Traverse the tree with the given point.
- *
- * @param queryIndex The index of the point in the query set which is being
- * used as the query point.
- * @param referenceNode The tree node to be traversed.
- */
- void Traverse(const size_t queryIndex, BinarySpaceTree& referenceNode);
-
- //! Get the number of prunes.
- size_t NumPrunes() const { return numPrunes; }
- //! Modify the number of prunes.
- size_t& NumPrunes() { return numPrunes; }
-
- private:
- //! Reference to the rules with which the tree will be traversed.
- RuleType& rule;
-
- //! The number of nodes which have been pruned during traversal.
- size_t numPrunes;
-};
-
-}; // namespace tree
-}; // namespace mlpack
-
-// Include implementation.
-#include "single_tree_traverser_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/single_tree_traverser.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/binary_space_tree/single_tree_traverser.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/single_tree_traverser.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/single_tree_traverser.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,72 @@
+/**
+ * @file single_tree_traverser.hpp
+ * @author Ryan Curtin
+ *
+ * A nested class of BinarySpaceTree which traverses the entire tree with a
+ * given set of rules which indicate the branches which can be pruned and the
+ * order in which to recurse. This traverser is a depth-first traverser.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_SINGLE_TREE_TRAVERSER_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_SINGLE_TREE_TRAVERSER_HPP
+
+#include <mlpack/core.hpp>
+
+#include "binary_space_tree.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename BoundType, typename StatisticType, typename MatType>
+template<typename RuleType>
+class BinarySpaceTree<BoundType, StatisticType, MatType>::SingleTreeTraverser
+{
+ public:
+ /**
+ * Instantiate the single tree traverser with the given rule set.
+ */
+ SingleTreeTraverser(RuleType& rule);
+
+ /**
+ * Traverse the tree with the given point.
+ *
+ * @param queryIndex The index of the point in the query set which is being
+ * used as the query point.
+ * @param referenceNode The tree node to be traversed.
+ */
+ void Traverse(const size_t queryIndex, BinarySpaceTree& referenceNode);
+
+ //! Get the number of prunes.
+ size_t NumPrunes() const { return numPrunes; }
+ //! Modify the number of prunes.
+ size_t& NumPrunes() { return numPrunes; }
+
+ private:
+ //! Reference to the rules with which the tree will be traversed.
+ RuleType& rule;
+
+ //! The number of nodes which have been pruned during traversal.
+ size_t numPrunes;
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+// Include implementation.
+#include "single_tree_traverser_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,115 +0,0 @@
-/**
- * @file single_tree_traverser_impl.hpp
- * @author Ryan Curtin
- *
- * A nested class of BinarySpaceTree which traverses the entire tree with a
- * given set of rules which indicate the branches which can be pruned and the
- * order in which to recurse. This traverser is a depth-first traverser.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_SINGLE_TREE_TRAVERSER_IMPL_HPP
-#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_SINGLE_TREE_TRAVERSER_IMPL_HPP
-
-// In case it hasn't been included yet.
-#include "single_tree_traverser.hpp"
-
-#include <stack>
-
-namespace mlpack {
-namespace tree {
-
-template<typename BoundType, typename StatisticType, typename MatType>
-template<typename RuleType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::
-SingleTreeTraverser<RuleType>::SingleTreeTraverser(RuleType& rule) :
- rule(rule),
- numPrunes(0)
-{ /* Nothing to do. */ }
-
-template<typename BoundType, typename StatisticType, typename MatType>
-template<typename RuleType>
-void BinarySpaceTree<BoundType, StatisticType, MatType>::
-SingleTreeTraverser<RuleType>::Traverse(
- const size_t queryIndex,
- BinarySpaceTree<BoundType, StatisticType, MatType>& referenceNode)
-{
- // If we are a leaf, run the base case as necessary.
- if (referenceNode.IsLeaf())
- {
- for (size_t i = referenceNode.Begin(); i < referenceNode.End(); ++i)
- rule.BaseCase(queryIndex, i);
- }
- else
- {
- // If either score is DBL_MAX, we do not recurse into that node.
- double leftScore = rule.Score(queryIndex, *referenceNode.Left());
- double rightScore = rule.Score(queryIndex, *referenceNode.Right());
-
- if (leftScore < rightScore)
- {
- // Recurse to the left.
- Traverse(queryIndex, *referenceNode.Left());
-
- // Is it still valid to recurse to the right?
- rightScore = rule.Rescore(queryIndex, *referenceNode.Right(), rightScore);
-
- if (rightScore != DBL_MAX)
- Traverse(queryIndex, *referenceNode.Right()); // Recurse to the right.
- else
- ++numPrunes;
- }
- else if (rightScore < leftScore)
- {
- // Recurse to the right.
- Traverse(queryIndex, *referenceNode.Right());
-
- // Is it still valid to recurse to the left?
- leftScore = rule.Rescore(queryIndex, *referenceNode.Left(), leftScore);
-
- if (leftScore != DBL_MAX)
- Traverse(queryIndex, *referenceNode.Left()); // Recurse to the left.
- else
- ++numPrunes;
- }
- else // leftScore is equal to rightScore.
- {
- if (leftScore == DBL_MAX)
- {
- numPrunes += 2; // Pruned both left and right.
- }
- else
- {
- // Choose the left first.
- Traverse(queryIndex, *referenceNode.Left());
-
- // Is it still valid to recurse to the right?
- rightScore = rule.Rescore(queryIndex, *referenceNode.Right(),
- rightScore);
-
- if (rightScore != DBL_MAX)
- Traverse(queryIndex, *referenceNode.Right());
- else
- ++numPrunes;
- }
- }
- }
-}
-
-}; // namespace tree
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,115 @@
+/**
+ * @file single_tree_traverser_impl.hpp
+ * @author Ryan Curtin
+ *
+ * A nested class of BinarySpaceTree which traverses the entire tree with a
+ * given set of rules which indicate the branches which can be pruned and the
+ * order in which to recurse. This traverser is a depth-first traverser.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_SINGLE_TREE_TRAVERSER_IMPL_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_SINGLE_TREE_TRAVERSER_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "single_tree_traverser.hpp"
+
+#include <stack>
+
+namespace mlpack {
+namespace tree {
+
+template<typename BoundType, typename StatisticType, typename MatType>
+template<typename RuleType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::
+SingleTreeTraverser<RuleType>::SingleTreeTraverser(RuleType& rule) :
+ rule(rule),
+ numPrunes(0)
+{ /* Nothing to do. */ }
+
+template<typename BoundType, typename StatisticType, typename MatType>
+template<typename RuleType>
+void BinarySpaceTree<BoundType, StatisticType, MatType>::
+SingleTreeTraverser<RuleType>::Traverse(
+ const size_t queryIndex,
+ BinarySpaceTree<BoundType, StatisticType, MatType>& referenceNode)
+{
+ // If we are a leaf, run the base case as necessary.
+ if (referenceNode.IsLeaf())
+ {
+ for (size_t i = referenceNode.Begin(); i < referenceNode.End(); ++i)
+ rule.BaseCase(queryIndex, i);
+ }
+ else
+ {
+ // If either score is DBL_MAX, we do not recurse into that node.
+ double leftScore = rule.Score(queryIndex, *referenceNode.Left());
+ double rightScore = rule.Score(queryIndex, *referenceNode.Right());
+
+ if (leftScore < rightScore)
+ {
+ // Recurse to the left.
+ Traverse(queryIndex, *referenceNode.Left());
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(queryIndex, *referenceNode.Right(), rightScore);
+
+ if (rightScore != DBL_MAX)
+ Traverse(queryIndex, *referenceNode.Right()); // Recurse to the right.
+ else
+ ++numPrunes;
+ }
+ else if (rightScore < leftScore)
+ {
+ // Recurse to the right.
+ Traverse(queryIndex, *referenceNode.Right());
+
+ // Is it still valid to recurse to the left?
+ leftScore = rule.Rescore(queryIndex, *referenceNode.Left(), leftScore);
+
+ if (leftScore != DBL_MAX)
+ Traverse(queryIndex, *referenceNode.Left()); // Recurse to the left.
+ else
+ ++numPrunes;
+ }
+ else // leftScore is equal to rightScore.
+ {
+ if (leftScore == DBL_MAX)
+ {
+ numPrunes += 2; // Pruned both left and right.
+ }
+ else
+ {
+ // Choose the left first.
+ Traverse(queryIndex, *referenceNode.Left());
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(queryIndex, *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ Traverse(queryIndex, *referenceNode.Right());
+ else
+ ++numPrunes;
+ }
+ }
+ }
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/traits.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/binary_space_tree/traits.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/traits.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,64 +0,0 @@
-/**
- * @file traits.hpp
- * @author Ryan Curtin
- *
- * Specialization of the TreeTraits class for the BinarySpaceTree type of tree.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_TRAITS_HPP
-#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_TRAITS_HPP
-
-#include <mlpack/core/tree/tree_traits.hpp>
-
-namespace mlpack {
-namespace tree {
-
-/**
- * This is a specialization of the TreeType class to the BinarySpaceTree tree
- * type. It defines characteristics of the binary space tree, and is used to
- * help write tree-independent (but still optimized) tree-based algorithms. See
- * mlpack/core/tree/tree_traits.hpp for more information.
- */
-template<typename BoundType,
- typename StatisticType,
- typename MatType>
-class TreeTraits<BinarySpaceTree<BoundType, StatisticType, MatType> >
-{
- public:
- /**
- * The binary space tree cannot easily calculate the distance from a node to
- * its parent; so BinarySpaceTree<...>::ParentDistance() does not exist.
- */
- static const bool HasParentDistance = false;
-
- /**
- * Each binary space tree node has two children which represent
- * non-overlapping subsets of the space which the node represents. Therefore,
- * children are not overlapping.
- */
- static const bool HasOverlappingChildren = false;
-
- /**
- * There is no guarantee that the first point in a node is its centroid.
- */
- static const bool FirstPointIsCentroid = false;
-};
-
-}; // namespace tree
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/traits.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/binary_space_tree/traits.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/traits.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree/traits.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,64 @@
+/**
+ * @file traits.hpp
+ * @author Ryan Curtin
+ *
+ * Specialization of the TreeTraits class for the BinarySpaceTree type of tree.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_TRAITS_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_TRAITS_HPP
+
+#include <mlpack/core/tree/tree_traits.hpp>
+
+namespace mlpack {
+namespace tree {
+
+/**
+ * This is a specialization of the TreeType class to the BinarySpaceTree tree
+ * type. It defines characteristics of the binary space tree, and is used to
+ * help write tree-independent (but still optimized) tree-based algorithms. See
+ * mlpack/core/tree/tree_traits.hpp for more information.
+ */
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType>
+class TreeTraits<BinarySpaceTree<BoundType, StatisticType, MatType> >
+{
+ public:
+ /**
+ * The binary space tree cannot easily calculate the distance from a node to
+ * its parent; so BinarySpaceTree<...>::ParentDistance() does not exist.
+ */
+ static const bool HasParentDistance = false;
+
+ /**
+ * Each binary space tree node has two children which represent
+ * non-overlapping subsets of the space which the node represents. Therefore,
+ * children are not overlapping.
+ */
+ static const bool HasOverlappingChildren = false;
+
+ /**
+ * There is no guarantee that the first point in a node is its centroid.
+ */
+ static const bool FirstPointIsCentroid = false;
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/binary_space_tree.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,31 +0,0 @@
-/**
- * @file binary_space_tree.hpp
- * @author Ryan Curtin
- *
- * Include all the necessary files to use the BinarySpaceTree class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_HPP
-#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_HPP
-
-#include "bounds.hpp"
-#include "binary_space_tree/binary_space_tree.hpp"
-#include "binary_space_tree/single_tree_traverser.hpp"
-#include "binary_space_tree/dual_tree_traverser.hpp"
-#include "binary_space_tree/traits.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/binary_space_tree.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/binary_space_tree.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,31 @@
+/**
+ * @file binary_space_tree.hpp
+ * @author Ryan Curtin
+ *
+ * Include all the necessary files to use the BinarySpaceTree class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_HPP
+
+#include "bounds.hpp"
+#include "binary_space_tree/binary_space_tree.hpp"
+#include "binary_space_tree/single_tree_traverser.hpp"
+#include "binary_space_tree/dual_tree_traverser.hpp"
+#include "binary_space_tree/traits.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/bounds.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/bounds.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/bounds.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,32 +0,0 @@
-/**
- * @file bounds.hpp
- *
- * Bounds that are useful for binary space partitioning trees.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#ifndef __MLPACK_CORE_TREE_BOUNDS_HPP
-#define __MLPACK_CORE_TREE_BOUNDS_HPP
-
-#include <mlpack/core/math/range.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-
-#include "hrectbound.hpp"
-#include "periodichrectbound.hpp"
-#include "ballbound.hpp"
-
-#endif // __MLPACK_CORE_TREE_BOUNDS_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/bounds.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/bounds.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/bounds.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/bounds.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,32 @@
+/**
+ * @file bounds.hpp
+ *
+ * Bounds that are useful for binary space partitioning trees.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef __MLPACK_CORE_TREE_BOUNDS_HPP
+#define __MLPACK_CORE_TREE_BOUNDS_HPP
+
+#include <mlpack/core/math/range.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+
+#include "hrectbound.hpp"
+#include "periodichrectbound.hpp"
+#include "ballbound.hpp"
+
+#endif // __MLPACK_CORE_TREE_BOUNDS_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/cover_tree.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/cover_tree/cover_tree.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/cover_tree.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,452 +0,0 @@
-/**
- * @file cover_tree.hpp
- * @author Ryan Curtin
- *
- * Definition of CoverTree, which can be used in place of the BinarySpaceTree.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
-#define __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
-
-#include <mlpack/core.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-#include "first_point_is_root.hpp"
-#include "../statistic.hpp"
-
-namespace mlpack {
-namespace tree {
-
-/**
- * A cover tree is a tree specifically designed to speed up nearest-neighbor
- * computation in high-dimensional spaces. Each non-leaf node references a
- * point and has a nonzero number of children, including a "self-child" which
- * references the same point. A leaf node represents only one point.
- *
- * The tree can be thought of as a hierarchy with the root node at the top level
- * and the leaf nodes at the bottom level. Each level in the tree has an
- * assigned 'scale' i. The tree follows these three conditions:
- *
- * - nesting: the level C_i is a subset of the level C_{i - 1}.
- * - covering: all node in level C_{i - 1} have at least one node in the
- * level C_i with distance less than or equal to b^i (exactly one of these
- * is a parent of the point in level C_{i - 1}.
- * - separation: all nodes in level C_i have distance greater than b^i to all
- * other nodes in level C_i.
- *
- * The value 'b' refers to the base, which is a parameter of the tree. These
- * three properties make the cover tree very good for fast, high-dimensional
- * nearest-neighbor search.
- *
- * The theoretical structure of the tree contains many 'implicit' nodes which
- * only have a "self-child" (a child referencing the same point, but at a lower
- * scale level). This practical implementation only constructs explicit nodes
- * -- non-leaf nodes with more than one child. A leaf node has no children, and
- * its scale level is INT_MIN.
- *
- * For more information on cover trees, see
- *
- * @code
- * @inproceedings{
- * author = {Beygelzimer, Alina and Kakade, Sham and Langford, John},
- * title = {Cover trees for nearest neighbor},
- * booktitle = {Proceedings of the 23rd International Conference on Machine
- * Learning},
- * series = {ICML '06},
- * year = {2006},
- * pages = {97--104]
- * }
- * @endcode
- *
- * For information on runtime bounds of the nearest-neighbor computation using
- * cover trees, see the following paper, presented at NIPS 2009:
- *
- * @code
- * @inproceedings{
- * author = {Ram, P., and Lee, D., and March, W.B., and Gray, A.G.},
- * title = {Linear-time Algorithms for Pairwise Statistical Problems},
- * booktitle = {Advances in Neural Information Processing Systems 22},
- * editor = {Y. Bengio and D. Schuurmans and J. Lafferty and C.K.I. Williams
- * and A. Culotta},
- * pages = {1527--1535},
- * year = {2009}
- * }
- * @endcode
- *
- * The CoverTree class offers three template parameters; a custom metric type
- * can be used with MetricType (this class defaults to the L2-squared metric).
- * The root node's point can be chosen with the RootPointPolicy; by default, the
- * FirstPointIsRoot policy is used, meaning the first point in the dataset is
- * used. The StatisticType policy allows you to define statistics which can be
- * gathered during the creation of the tree.
- *
- * @tparam MetricType Metric type to use during tree construction.
- * @tparam RootPointPolicy Determines which point to use as the root node.
- * @tparam StatisticType Statistic to be used during tree creation.
- */
-template<typename MetricType = metric::LMetric<2, true>,
- typename RootPointPolicy = FirstPointIsRoot,
- typename StatisticType = EmptyStatistic>
-class CoverTree
-{
- public:
- typedef arma::mat Mat;
-
- /**
- * Create the cover tree with the given dataset and given base.
- * The dataset will not be modified during the building procedure (unlike
- * BinarySpaceTree).
- *
- * The last argument will be removed in mlpack 1.1.0 (see #274 and #273).
- *
- * @param dataset Reference to the dataset to build a tree on.
- * @param base Base to use during tree building (default 2.0).
- */
- CoverTree(const arma::mat& dataset,
- const double base = 2.0,
- MetricType* metric = NULL);
-
- /**
- * Create the cover tree with the given dataset and the given instantiated
- * metric. Optionally, set the base. The dataset will not be modified during
- * the building procedure (unlike BinarySpaceTree).
- *
- * @param dataset Reference to the dataset to build a tree on.
- * @param metric Instantiated metric to use during tree building.
- * @param base Base to use during tree building (default 2.0).
- */
- CoverTree(const arma::mat& dataset,
- MetricType& metric,
- const double base = 2.0);
-
- /**
- * Construct a child cover tree node. This constructor is not meant to be
- * used externally, but it could be used to insert another node into a tree.
- * This procedure uses only one vector for the near set, the far set, and the
- * used set (this is to prevent unnecessary memory allocation in recursive
- * calls to this constructor). Therefore, the size of the near set, far set,
- * and used set must be passed in. The near set will be entirely used up, and
- * some of the far set may be used. The value of usedSetSize will be set to
- * the number of points used in the construction of this node, and the value
- * of farSetSize will be modified to reflect the number of points in the far
- * set _after_ the construction of this node.
- *
- * If you are calling this manually, be careful that the given scale is
- * as small as possible, or you may be creating an implicit node in your tree.
- *
- * @param dataset Reference to the dataset to build a tree on.
- * @param base Base to use during tree building.
- * @param pointIndex Index of the point this node references.
- * @param scale Scale of this level in the tree.
- * @param parent Parent of this node (NULL indicates no parent).
- * @param parentDistance Distance to the parent node.
- * @param indices Array of indices, ordered [ nearSet | farSet | usedSet ];
- * will be modified to [ farSet | usedSet ].
- * @param distances Array of distances, ordered the same way as the indices.
- * These represent the distances between the point specified by pointIndex
- * and each point in the indices array.
- * @param nearSetSize Size of the near set; if 0, this will be a leaf.
- * @param farSetSize Size of the far set; may be modified (if this node uses
- * any points in the far set).
- * @param usedSetSize The number of points used will be added to this number.
- */
- CoverTree(const arma::mat& dataset,
- const double base,
- const size_t pointIndex,
- const int scale,
- CoverTree* parent,
- const double parentDistance,
- arma::Col<size_t>& indices,
- arma::vec& distances,
- size_t nearSetSize,
- size_t& farSetSize,
- size_t& usedSetSize,
- MetricType& metric = NULL);
-
- /**
- * Manually construct a cover tree node; no tree assembly is done in this
- * constructor, and children must be added manually (use Children()). This
- * constructor is useful when the tree is being "imported" into the CoverTree
- * class after being created in some other manner.
- *
- * @param dataset Reference to the dataset this node is a part of.
- * @param base Base that was used for tree building.
- * @param pointIndex Index of the point in the dataset which this node refers
- * to.
- * @param scale Scale of this node's level in the tree.
- * @param parent Parent node (NULL indicates no parent).
- * @param parentDistance Distance to parent node point.
- * @param furthestDescendantDistance Distance to furthest descendant point.
- * @param metric Instantiated metric (optional).
- */
- CoverTree(const arma::mat& dataset,
- const double base,
- const size_t pointIndex,
- const int scale,
- CoverTree* parent,
- const double parentDistance,
- const double furthestDescendantDistance,
- MetricType* metric = NULL);
-
- /**
- * Create a cover tree from another tree. Be careful! This may use a lot of
- * memory and take a lot of time.
- *
- * @param other Cover tree to copy from.
- */
- CoverTree(const CoverTree& other);
-
- /**
- * Delete this cover tree node and its children.
- */
- ~CoverTree();
-
- //! A single-tree cover tree traverser; see single_tree_traverser.hpp for
- //! implementation.
- template<typename RuleType>
- class SingleTreeTraverser;
-
- //! A dual-tree cover tree traverser; see dual_tree_traverser.hpp.
- template<typename RuleType>
- class DualTreeTraverser;
-
- //! Get a reference to the dataset.
- const arma::mat& Dataset() const { return dataset; }
-
- //! Get the index of the point which this node represents.
- size_t Point() const { return point; }
- //! For compatibility with other trees; the argument is ignored.
- size_t Point(const size_t) const { return point; }
-
- // Fake
- CoverTree* Left() const { return NULL; }
- CoverTree* Right() const { return NULL; }
- size_t Begin() const { return 0; }
- size_t Count() const { return 0; }
- size_t End() const { return 0; }
- bool IsLeaf() const { return (children.size() == 0); }
- size_t NumPoints() const { return 1; }
-
- //! Get a particular child node.
- const CoverTree& Child(const size_t index) const { return *children[index]; }
- //! Modify a particular child node.
- CoverTree& Child(const size_t index) { return *children[index]; }
-
- //! Get the number of children.
- size_t NumChildren() const { return children.size(); }
-
- //! Get the children.
- const std::vector<CoverTree*>& Children() const { return children; }
- //! Modify the children manually (maybe not a great idea).
- std::vector<CoverTree*>& Children() { return children; }
-
- //! Get the scale of this node.
- int Scale() const { return scale; }
- //! Modify the scale of this node. Be careful...
- int& Scale() { return scale; }
-
- //! Get the base.
- double Base() const { return base; }
- //! Modify the base; don't do this, you'll break everything.
- double& Base() { return base; }
-
- //! Get the statistic for this node.
- const StatisticType& Stat() const { return stat; }
- //! Modify the statistic for this node.
- StatisticType& Stat() { return stat; }
-
- //! Return the minimum distance to another node.
- double MinDistance(const CoverTree* other) const;
-
- //! Return the minimum distance to another node given that the point-to-point
- //! distance has already been calculated.
- double MinDistance(const CoverTree* other, const double distance) const;
-
- //! Return the minimum distance to another point.
- double MinDistance(const arma::vec& other) const;
-
- //! Return the minimum distance to another point given that the distance from
- //! the center to the point has already been calculated.
- double MinDistance(const arma::vec& other, const double distance) const;
-
- //! Return the maximum distance to another node.
- double MaxDistance(const CoverTree* other) const;
-
- //! Return the maximum distance to another node given that the point-to-point
- //! distance has already been calculated.
- double MaxDistance(const CoverTree* other, const double distance) const;
-
- //! Return the maximum distance to another point.
- double MaxDistance(const arma::vec& other) const;
-
- //! Return the maximum distance to another point given that the distance from
- //! the center to the point has already been calculated.
- double MaxDistance(const arma::vec& other, const double distance) const;
-
- //! Returns true: this tree does have self-children.
- static bool HasSelfChildren() { return true; }
-
- //! Get the parent node.
- CoverTree* Parent() const { return parent; }
- //! Modify the parent node.
- CoverTree*& Parent() { return parent; }
-
- //! Get the distance to the parent.
- double ParentDistance() const { return parentDistance; }
- //! Modify the distance to the parent.
- double& ParentDistance() { return parentDistance; }
-
- //! Get the distance to the furthest descendant.
- double FurthestDescendantDistance() const
- { return furthestDescendantDistance; }
- //! Modify the distance to the furthest descendant.
- double& FurthestDescendantDistance() { return furthestDescendantDistance; }
-
- //! Get the centroid of the node and store it in the given vector.
- void Centroid(arma::vec& centroid) const { centroid = dataset.col(point); }
-
- //! Get the instantiated metric.
- MetricType& Metric() const { return *metric; }
-
- private:
- //! Reference to the matrix which this tree is built on.
- const arma::mat& dataset;
-
- //! Index of the point in the matrix which this node represents.
- size_t point;
-
- //! The list of children; the first is the self-child.
- std::vector<CoverTree*> children;
-
- //! Scale level of the node.
- int scale;
-
- //! The base used to construct the tree.
- double base;
-
- //! The instantiated statistic.
- StatisticType stat;
-
- //! The parent node (NULL if this is the root of the tree).
- CoverTree* parent;
-
- //! Distance to the parent.
- double parentDistance;
-
- //! Distance to the furthest descendant.
- double furthestDescendantDistance;
-
- //! Whether or not we need to destroy the metric in the destructor.
- bool localMetric;
-
- //! The metric used for this tree.
- MetricType* metric;
-
- /**
- * Create the children for this node.
- */
- void CreateChildren(arma::Col<size_t>& indices,
- arma::vec& distances,
- size_t nearSetSize,
- size_t& farSetSize,
- size_t& usedSetSize);
-
- /**
- * Fill the vector of distances with the distances between the point specified
- * by pointIndex and each point in the indices array. The distances of the
- * first pointSetSize points in indices are calculated (so, this does not
- * necessarily need to use all of the points in the arrays).
- *
- * @param pointIndex Point to build the distances for.
- * @param indices List of indices to compute distances for.
- * @param distances Vector to store calculated distances in.
- * @param pointSetSize Number of points in arrays to calculate distances for.
- */
- void ComputeDistances(const size_t pointIndex,
- const arma::Col<size_t>& indices,
- arma::vec& distances,
- const size_t pointSetSize);
- /**
- * Split the given indices and distances into a near and a far set, returning
- * the number of points in the near set. The distances must already be
- * initialized. This will order the indices and distances such that the
- * points in the near set make up the first part of the array and the far set
- * makes up the rest: [ nearSet | farSet ].
- *
- * @param indices List of indices; will be reordered.
- * @param distances List of distances; will be reordered.
- * @param bound If the distance is less than or equal to this bound, the point
- * is placed into the near set.
- * @param pointSetSize Size of point set (because we may be sorting a smaller
- * list than the indices vector will hold).
- */
- size_t SplitNearFar(arma::Col<size_t>& indices,
- arma::vec& distances,
- const double bound,
- const size_t pointSetSize);
-
- /**
- * Assuming that the list of indices and distances is sorted as
- * [ childFarSet | childUsedSet | farSet | usedSet ],
- * resort the sets so the organization is
- * [ childFarSet | farSet | childUsedSet | usedSet ].
- *
- * The size_t parameters specify the sizes of each set in the array. Only the
- * ordering of the indices and distances arrays will be modified (not their
- * actual contents).
- *
- * The size of any of the four sets can be zero and this method will handle
- * that case accordingly.
- *
- * @param indices List of indices to sort.
- * @param distances List of distances to sort.
- * @param childFarSetSize Number of points in child far set (childFarSet).
- * @param childUsedSetSize Number of points in child used set (childUsedSet).
- * @param farSetSize Number of points in far set (farSet).
- */
- size_t SortPointSet(arma::Col<size_t>& indices,
- arma::vec& distances,
- const size_t childFarSetSize,
- const size_t childUsedSetSize,
- const size_t farSetSize);
-
- void MoveToUsedSet(arma::Col<size_t>& indices,
- arma::vec& distances,
- size_t& nearSetSize,
- size_t& farSetSize,
- size_t& usedSetSize,
- arma::Col<size_t>& childIndices,
- const size_t childFarSetSize,
- const size_t childUsedSetSize);
- size_t PruneFarSet(arma::Col<size_t>& indices,
- arma::vec& distances,
- const double bound,
- const size_t nearSetSize,
- const size_t pointSetSize);
- public:
- /**
- * Returns a string representation of this object.
- */
- std::string ToString() const;
-};
-
-}; // namespace tree
-}; // namespace mlpack
-
-// Include implementation.
-#include "cover_tree_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/cover_tree.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/cover_tree/cover_tree.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/cover_tree.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/cover_tree.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,452 @@
+/**
+ * @file cover_tree.hpp
+ * @author Ryan Curtin
+ *
+ * Definition of CoverTree, which can be used in place of the BinarySpaceTree.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
+#define __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+#include "first_point_is_root.hpp"
+#include "../statistic.hpp"
+
+namespace mlpack {
+namespace tree {
+
+/**
+ * A cover tree is a tree specifically designed to speed up nearest-neighbor
+ * computation in high-dimensional spaces. Each non-leaf node references a
+ * point and has a nonzero number of children, including a "self-child" which
+ * references the same point. A leaf node represents only one point.
+ *
+ * The tree can be thought of as a hierarchy with the root node at the top level
+ * and the leaf nodes at the bottom level. Each level in the tree has an
+ * assigned 'scale' i. The tree follows these three conditions:
+ *
+ * - nesting: the level C_i is a subset of the level C_{i - 1}.
+ * - covering: all node in level C_{i - 1} have at least one node in the
+ * level C_i with distance less than or equal to b^i (exactly one of these
+ * is a parent of the point in level C_{i - 1}.
+ * - separation: all nodes in level C_i have distance greater than b^i to all
+ * other nodes in level C_i.
+ *
+ * The value 'b' refers to the base, which is a parameter of the tree. These
+ * three properties make the cover tree very good for fast, high-dimensional
+ * nearest-neighbor search.
+ *
+ * The theoretical structure of the tree contains many 'implicit' nodes which
+ * only have a "self-child" (a child referencing the same point, but at a lower
+ * scale level). This practical implementation only constructs explicit nodes
+ * -- non-leaf nodes with more than one child. A leaf node has no children, and
+ * its scale level is INT_MIN.
+ *
+ * For more information on cover trees, see
+ *
+ * @code
+ * @inproceedings{
+ * author = {Beygelzimer, Alina and Kakade, Sham and Langford, John},
+ * title = {Cover trees for nearest neighbor},
+ * booktitle = {Proceedings of the 23rd International Conference on Machine
+ * Learning},
+ * series = {ICML '06},
+ * year = {2006},
+ * pages = {97--104]
+ * }
+ * @endcode
+ *
+ * For information on runtime bounds of the nearest-neighbor computation using
+ * cover trees, see the following paper, presented at NIPS 2009:
+ *
+ * @code
+ * @inproceedings{
+ * author = {Ram, P., and Lee, D., and March, W.B., and Gray, A.G.},
+ * title = {Linear-time Algorithms for Pairwise Statistical Problems},
+ * booktitle = {Advances in Neural Information Processing Systems 22},
+ * editor = {Y. Bengio and D. Schuurmans and J. Lafferty and C.K.I. Williams
+ * and A. Culotta},
+ * pages = {1527--1535},
+ * year = {2009}
+ * }
+ * @endcode
+ *
+ * The CoverTree class offers three template parameters; a custom metric type
+ * can be used with MetricType (this class defaults to the L2-squared metric).
+ * The root node's point can be chosen with the RootPointPolicy; by default, the
+ * FirstPointIsRoot policy is used, meaning the first point in the dataset is
+ * used. The StatisticType policy allows you to define statistics which can be
+ * gathered during the creation of the tree.
+ *
+ * @tparam MetricType Metric type to use during tree construction.
+ * @tparam RootPointPolicy Determines which point to use as the root node.
+ * @tparam StatisticType Statistic to be used during tree creation.
+ */
+template<typename MetricType = metric::LMetric<2, true>,
+ typename RootPointPolicy = FirstPointIsRoot,
+ typename StatisticType = EmptyStatistic>
+class CoverTree
+{
+ public:
+ typedef arma::mat Mat;
+
+ /**
+ * Create the cover tree with the given dataset and given base.
+ * The dataset will not be modified during the building procedure (unlike
+ * BinarySpaceTree).
+ *
+ * The last argument will be removed in mlpack 1.1.0 (see #274 and #273).
+ *
+ * @param dataset Reference to the dataset to build a tree on.
+ * @param base Base to use during tree building (default 2.0).
+ */
+ CoverTree(const arma::mat& dataset,
+ const double base = 2.0,
+ MetricType* metric = NULL);
+
+ /**
+ * Create the cover tree with the given dataset and the given instantiated
+ * metric. Optionally, set the base. The dataset will not be modified during
+ * the building procedure (unlike BinarySpaceTree).
+ *
+ * @param dataset Reference to the dataset to build a tree on.
+ * @param metric Instantiated metric to use during tree building.
+ * @param base Base to use during tree building (default 2.0).
+ */
+ CoverTree(const arma::mat& dataset,
+ MetricType& metric,
+ const double base = 2.0);
+
+ /**
+ * Construct a child cover tree node. This constructor is not meant to be
+ * used externally, but it could be used to insert another node into a tree.
+ * This procedure uses only one vector for the near set, the far set, and the
+ * used set (this is to prevent unnecessary memory allocation in recursive
+ * calls to this constructor). Therefore, the size of the near set, far set,
+ * and used set must be passed in. The near set will be entirely used up, and
+ * some of the far set may be used. The value of usedSetSize will be set to
+ * the number of points used in the construction of this node, and the value
+ * of farSetSize will be modified to reflect the number of points in the far
+ * set _after_ the construction of this node.
+ *
+ * If you are calling this manually, be careful that the given scale is
+ * as small as possible, or you may be creating an implicit node in your tree.
+ *
+ * @param dataset Reference to the dataset to build a tree on.
+ * @param base Base to use during tree building.
+ * @param pointIndex Index of the point this node references.
+ * @param scale Scale of this level in the tree.
+ * @param parent Parent of this node (NULL indicates no parent).
+ * @param parentDistance Distance to the parent node.
+ * @param indices Array of indices, ordered [ nearSet | farSet | usedSet ];
+ * will be modified to [ farSet | usedSet ].
+ * @param distances Array of distances, ordered the same way as the indices.
+ * These represent the distances between the point specified by pointIndex
+ * and each point in the indices array.
+ * @param nearSetSize Size of the near set; if 0, this will be a leaf.
+ * @param farSetSize Size of the far set; may be modified (if this node uses
+ * any points in the far set).
+ * @param usedSetSize The number of points used will be added to this number.
+ */
+ CoverTree(const arma::mat& dataset,
+ const double base,
+ const size_t pointIndex,
+ const int scale,
+ CoverTree* parent,
+ const double parentDistance,
+ arma::Col<size_t>& indices,
+ arma::vec& distances,
+ size_t nearSetSize,
+ size_t& farSetSize,
+ size_t& usedSetSize,
+ MetricType& metric = NULL);
+
+ /**
+ * Manually construct a cover tree node; no tree assembly is done in this
+ * constructor, and children must be added manually (use Children()). This
+ * constructor is useful when the tree is being "imported" into the CoverTree
+ * class after being created in some other manner.
+ *
+ * @param dataset Reference to the dataset this node is a part of.
+ * @param base Base that was used for tree building.
+ * @param pointIndex Index of the point in the dataset which this node refers
+ * to.
+ * @param scale Scale of this node's level in the tree.
+ * @param parent Parent node (NULL indicates no parent).
+ * @param parentDistance Distance to parent node point.
+ * @param furthestDescendantDistance Distance to furthest descendant point.
+ * @param metric Instantiated metric (optional).
+ */
+ CoverTree(const arma::mat& dataset,
+ const double base,
+ const size_t pointIndex,
+ const int scale,
+ CoverTree* parent,
+ const double parentDistance,
+ const double furthestDescendantDistance,
+ MetricType* metric = NULL);
+
+ /**
+ * Create a cover tree from another tree. Be careful! This may use a lot of
+ * memory and take a lot of time.
+ *
+ * @param other Cover tree to copy from.
+ */
+ CoverTree(const CoverTree& other);
+
+ /**
+ * Delete this cover tree node and its children.
+ */
+ ~CoverTree();
+
+ //! A single-tree cover tree traverser; see single_tree_traverser.hpp for
+ //! implementation.
+ template<typename RuleType>
+ class SingleTreeTraverser;
+
+ //! A dual-tree cover tree traverser; see dual_tree_traverser.hpp.
+ template<typename RuleType>
+ class DualTreeTraverser;
+
+ //! Get a reference to the dataset.
+ const arma::mat& Dataset() const { return dataset; }
+
+ //! Get the index of the point which this node represents.
+ size_t Point() const { return point; }
+ //! For compatibility with other trees; the argument is ignored.
+ size_t Point(const size_t) const { return point; }
+
+ // Fake
+ CoverTree* Left() const { return NULL; }
+ CoverTree* Right() const { return NULL; }
+ size_t Begin() const { return 0; }
+ size_t Count() const { return 0; }
+ size_t End() const { return 0; }
+ bool IsLeaf() const { return (children.size() == 0); }
+ size_t NumPoints() const { return 1; }
+
+ //! Get a particular child node.
+ const CoverTree& Child(const size_t index) const { return *children[index]; }
+ //! Modify a particular child node.
+ CoverTree& Child(const size_t index) { return *children[index]; }
+
+ //! Get the number of children.
+ size_t NumChildren() const { return children.size(); }
+
+ //! Get the children.
+ const std::vector<CoverTree*>& Children() const { return children; }
+ //! Modify the children manually (maybe not a great idea).
+ std::vector<CoverTree*>& Children() { return children; }
+
+ //! Get the scale of this node.
+ int Scale() const { return scale; }
+ //! Modify the scale of this node. Be careful...
+ int& Scale() { return scale; }
+
+ //! Get the base.
+ double Base() const { return base; }
+ //! Modify the base; don't do this, you'll break everything.
+ double& Base() { return base; }
+
+ //! Get the statistic for this node.
+ const StatisticType& Stat() const { return stat; }
+ //! Modify the statistic for this node.
+ StatisticType& Stat() { return stat; }
+
+ //! Return the minimum distance to another node.
+ double MinDistance(const CoverTree* other) const;
+
+ //! Return the minimum distance to another node given that the point-to-point
+ //! distance has already been calculated.
+ double MinDistance(const CoverTree* other, const double distance) const;
+
+ //! Return the minimum distance to another point.
+ double MinDistance(const arma::vec& other) const;
+
+ //! Return the minimum distance to another point given that the distance from
+ //! the center to the point has already been calculated.
+ double MinDistance(const arma::vec& other, const double distance) const;
+
+ //! Return the maximum distance to another node.
+ double MaxDistance(const CoverTree* other) const;
+
+ //! Return the maximum distance to another node given that the point-to-point
+ //! distance has already been calculated.
+ double MaxDistance(const CoverTree* other, const double distance) const;
+
+ //! Return the maximum distance to another point.
+ double MaxDistance(const arma::vec& other) const;
+
+ //! Return the maximum distance to another point given that the distance from
+ //! the center to the point has already been calculated.
+ double MaxDistance(const arma::vec& other, const double distance) const;
+
+ //! Returns true: this tree does have self-children.
+ static bool HasSelfChildren() { return true; }
+
+ //! Get the parent node.
+ CoverTree* Parent() const { return parent; }
+ //! Modify the parent node.
+ CoverTree*& Parent() { return parent; }
+
+ //! Get the distance to the parent.
+ double ParentDistance() const { return parentDistance; }
+ //! Modify the distance to the parent.
+ double& ParentDistance() { return parentDistance; }
+
+ //! Get the distance to the furthest descendant.
+ double FurthestDescendantDistance() const
+ { return furthestDescendantDistance; }
+ //! Modify the distance to the furthest descendant.
+ double& FurthestDescendantDistance() { return furthestDescendantDistance; }
+
+ //! Get the centroid of the node and store it in the given vector.
+ void Centroid(arma::vec& centroid) const { centroid = dataset.col(point); }
+
+ //! Get the instantiated metric.
+ MetricType& Metric() const { return *metric; }
+
+ private:
+ //! Reference to the matrix which this tree is built on.
+ const arma::mat& dataset;
+
+ //! Index of the point in the matrix which this node represents.
+ size_t point;
+
+ //! The list of children; the first is the self-child.
+ std::vector<CoverTree*> children;
+
+ //! Scale level of the node.
+ int scale;
+
+ //! The base used to construct the tree.
+ double base;
+
+ //! The instantiated statistic.
+ StatisticType stat;
+
+ //! The parent node (NULL if this is the root of the tree).
+ CoverTree* parent;
+
+ //! Distance to the parent.
+ double parentDistance;
+
+ //! Distance to the furthest descendant.
+ double furthestDescendantDistance;
+
+ //! Whether or not we need to destroy the metric in the destructor.
+ bool localMetric;
+
+ //! The metric used for this tree.
+ MetricType* metric;
+
+ /**
+ * Create the children for this node.
+ */
+ void CreateChildren(arma::Col<size_t>& indices,
+ arma::vec& distances,
+ size_t nearSetSize,
+ size_t& farSetSize,
+ size_t& usedSetSize);
+
+ /**
+ * Fill the vector of distances with the distances between the point specified
+ * by pointIndex and each point in the indices array. The distances of the
+ * first pointSetSize points in indices are calculated (so, this does not
+ * necessarily need to use all of the points in the arrays).
+ *
+ * @param pointIndex Point to build the distances for.
+ * @param indices List of indices to compute distances for.
+ * @param distances Vector to store calculated distances in.
+ * @param pointSetSize Number of points in arrays to calculate distances for.
+ */
+ void ComputeDistances(const size_t pointIndex,
+ const arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const size_t pointSetSize);
+ /**
+ * Split the given indices and distances into a near and a far set, returning
+ * the number of points in the near set. The distances must already be
+ * initialized. This will order the indices and distances such that the
+ * points in the near set make up the first part of the array and the far set
+ * makes up the rest: [ nearSet | farSet ].
+ *
+ * @param indices List of indices; will be reordered.
+ * @param distances List of distances; will be reordered.
+ * @param bound If the distance is less than or equal to this bound, the point
+ * is placed into the near set.
+ * @param pointSetSize Size of point set (because we may be sorting a smaller
+ * list than the indices vector will hold).
+ */
+ size_t SplitNearFar(arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const double bound,
+ const size_t pointSetSize);
+
+ /**
+ * Assuming that the list of indices and distances is sorted as
+ * [ childFarSet | childUsedSet | farSet | usedSet ],
+ * resort the sets so the organization is
+ * [ childFarSet | farSet | childUsedSet | usedSet ].
+ *
+ * The size_t parameters specify the sizes of each set in the array. Only the
+ * ordering of the indices and distances arrays will be modified (not their
+ * actual contents).
+ *
+ * The size of any of the four sets can be zero and this method will handle
+ * that case accordingly.
+ *
+ * @param indices List of indices to sort.
+ * @param distances List of distances to sort.
+ * @param childFarSetSize Number of points in child far set (childFarSet).
+ * @param childUsedSetSize Number of points in child used set (childUsedSet).
+ * @param farSetSize Number of points in far set (farSet).
+ */
+ size_t SortPointSet(arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const size_t childFarSetSize,
+ const size_t childUsedSetSize,
+ const size_t farSetSize);
+
+ void MoveToUsedSet(arma::Col<size_t>& indices,
+ arma::vec& distances,
+ size_t& nearSetSize,
+ size_t& farSetSize,
+ size_t& usedSetSize,
+ arma::Col<size_t>& childIndices,
+ const size_t childFarSetSize,
+ const size_t childUsedSetSize);
+ size_t PruneFarSet(arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const double bound,
+ const size_t nearSetSize,
+ const size_t pointSetSize);
+ public:
+ /**
+ * Returns a string representation of this object.
+ */
+ std::string ToString() const;
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+// Include implementation.
+#include "cover_tree_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,854 +0,0 @@
-/**
- * @file cover_tree_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of CoverTree class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_IMPL_HPP
-#define __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_IMPL_HPP
-
-// In case it hasn't already been included.
-#include "cover_tree.hpp"
-
-#include <mlpack/core/util/string_util.hpp>
-#include <string>
-
-namespace mlpack {
-namespace tree {
-
-// Create the cover tree.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
- const arma::mat& dataset,
- const double base,
- MetricType* metric) :
- dataset(dataset),
- point(RootPointPolicy::ChooseRoot(dataset)),
- scale(INT_MAX),
- base(base),
- parent(NULL),
- parentDistance(0),
- furthestDescendantDistance(0),
- localMetric(metric == NULL),
- metric(metric)
-{
- // If we need to create a metric, do that. We'll just do it on the heap.
- if (localMetric)
- this->metric = new MetricType();
-
- // If there is only one point in the dataset... uh, we're done.
- if (dataset.n_cols == 1)
- return;
-
- // Kick off the building. Create the indices array and the distances array.
- arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
- dataset.n_cols - 1, dataset.n_cols - 1);
- // This is now [1 2 3 4 ... n]. We must be sure that our point does not
- // occur.
- if (point != 0)
- indices[point - 1] = 0; // Put 0 back into the set; remove what was there.
-
- arma::vec distances(dataset.n_cols - 1);
-
- // Build the initial distances.
- ComputeDistances(point, indices, distances, dataset.n_cols - 1);
-
- // Create the children.
- size_t farSetSize = 0;
- size_t usedSetSize = 0;
- CreateChildren(indices, distances, dataset.n_cols - 1, farSetSize,
- usedSetSize);
-
- // Use the furthest descendant distance to determine the scale of the root
- // node.
- scale = (int) ceil(log(furthestDescendantDistance) / log(base));
-
- // Initialize statistic.
- stat = StatisticType(*this);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
- const arma::mat& dataset,
- MetricType& metric,
- const double base) :
- dataset(dataset),
- point(RootPointPolicy::ChooseRoot(dataset)),
- scale(INT_MAX),
- base(base),
- parent(NULL),
- parentDistance(0),
- furthestDescendantDistance(0),
- localMetric(false),
- metric(&metric)
-{
- // If there is only one point in the dataset, uh, we're done.
- if (dataset.n_cols == 1)
- return;
-
- // Kick off the building. Create the indices array and the distances array.
- arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
- dataset.n_cols - 1, dataset.n_cols - 1);
- // This is now [1 2 3 4 ... n]. We must be sure that our point does not
- // occur.
- if (point != 0)
- indices[point - 1] = 0; // Put 0 back into the set; remove what was there.
-
- arma::vec distances(dataset.n_cols - 1);
-
- // Build the initial distances.
- ComputeDistances(point, indices, distances, dataset.n_cols - 1);
-
- // Create the children.
- size_t farSetSize = 0;
- size_t usedSetSize = 0;
- CreateChildren(indices, distances, dataset.n_cols - 1, farSetSize,
- usedSetSize);
-
- // Use the furthest descendant distance to determine the scale of the root
- // node.
- scale = (int) ceil(log(furthestDescendantDistance) / log(base));
-
- // Initialize statistic.
- stat = StatisticType(*this);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
- const arma::mat& dataset,
- const double base,
- const size_t pointIndex,
- const int scale,
- CoverTree* parent,
- const double parentDistance,
- arma::Col<size_t>& indices,
- arma::vec& distances,
- size_t nearSetSize,
- size_t& farSetSize,
- size_t& usedSetSize,
- MetricType& metric) :
- dataset(dataset),
- point(pointIndex),
- scale(scale),
- base(base),
- parent(parent),
- parentDistance(parentDistance),
- furthestDescendantDistance(0),
- localMetric(false),
- metric(&metric)
-{
- // If the size of the near set is 0, this is a leaf.
- if (nearSetSize == 0)
- {
- this->scale = INT_MIN;
- stat = StatisticType(*this);
- return;
- }
-
- // Otherwise, create the children.
- CreateChildren(indices, distances, nearSetSize, farSetSize, usedSetSize);
-
- // Initialize statistic.
- stat = StatisticType(*this);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-inline void
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CreateChildren(
- arma::Col<size_t>& indices,
- arma::vec& distances,
- size_t nearSetSize,
- size_t& farSetSize,
- size_t& usedSetSize)
-{
- // Determine the next scale level. This should be the first level where there
- // are any points in the far set. So, if we know the maximum distance in the
- // distances array, this will be the largest i such that
- // maxDistance > pow(base, i)
- // and using this for the scale factor should guarantee we are not creating an
- // implicit node. If the maximum distance is 0, every point in the near set
- // will be created as a leaf, and a child to this node. We also do not need
- // to change the furthestChildDistance or furthestDescendantDistance.
- const double maxDistance = max(distances.rows(0,
- nearSetSize + farSetSize - 1));
- if (maxDistance == 0)
- {
- // Make the self child at the lowest possible level.
- // This should not modify farSetSize or usedSetSize.
- size_t tempSize = 0;
- children.push_back(new CoverTree(dataset, base, point, INT_MIN, this, 0,
- indices, distances, 0, tempSize, usedSetSize, *metric));
-
- // Every point in the near set should be a leaf.
- for (size_t i = 0; i < nearSetSize; ++i)
- {
- // farSetSize and usedSetSize will not be modified.
- children.push_back(new CoverTree(dataset, base, indices[i],
- INT_MIN, this, distances[i], indices, distances, 0, tempSize,
- usedSetSize, *metric));
- usedSetSize++;
- }
-
- // Re-sort the dataset. We have
- // [ used | far | other used ]
- // and we want
- // [ far | all used ].
- SortPointSet(indices, distances, 0, usedSetSize, farSetSize);
-
- // Initialize the statistic.
- stat = StatisticType(*this);
-
- return;
- }
-
- const int nextScale = std::min(scale,
- (int) ceil(log(maxDistance) / log(base))) - 1;
- const double bound = pow(base, nextScale);
-
- // First, make the self child. We must split the given near set into the near
- // set and far set for the self child.
- size_t childNearSetSize =
- SplitNearFar(indices, distances, bound, nearSetSize);
-
- // Build the self child (recursively).
- size_t childFarSetSize = nearSetSize - childNearSetSize;
- size_t childUsedSetSize = 0;
- children.push_back(new CoverTree(dataset, base, point, nextScale, this, 0,
- indices, distances, childNearSetSize, childFarSetSize, childUsedSetSize,
- *metric));
-
- // The self-child can't modify the furthestChildDistance away from 0, but it
- // can modify the furthestDescendantDistance.
- furthestDescendantDistance = children[0]->FurthestDescendantDistance();
-
- // If we created an implicit node, take its self-child instead (this could
- // happen multiple times).
- while (children[children.size() - 1]->NumChildren() == 1)
- {
- CoverTree* old = children[children.size() - 1];
- children.erase(children.begin() + children.size() - 1);
-
- // Now take its child.
- children.push_back(&(old->Child(0)));
-
- // Set its parent correctly.
- old->Child(0).Parent() = this;
-
- // Remove its child (so it doesn't delete it).
- old->Children().erase(old->Children().begin() + old->Children().size() - 1);
-
- // Now delete it.
- delete old;
- }
-
- // Now the arrays, in memory, look like this:
- // [ childFar | childUsed | far | used ]
- // but we need to move the used points past our far set:
- // [ childFar | far | childUsed + used ]
- // and keeping in mind that childFar = our near set,
- // [ near | far | childUsed + used ]
- // is what we are trying to make.
- SortPointSet(indices, distances, childFarSetSize, childUsedSetSize,
- farSetSize);
-
- // Update size of near set and used set.
- nearSetSize -= childUsedSetSize;
- usedSetSize += childUsedSetSize;
-
- // Now for each point in the near set, we need to make children. To save
- // computation later, we'll create an array holding the points in the near
- // set, and then after each run we'll check which of those (if any) were used
- // and we will remove them. ...if that's faster. I think it is.
- while (nearSetSize > 0)
- {
- size_t newPointIndex = nearSetSize - 1;
-
- // Swap to front if necessary.
- if (newPointIndex != 0)
- {
- const size_t tempIndex = indices[newPointIndex];
- const double tempDist = distances[newPointIndex];
-
- indices[newPointIndex] = indices[0];
- distances[newPointIndex] = distances[0];
-
- indices[0] = tempIndex;
- distances[0] = tempDist;
- }
-
- // Will this be a new furthest child?
- if (distances[0] > furthestDescendantDistance)
- furthestDescendantDistance = distances[0];
-
- // If there's only one point left, we don't need this crap.
- if ((nearSetSize == 1) && (farSetSize == 0))
- {
- size_t childNearSetSize = 0;
- children.push_back(new CoverTree(dataset, base, indices[0], nextScale,
- this, distances[0], indices, distances, childNearSetSize, farSetSize,
- usedSetSize, *metric));
-
- // Because the far set size is 0, we don't have to do any swapping to
- // move the point into the used set.
- ++usedSetSize;
- --nearSetSize;
-
- // And we're done.
- break;
- }
-
- // Create the near and far set indices and distance vectors. We don't fill
- // in the self-point, yet.
- arma::Col<size_t> childIndices(nearSetSize + farSetSize);
- childIndices.rows(0, (nearSetSize + farSetSize - 2)) = indices.rows(1,
- nearSetSize + farSetSize - 1);
- arma::vec childDistances(nearSetSize + farSetSize);
-
- // Build distances for the child.
- ComputeDistances(indices[0], childIndices, childDistances, nearSetSize
- + farSetSize - 1);
-
- // Split into near and far sets for this point.
- childNearSetSize = SplitNearFar(childIndices, childDistances, bound,
- nearSetSize + farSetSize - 1);
- childFarSetSize = PruneFarSet(childIndices, childDistances,
- base * bound, childNearSetSize,
- (nearSetSize + farSetSize - 1));
-
- // Now that we know the near and far set sizes, we can put the used point
- // (the self point) in the correct place; now, when we call
- // MoveToUsedSet(), it will move the self-point correctly. The distance
- // does not matter.
- childIndices(childNearSetSize + childFarSetSize) = indices[0];
- childDistances(childNearSetSize + childFarSetSize) = 0;
-
- // Build this child (recursively).
- childUsedSetSize = 1; // Mark self point as used.
- children.push_back(new CoverTree(dataset, base, indices[0], nextScale,
- this, distances[0], childIndices, childDistances, childNearSetSize,
- childFarSetSize, childUsedSetSize, *metric));
-
- // If we created an implicit node, take its self-child instead (this could
- // happen multiple times).
- while (children[children.size() - 1]->NumChildren() == 1)
- {
- CoverTree* old = children[children.size() - 1];
- children.erase(children.begin() + children.size() - 1);
-
- // Now take its child.
- children.push_back(&(old->Child(0)));
-
- // Set its parent correctly.
- old->Child(0).Parent() = this;
-
- // Remove its child (so it doesn't delete it).
- old->Children().erase(old->Children().begin() + old->Children().size()
- - 1);
-
- // Now delete it.
- delete old;
- }
-
- // Now with the child created, it returns the childIndices and
- // childDistances vectors in this form:
- // [ childFar | childUsed ]
- // For each point in the childUsed set, we must move that point to the used
- // set in our own vector.
- MoveToUsedSet(indices, distances, nearSetSize, farSetSize, usedSetSize,
- childIndices, childFarSetSize, childUsedSetSize);
- }
-
- // Calculate furthest descendant.
- for (size_t i = (nearSetSize + farSetSize); i < (nearSetSize + farSetSize +
- usedSetSize); ++i)
- if (distances[i] > furthestDescendantDistance)
- furthestDescendantDistance = distances[i];
-}
-
-// Manually create a cover tree node.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
- const arma::mat& dataset,
- const double base,
- const size_t pointIndex,
- const int scale,
- CoverTree* parent,
- const double parentDistance,
- const double furthestDescendantDistance,
- MetricType* metric) :
- dataset(dataset),
- point(pointIndex),
- scale(scale),
- base(base),
- parent(parent),
- parentDistance(parentDistance),
- furthestDescendantDistance(furthestDescendantDistance),
- localMetric(metric == NULL),
- metric(metric)
-{
- // If necessary, create a local metric.
- if (localMetric)
- this->metric = new MetricType();
-
- // Initialize the statistic.
- stat = StatisticType(*this);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
- const CoverTree& other) :
- dataset(other.dataset),
- point(other.point),
- scale(other.scale),
- base(other.base),
- stat(other.stat),
- parent(other.parent),
- parentDistance(other.parentDistance),
- furthestDescendantDistance(other.furthestDescendantDistance),
- localMetric(false),
- metric(other.metric)
-{
- // Copy each child by hand.
- for (size_t i = 0; i < other.NumChildren(); ++i)
- {
- children.push_back(new CoverTree(other.Child(i)));
- children[i]->Parent() = this;
- }
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::~CoverTree()
-{
- // Delete each child.
- for (size_t i = 0; i < children.size(); ++i)
- delete children[i];
-
- // Delete the local metric, if necessary.
- if (localMetric)
- delete metric;
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
- const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
-{
- // Every cover tree node will contain points up to EC^(scale + 1) away.
- return std::max(metric->Evaluate(dataset.unsafe_col(point),
- other->Dataset().unsafe_col(other->Point())) -
- furthestDescendantDistance - other->FurthestDescendantDistance(), 0.0);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
- const CoverTree<MetricType, RootPointPolicy, StatisticType>* other,
- const double distance) const
-{
- // We already have the distance as evaluated by the metric.
- return std::max(distance - furthestDescendantDistance -
- other->FurthestDescendantDistance(), 0.0);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
- const arma::vec& other) const
-{
- return std::max(metric->Evaluate(dataset.unsafe_col(point), other) -
- furthestDescendantDistance, 0.0);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
- const arma::vec& /* other */,
- const double distance) const
-{
- return std::max(distance - furthestDescendantDistance, 0.0);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
- const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
-{
- return metric->Evaluate(dataset.unsafe_col(point),
- other->Dataset().unsafe_col(other->Point())) +
- furthestDescendantDistance + other->FurthestDescendantDistance();
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
- const CoverTree<MetricType, RootPointPolicy, StatisticType>* other,
- const double distance) const
-{
- // We already have the distance as evaluated by the metric.
- return distance + furthestDescendantDistance +
- other->FurthestDescendantDistance();
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
- const arma::vec& other) const
-{
- return metric->Evaluate(dataset.unsafe_col(point), other) +
- furthestDescendantDistance;
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
- const arma::vec& /* other */,
- const double distance) const
-{
- return distance + furthestDescendantDistance;
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SplitNearFar(
- arma::Col<size_t>& indices,
- arma::vec& distances,
- const double bound,
- const size_t pointSetSize)
-{
- // Sanity check; there is no guarantee that this condition will not be true.
- // ...or is there?
- if (pointSetSize <= 1)
- return 0;
-
- // We'll traverse from both left and right.
- size_t left = 0;
- size_t right = pointSetSize - 1;
-
- // A modification of quicksort, with the pivot value set to the bound.
- // Everything on the left of the pivot will be less than or equal to the
- // bound; everything on the right will be greater than the bound.
- while ((distances[left] <= bound) && (left != right))
- ++left;
- while ((distances[right] > bound) && (left != right))
- --right;
-
- while (left != right)
- {
- // Now swap the values and indices.
- const size_t tempPoint = indices[left];
- const double tempDist = distances[left];
-
- indices[left] = indices[right];
- distances[left] = distances[right];
-
- indices[right] = tempPoint;
- distances[right] = tempDist;
-
- // Traverse the left, seeing how many points are correctly on that side.
- // When we encounter an incorrect point, stop. We will switch it later.
- while ((distances[left] <= bound) && (left != right))
- ++left;
-
- // Traverse the right, seeing how many points are correctly on that side.
- // When we encounter an incorrect point, stop. We will switch it with the
- // wrong point from the left side.
- while ((distances[right] > bound) && (left != right))
- --right;
- }
-
- // The final left value is the index of the first far value.
- return left;
-}
-
-// Returns the maximum distance between points.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::ComputeDistances(
- const size_t pointIndex,
- const arma::Col<size_t>& indices,
- arma::vec& distances,
- const size_t pointSetSize)
-{
- // For each point, rebuild the distances. The indices do not need to be
- // modified.
- for (size_t i = 0; i < pointSetSize; ++i)
- {
- distances[i] = metric->Evaluate(dataset.unsafe_col(pointIndex),
- dataset.unsafe_col(indices[i]));
- }
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SortPointSet(
- arma::Col<size_t>& indices,
- arma::vec& distances,
- const size_t childFarSetSize,
- const size_t childUsedSetSize,
- const size_t farSetSize)
-{
- // We'll use low-level memcpy calls ourselves, just to ensure it's done
- // quickly and the way we want it to be. Unfortunately this takes up more
- // memory than one-element swaps, but there's not a great way around that.
- const size_t bufferSize = std::min(farSetSize, childUsedSetSize);
- const size_t bigCopySize = std::max(farSetSize, childUsedSetSize);
-
- // Sanity check: there is no need to sort if the buffer size is going to be
- // zero.
- if (bufferSize == 0)
- return (childFarSetSize + farSetSize);
-
- size_t* indicesBuffer = new size_t[bufferSize];
- double* distancesBuffer = new double[bufferSize];
-
- // The start of the memory region to copy to the buffer.
- const size_t bufferFromLocation = ((bufferSize == farSetSize) ?
- (childFarSetSize + childUsedSetSize) : childFarSetSize);
- // The start of the memory region to move directly to the new place.
- const size_t directFromLocation = ((bufferSize == farSetSize) ?
- childFarSetSize : (childFarSetSize + childUsedSetSize));
- // The destination to copy the buffer back to.
- const size_t bufferToLocation = ((bufferSize == farSetSize) ?
- childFarSetSize : (childFarSetSize + farSetSize));
- // The destination of the directly moved memory region.
- const size_t directToLocation = ((bufferSize == farSetSize) ?
- (childFarSetSize + farSetSize) : childFarSetSize);
-
- // Copy the smaller piece to the buffer.
- memcpy(indicesBuffer, indices.memptr() + bufferFromLocation,
- sizeof(size_t) * bufferSize);
- memcpy(distancesBuffer, distances.memptr() + bufferFromLocation,
- sizeof(double) * bufferSize);
-
- // Now move the other memory.
- memmove(indices.memptr() + directToLocation,
- indices.memptr() + directFromLocation, sizeof(size_t) * bigCopySize);
- memmove(distances.memptr() + directToLocation,
- distances.memptr() + directFromLocation, sizeof(double) * bigCopySize);
-
- // Now copy the temporary memory to the right place.
- memcpy(indices.memptr() + bufferToLocation, indicesBuffer,
- sizeof(size_t) * bufferSize);
- memcpy(distances.memptr() + bufferToLocation, distancesBuffer,
- sizeof(double) * bufferSize);
-
- delete[] indicesBuffer;
- delete[] distancesBuffer;
-
- // This returns the complete size of the far set.
- return (childFarSetSize + farSetSize);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::MoveToUsedSet(
- arma::Col<size_t>& indices,
- arma::vec& distances,
- size_t& nearSetSize,
- size_t& farSetSize,
- size_t& usedSetSize,
- arma::Col<size_t>& childIndices,
- const size_t childFarSetSize, // childNearSetSize is 0 in this case.
- const size_t childUsedSetSize)
-{
- const size_t originalSum = nearSetSize + farSetSize + usedSetSize;
-
- // Loop across the set. We will swap points as we need. It should be noted
- // that farSetSize and nearSetSize may change with each iteration of this loop
- // (depending on if we make a swap or not).
- size_t startChildUsedSet = 0; // Where to start in the child set.
- for (size_t i = 0; i < nearSetSize; ++i)
- {
- // Discover if this point was in the child's used set.
- for (size_t j = startChildUsedSet; j < childUsedSetSize; ++j)
- {
- if (childIndices[childFarSetSize + j] == indices[i])
- {
- // We have found a point; a swap is necessary.
-
- // Since this point is from the near set, to preserve the near set, we
- // must do a swap.
- if (farSetSize > 0)
- {
- if ((nearSetSize - 1) != i)
- {
- // In this case it must be a three-way swap.
- size_t tempIndex = indices[nearSetSize + farSetSize - 1];
- double tempDist = distances[nearSetSize + farSetSize - 1];
-
- size_t tempNearIndex = indices[nearSetSize - 1];
- double tempNearDist = distances[nearSetSize - 1];
-
- indices[nearSetSize + farSetSize - 1] = indices[i];
- distances[nearSetSize + farSetSize - 1] = distances[i];
-
- indices[nearSetSize - 1] = tempIndex;
- distances[nearSetSize - 1] = tempDist;
-
- indices[i] = tempNearIndex;
- distances[i] = tempNearDist;
- }
- else
- {
- // We can do a two-way swap.
- size_t tempIndex = indices[nearSetSize + farSetSize - 1];
- double tempDist = distances[nearSetSize + farSetSize - 1];
-
- indices[nearSetSize + farSetSize - 1] = indices[i];
- distances[nearSetSize + farSetSize - 1] = distances[i];
-
- indices[i] = tempIndex;
- distances[i] = tempDist;
- }
- }
- else if ((nearSetSize - 1) != i)
- {
- // A two-way swap is possible.
- size_t tempIndex = indices[nearSetSize + farSetSize - 1];
- double tempDist = distances[nearSetSize + farSetSize - 1];
-
- indices[nearSetSize + farSetSize - 1] = indices[i];
- distances[nearSetSize + farSetSize - 1] = distances[i];
-
- indices[i] = tempIndex;
- distances[i] = tempDist;
- }
- else
- {
- // No swap is necessary.
- }
-
- // We don't need to do a complete preservation of the child index set,
- // but we want to make sure we only loop over points we haven't seen.
- // So increment the child counter by 1 and move a point if we need.
- if (j != startChildUsedSet)
- {
- childIndices[childFarSetSize + j] = childIndices[childFarSetSize +
- startChildUsedSet];
- }
-
- // Update all counters from the swaps we have done.
- ++startChildUsedSet;
- --nearSetSize;
- --i; // Since we moved a point out of the near set we must step back.
-
- break; // Break out of this for loop; back to the first one.
- }
- }
- }
-
- // Now loop over the far set. This loop is different because we only require
- // a normal two-way swap instead of the three-way swap to preserve the near
- // set / far set ordering.
- for (size_t i = 0; i < farSetSize; ++i)
- {
- // Discover if this point was in the child's used set.
- for (size_t j = startChildUsedSet; j < childUsedSetSize; ++j)
- {
- if (childIndices[childFarSetSize + j] == indices[i + nearSetSize])
- {
- // We have found a point to swap.
-
- // Perform the swap.
- size_t tempIndex = indices[nearSetSize + farSetSize - 1];
- double tempDist = distances[nearSetSize + farSetSize - 1];
-
- indices[nearSetSize + farSetSize - 1] = indices[nearSetSize + i];
- distances[nearSetSize + farSetSize - 1] = distances[nearSetSize + i];
-
- indices[nearSetSize + i] = tempIndex;
- distances[nearSetSize + i] = tempDist;
-
- if (j != startChildUsedSet)
- {
- childIndices[childFarSetSize + j] = childIndices[childFarSetSize +
- startChildUsedSet];
- }
-
- // Update all counters from the swaps we have done.
- ++startChildUsedSet;
- --farSetSize;
- --i;
-
- break; // Break out of this for loop; back to the first one.
- }
- }
- }
-
- // Update used set size.
- usedSetSize += childUsedSetSize;
-
- Log::Assert(originalSum == (nearSetSize + farSetSize + usedSetSize));
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::PruneFarSet(
- arma::Col<size_t>& indices,
- arma::vec& distances,
- const double bound,
- const size_t nearSetSize,
- const size_t pointSetSize)
-{
- // What we are trying to do is remove any points greater than the bound from
- // the far set. We don't care what happens to those indices and distances...
- // so, we don't need to properly swap points -- just drop new ones in place.
- size_t left = nearSetSize;
- size_t right = pointSetSize - 1;
- while ((distances[left] <= bound) && (left != right))
- ++left;
- while ((distances[right] > bound) && (left != right))
- --right;
-
- while (left != right)
- {
- // We don't care what happens to the point which should be on the right.
- indices[left] = indices[right];
- distances[left] = distances[right];
- --right; // Since we aren't changing the right.
-
- // Advance to next location which needs to switch.
- while ((distances[left] <= bound) && (left != right))
- ++left;
- while ((distances[right] > bound) && (left != right))
- --right;
- }
-
- // The far set size is the left pointer, with the near set size subtracted
- // from it.
- return (left - nearSetSize);
-}
-
-/**
- * Returns a string representation of this object.
- */
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-std::string CoverTree<MetricType, RootPointPolicy, StatisticType>::ToString() const
-{
- std::ostringstream convert;
- convert << "CoverTree [" << this << "]" << std::endl;
- convert << "dataset: " << &dataset << std::endl;
- convert << "point: " << point << std::endl;
- convert << "scale: " << scale << std::endl;
- convert << "base: " << base << std::endl;
-// convert << "StatisticType: " << stat << std::endl;
- convert << "parent distance : " << parentDistance << std::endl;
- convert << "furthest child distance: " << furthestDescendantDistance;
- convert << std::endl;
- convert << "children:";
-
- if (IsLeaf() == false)
- {
- for (int i = 0; i < children.size(); i++)
- {
- convert << std::endl << mlpack::util::Indent(children.at(i)->ToString());
- }
- }
- return convert.str();
-}
-}; // namespace tree
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,854 @@
+/**
+ * @file cover_tree_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of CoverTree class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_IMPL_HPP
+#define __MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_IMPL_HPP
+
+// In case it hasn't already been included.
+#include "cover_tree.hpp"
+
+#include <mlpack/core/util/string_util.hpp>
+#include <string>
+
+namespace mlpack {
+namespace tree {
+
+// Create the cover tree.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
+ const arma::mat& dataset,
+ const double base,
+ MetricType* metric) :
+ dataset(dataset),
+ point(RootPointPolicy::ChooseRoot(dataset)),
+ scale(INT_MAX),
+ base(base),
+ parent(NULL),
+ parentDistance(0),
+ furthestDescendantDistance(0),
+ localMetric(metric == NULL),
+ metric(metric)
+{
+ // If we need to create a metric, do that. We'll just do it on the heap.
+ if (localMetric)
+ this->metric = new MetricType();
+
+ // If there is only one point in the dataset... uh, we're done.
+ if (dataset.n_cols == 1)
+ return;
+
+ // Kick off the building. Create the indices array and the distances array.
+ arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
+ dataset.n_cols - 1, dataset.n_cols - 1);
+ // This is now [1 2 3 4 ... n]. We must be sure that our point does not
+ // occur.
+ if (point != 0)
+ indices[point - 1] = 0; // Put 0 back into the set; remove what was there.
+
+ arma::vec distances(dataset.n_cols - 1);
+
+ // Build the initial distances.
+ ComputeDistances(point, indices, distances, dataset.n_cols - 1);
+
+ // Create the children.
+ size_t farSetSize = 0;
+ size_t usedSetSize = 0;
+ CreateChildren(indices, distances, dataset.n_cols - 1, farSetSize,
+ usedSetSize);
+
+ // Use the furthest descendant distance to determine the scale of the root
+ // node.
+ scale = (int) ceil(log(furthestDescendantDistance) / log(base));
+
+ // Initialize statistic.
+ stat = StatisticType(*this);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
+ const arma::mat& dataset,
+ MetricType& metric,
+ const double base) :
+ dataset(dataset),
+ point(RootPointPolicy::ChooseRoot(dataset)),
+ scale(INT_MAX),
+ base(base),
+ parent(NULL),
+ parentDistance(0),
+ furthestDescendantDistance(0),
+ localMetric(false),
+ metric(&metric)
+{
+ // If there is only one point in the dataset, uh, we're done.
+ if (dataset.n_cols == 1)
+ return;
+
+ // Kick off the building. Create the indices array and the distances array.
+ arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
+ dataset.n_cols - 1, dataset.n_cols - 1);
+ // This is now [1 2 3 4 ... n]. We must be sure that our point does not
+ // occur.
+ if (point != 0)
+ indices[point - 1] = 0; // Put 0 back into the set; remove what was there.
+
+ arma::vec distances(dataset.n_cols - 1);
+
+ // Build the initial distances.
+ ComputeDistances(point, indices, distances, dataset.n_cols - 1);
+
+ // Create the children.
+ size_t farSetSize = 0;
+ size_t usedSetSize = 0;
+ CreateChildren(indices, distances, dataset.n_cols - 1, farSetSize,
+ usedSetSize);
+
+ // Use the furthest descendant distance to determine the scale of the root
+ // node.
+ scale = (int) ceil(log(furthestDescendantDistance) / log(base));
+
+ // Initialize statistic.
+ stat = StatisticType(*this);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
+ const arma::mat& dataset,
+ const double base,
+ const size_t pointIndex,
+ const int scale,
+ CoverTree* parent,
+ const double parentDistance,
+ arma::Col<size_t>& indices,
+ arma::vec& distances,
+ size_t nearSetSize,
+ size_t& farSetSize,
+ size_t& usedSetSize,
+ MetricType& metric) :
+ dataset(dataset),
+ point(pointIndex),
+ scale(scale),
+ base(base),
+ parent(parent),
+ parentDistance(parentDistance),
+ furthestDescendantDistance(0),
+ localMetric(false),
+ metric(&metric)
+{
+ // If the size of the near set is 0, this is a leaf.
+ if (nearSetSize == 0)
+ {
+ this->scale = INT_MIN;
+ stat = StatisticType(*this);
+ return;
+ }
+
+ // Otherwise, create the children.
+ CreateChildren(indices, distances, nearSetSize, farSetSize, usedSetSize);
+
+ // Initialize statistic.
+ stat = StatisticType(*this);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+inline void
+CoverTree<MetricType, RootPointPolicy, StatisticType>::CreateChildren(
+ arma::Col<size_t>& indices,
+ arma::vec& distances,
+ size_t nearSetSize,
+ size_t& farSetSize,
+ size_t& usedSetSize)
+{
+ // Determine the next scale level. This should be the first level where there
+ // are any points in the far set. So, if we know the maximum distance in the
+ // distances array, this will be the largest i such that
+ // maxDistance > pow(base, i)
+ // and using this for the scale factor should guarantee we are not creating an
+ // implicit node. If the maximum distance is 0, every point in the near set
+ // will be created as a leaf, and a child to this node. We also do not need
+ // to change the furthestChildDistance or furthestDescendantDistance.
+ const double maxDistance = max(distances.rows(0,
+ nearSetSize + farSetSize - 1));
+ if (maxDistance == 0)
+ {
+ // Make the self child at the lowest possible level.
+ // This should not modify farSetSize or usedSetSize.
+ size_t tempSize = 0;
+ children.push_back(new CoverTree(dataset, base, point, INT_MIN, this, 0,
+ indices, distances, 0, tempSize, usedSetSize, *metric));
+
+ // Every point in the near set should be a leaf.
+ for (size_t i = 0; i < nearSetSize; ++i)
+ {
+ // farSetSize and usedSetSize will not be modified.
+ children.push_back(new CoverTree(dataset, base, indices[i],
+ INT_MIN, this, distances[i], indices, distances, 0, tempSize,
+ usedSetSize, *metric));
+ usedSetSize++;
+ }
+
+ // Re-sort the dataset. We have
+ // [ used | far | other used ]
+ // and we want
+ // [ far | all used ].
+ SortPointSet(indices, distances, 0, usedSetSize, farSetSize);
+
+ // Initialize the statistic.
+ stat = StatisticType(*this);
+
+ return;
+ }
+
+ const int nextScale = std::min(scale,
+ (int) ceil(log(maxDistance) / log(base))) - 1;
+ const double bound = pow(base, nextScale);
+
+ // First, make the self child. We must split the given near set into the near
+ // set and far set for the self child.
+ size_t childNearSetSize =
+ SplitNearFar(indices, distances, bound, nearSetSize);
+
+ // Build the self child (recursively).
+ size_t childFarSetSize = nearSetSize - childNearSetSize;
+ size_t childUsedSetSize = 0;
+ children.push_back(new CoverTree(dataset, base, point, nextScale, this, 0,
+ indices, distances, childNearSetSize, childFarSetSize, childUsedSetSize,
+ *metric));
+
+ // The self-child can't modify the furthestChildDistance away from 0, but it
+ // can modify the furthestDescendantDistance.
+ furthestDescendantDistance = children[0]->FurthestDescendantDistance();
+
+ // If we created an implicit node, take its self-child instead (this could
+ // happen multiple times).
+ while (children[children.size() - 1]->NumChildren() == 1)
+ {
+ CoverTree* old = children[children.size() - 1];
+ children.erase(children.begin() + children.size() - 1);
+
+ // Now take its child.
+ children.push_back(&(old->Child(0)));
+
+ // Set its parent correctly.
+ old->Child(0).Parent() = this;
+
+ // Remove its child (so it doesn't delete it).
+ old->Children().erase(old->Children().begin() + old->Children().size() - 1);
+
+ // Now delete it.
+ delete old;
+ }
+
+ // Now the arrays, in memory, look like this:
+ // [ childFar | childUsed | far | used ]
+ // but we need to move the used points past our far set:
+ // [ childFar | far | childUsed + used ]
+ // and keeping in mind that childFar = our near set,
+ // [ near | far | childUsed + used ]
+ // is what we are trying to make.
+ SortPointSet(indices, distances, childFarSetSize, childUsedSetSize,
+ farSetSize);
+
+ // Update size of near set and used set.
+ nearSetSize -= childUsedSetSize;
+ usedSetSize += childUsedSetSize;
+
+ // Now for each point in the near set, we need to make children. To save
+ // computation later, we'll create an array holding the points in the near
+ // set, and then after each run we'll check which of those (if any) were used
+ // and we will remove them. ...if that's faster. I think it is.
+ while (nearSetSize > 0)
+ {
+ size_t newPointIndex = nearSetSize - 1;
+
+ // Swap to front if necessary.
+ if (newPointIndex != 0)
+ {
+ const size_t tempIndex = indices[newPointIndex];
+ const double tempDist = distances[newPointIndex];
+
+ indices[newPointIndex] = indices[0];
+ distances[newPointIndex] = distances[0];
+
+ indices[0] = tempIndex;
+ distances[0] = tempDist;
+ }
+
+ // Will this be a new furthest child?
+ if (distances[0] > furthestDescendantDistance)
+ furthestDescendantDistance = distances[0];
+
+ // If there's only one point left, we don't need this crap.
+ if ((nearSetSize == 1) && (farSetSize == 0))
+ {
+ size_t childNearSetSize = 0;
+ children.push_back(new CoverTree(dataset, base, indices[0], nextScale,
+ this, distances[0], indices, distances, childNearSetSize, farSetSize,
+ usedSetSize, *metric));
+
+ // Because the far set size is 0, we don't have to do any swapping to
+ // move the point into the used set.
+ ++usedSetSize;
+ --nearSetSize;
+
+ // And we're done.
+ break;
+ }
+
+ // Create the near and far set indices and distance vectors. We don't fill
+ // in the self-point, yet.
+ arma::Col<size_t> childIndices(nearSetSize + farSetSize);
+ childIndices.rows(0, (nearSetSize + farSetSize - 2)) = indices.rows(1,
+ nearSetSize + farSetSize - 1);
+ arma::vec childDistances(nearSetSize + farSetSize);
+
+ // Build distances for the child.
+ ComputeDistances(indices[0], childIndices, childDistances, nearSetSize
+ + farSetSize - 1);
+
+ // Split into near and far sets for this point.
+ childNearSetSize = SplitNearFar(childIndices, childDistances, bound,
+ nearSetSize + farSetSize - 1);
+ childFarSetSize = PruneFarSet(childIndices, childDistances,
+ base * bound, childNearSetSize,
+ (nearSetSize + farSetSize - 1));
+
+ // Now that we know the near and far set sizes, we can put the used point
+ // (the self point) in the correct place; now, when we call
+ // MoveToUsedSet(), it will move the self-point correctly. The distance
+ // does not matter.
+ childIndices(childNearSetSize + childFarSetSize) = indices[0];
+ childDistances(childNearSetSize + childFarSetSize) = 0;
+
+ // Build this child (recursively).
+ childUsedSetSize = 1; // Mark self point as used.
+ children.push_back(new CoverTree(dataset, base, indices[0], nextScale,
+ this, distances[0], childIndices, childDistances, childNearSetSize,
+ childFarSetSize, childUsedSetSize, *metric));
+
+ // If we created an implicit node, take its self-child instead (this could
+ // happen multiple times).
+ while (children[children.size() - 1]->NumChildren() == 1)
+ {
+ CoverTree* old = children[children.size() - 1];
+ children.erase(children.begin() + children.size() - 1);
+
+ // Now take its child.
+ children.push_back(&(old->Child(0)));
+
+ // Set its parent correctly.
+ old->Child(0).Parent() = this;
+
+ // Remove its child (so it doesn't delete it).
+ old->Children().erase(old->Children().begin() + old->Children().size()
+ - 1);
+
+ // Now delete it.
+ delete old;
+ }
+
+ // Now with the child created, it returns the childIndices and
+ // childDistances vectors in this form:
+ // [ childFar | childUsed ]
+ // For each point in the childUsed set, we must move that point to the used
+ // set in our own vector.
+ MoveToUsedSet(indices, distances, nearSetSize, farSetSize, usedSetSize,
+ childIndices, childFarSetSize, childUsedSetSize);
+ }
+
+ // Calculate furthest descendant.
+ for (size_t i = (nearSetSize + farSetSize); i < (nearSetSize + farSetSize +
+ usedSetSize); ++i)
+ if (distances[i] > furthestDescendantDistance)
+ furthestDescendantDistance = distances[i];
+}
+
+// Manually create a cover tree node.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
+ const arma::mat& dataset,
+ const double base,
+ const size_t pointIndex,
+ const int scale,
+ CoverTree* parent,
+ const double parentDistance,
+ const double furthestDescendantDistance,
+ MetricType* metric) :
+ dataset(dataset),
+ point(pointIndex),
+ scale(scale),
+ base(base),
+ parent(parent),
+ parentDistance(parentDistance),
+ furthestDescendantDistance(furthestDescendantDistance),
+ localMetric(metric == NULL),
+ metric(metric)
+{
+ // If necessary, create a local metric.
+ if (localMetric)
+ this->metric = new MetricType();
+
+ // Initialize the statistic.
+ stat = StatisticType(*this);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
+ const CoverTree& other) :
+ dataset(other.dataset),
+ point(other.point),
+ scale(other.scale),
+ base(other.base),
+ stat(other.stat),
+ parent(other.parent),
+ parentDistance(other.parentDistance),
+ furthestDescendantDistance(other.furthestDescendantDistance),
+ localMetric(false),
+ metric(other.metric)
+{
+ // Copy each child by hand.
+ for (size_t i = 0; i < other.NumChildren(); ++i)
+ {
+ children.push_back(new CoverTree(other.Child(i)));
+ children[i]->Parent() = this;
+ }
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::~CoverTree()
+{
+ // Delete each child.
+ for (size_t i = 0; i < children.size(); ++i)
+ delete children[i];
+
+ // Delete the local metric, if necessary.
+ if (localMetric)
+ delete metric;
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
+ const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
+{
+ // Every cover tree node will contain points up to EC^(scale + 1) away.
+ return std::max(metric->Evaluate(dataset.unsafe_col(point),
+ other->Dataset().unsafe_col(other->Point())) -
+ furthestDescendantDistance - other->FurthestDescendantDistance(), 0.0);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
+ const CoverTree<MetricType, RootPointPolicy, StatisticType>* other,
+ const double distance) const
+{
+ // We already have the distance as evaluated by the metric.
+ return std::max(distance - furthestDescendantDistance -
+ other->FurthestDescendantDistance(), 0.0);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
+ const arma::vec& other) const
+{
+ return std::max(metric->Evaluate(dataset.unsafe_col(point), other) -
+ furthestDescendantDistance, 0.0);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
+ const arma::vec& /* other */,
+ const double distance) const
+{
+ return std::max(distance - furthestDescendantDistance, 0.0);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
+ const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
+{
+ return metric->Evaluate(dataset.unsafe_col(point),
+ other->Dataset().unsafe_col(other->Point())) +
+ furthestDescendantDistance + other->FurthestDescendantDistance();
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
+ const CoverTree<MetricType, RootPointPolicy, StatisticType>* other,
+ const double distance) const
+{
+ // We already have the distance as evaluated by the metric.
+ return distance + furthestDescendantDistance +
+ other->FurthestDescendantDistance();
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
+ const arma::vec& other) const
+{
+ return metric->Evaluate(dataset.unsafe_col(point), other) +
+ furthestDescendantDistance;
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
+ const arma::vec& /* other */,
+ const double distance) const
+{
+ return distance + furthestDescendantDistance;
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SplitNearFar(
+ arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const double bound,
+ const size_t pointSetSize)
+{
+ // Sanity check; there is no guarantee that this condition will not be true.
+ // ...or is there?
+ if (pointSetSize <= 1)
+ return 0;
+
+ // We'll traverse from both left and right.
+ size_t left = 0;
+ size_t right = pointSetSize - 1;
+
+ // A modification of quicksort, with the pivot value set to the bound.
+ // Everything on the left of the pivot will be less than or equal to the
+ // bound; everything on the right will be greater than the bound.
+ while ((distances[left] <= bound) && (left != right))
+ ++left;
+ while ((distances[right] > bound) && (left != right))
+ --right;
+
+ while (left != right)
+ {
+ // Now swap the values and indices.
+ const size_t tempPoint = indices[left];
+ const double tempDist = distances[left];
+
+ indices[left] = indices[right];
+ distances[left] = distances[right];
+
+ indices[right] = tempPoint;
+ distances[right] = tempDist;
+
+ // Traverse the left, seeing how many points are correctly on that side.
+ // When we encounter an incorrect point, stop. We will switch it later.
+ while ((distances[left] <= bound) && (left != right))
+ ++left;
+
+ // Traverse the right, seeing how many points are correctly on that side.
+ // When we encounter an incorrect point, stop. We will switch it with the
+ // wrong point from the left side.
+ while ((distances[right] > bound) && (left != right))
+ --right;
+ }
+
+ // The final left value is the index of the first far value.
+ return left;
+}
+
+// Returns the maximum distance between points.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+void CoverTree<MetricType, RootPointPolicy, StatisticType>::ComputeDistances(
+ const size_t pointIndex,
+ const arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const size_t pointSetSize)
+{
+ // For each point, rebuild the distances. The indices do not need to be
+ // modified.
+ for (size_t i = 0; i < pointSetSize; ++i)
+ {
+ distances[i] = metric->Evaluate(dataset.unsafe_col(pointIndex),
+ dataset.unsafe_col(indices[i]));
+ }
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SortPointSet(
+ arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const size_t childFarSetSize,
+ const size_t childUsedSetSize,
+ const size_t farSetSize)
+{
+ // We'll use low-level memcpy calls ourselves, just to ensure it's done
+ // quickly and the way we want it to be. Unfortunately this takes up more
+ // memory than one-element swaps, but there's not a great way around that.
+ const size_t bufferSize = std::min(farSetSize, childUsedSetSize);
+ const size_t bigCopySize = std::max(farSetSize, childUsedSetSize);
+
+ // Sanity check: there is no need to sort if the buffer size is going to be
+ // zero.
+ if (bufferSize == 0)
+ return (childFarSetSize + farSetSize);
+
+ size_t* indicesBuffer = new size_t[bufferSize];
+ double* distancesBuffer = new double[bufferSize];
+
+ // The start of the memory region to copy to the buffer.
+ const size_t bufferFromLocation = ((bufferSize == farSetSize) ?
+ (childFarSetSize + childUsedSetSize) : childFarSetSize);
+ // The start of the memory region to move directly to the new place.
+ const size_t directFromLocation = ((bufferSize == farSetSize) ?
+ childFarSetSize : (childFarSetSize + childUsedSetSize));
+ // The destination to copy the buffer back to.
+ const size_t bufferToLocation = ((bufferSize == farSetSize) ?
+ childFarSetSize : (childFarSetSize + farSetSize));
+ // The destination of the directly moved memory region.
+ const size_t directToLocation = ((bufferSize == farSetSize) ?
+ (childFarSetSize + farSetSize) : childFarSetSize);
+
+ // Copy the smaller piece to the buffer.
+ memcpy(indicesBuffer, indices.memptr() + bufferFromLocation,
+ sizeof(size_t) * bufferSize);
+ memcpy(distancesBuffer, distances.memptr() + bufferFromLocation,
+ sizeof(double) * bufferSize);
+
+ // Now move the other memory.
+ memmove(indices.memptr() + directToLocation,
+ indices.memptr() + directFromLocation, sizeof(size_t) * bigCopySize);
+ memmove(distances.memptr() + directToLocation,
+ distances.memptr() + directFromLocation, sizeof(double) * bigCopySize);
+
+ // Now copy the temporary memory to the right place.
+ memcpy(indices.memptr() + bufferToLocation, indicesBuffer,
+ sizeof(size_t) * bufferSize);
+ memcpy(distances.memptr() + bufferToLocation, distancesBuffer,
+ sizeof(double) * bufferSize);
+
+ delete[] indicesBuffer;
+ delete[] distancesBuffer;
+
+ // This returns the complete size of the far set.
+ return (childFarSetSize + farSetSize);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+void CoverTree<MetricType, RootPointPolicy, StatisticType>::MoveToUsedSet(
+ arma::Col<size_t>& indices,
+ arma::vec& distances,
+ size_t& nearSetSize,
+ size_t& farSetSize,
+ size_t& usedSetSize,
+ arma::Col<size_t>& childIndices,
+ const size_t childFarSetSize, // childNearSetSize is 0 in this case.
+ const size_t childUsedSetSize)
+{
+ const size_t originalSum = nearSetSize + farSetSize + usedSetSize;
+
+ // Loop across the set. We will swap points as we need. It should be noted
+ // that farSetSize and nearSetSize may change with each iteration of this loop
+ // (depending on if we make a swap or not).
+ size_t startChildUsedSet = 0; // Where to start in the child set.
+ for (size_t i = 0; i < nearSetSize; ++i)
+ {
+ // Discover if this point was in the child's used set.
+ for (size_t j = startChildUsedSet; j < childUsedSetSize; ++j)
+ {
+ if (childIndices[childFarSetSize + j] == indices[i])
+ {
+ // We have found a point; a swap is necessary.
+
+ // Since this point is from the near set, to preserve the near set, we
+ // must do a swap.
+ if (farSetSize > 0)
+ {
+ if ((nearSetSize - 1) != i)
+ {
+ // In this case it must be a three-way swap.
+ size_t tempIndex = indices[nearSetSize + farSetSize - 1];
+ double tempDist = distances[nearSetSize + farSetSize - 1];
+
+ size_t tempNearIndex = indices[nearSetSize - 1];
+ double tempNearDist = distances[nearSetSize - 1];
+
+ indices[nearSetSize + farSetSize - 1] = indices[i];
+ distances[nearSetSize + farSetSize - 1] = distances[i];
+
+ indices[nearSetSize - 1] = tempIndex;
+ distances[nearSetSize - 1] = tempDist;
+
+ indices[i] = tempNearIndex;
+ distances[i] = tempNearDist;
+ }
+ else
+ {
+ // We can do a two-way swap.
+ size_t tempIndex = indices[nearSetSize + farSetSize - 1];
+ double tempDist = distances[nearSetSize + farSetSize - 1];
+
+ indices[nearSetSize + farSetSize - 1] = indices[i];
+ distances[nearSetSize + farSetSize - 1] = distances[i];
+
+ indices[i] = tempIndex;
+ distances[i] = tempDist;
+ }
+ }
+ else if ((nearSetSize - 1) != i)
+ {
+ // A two-way swap is possible.
+ size_t tempIndex = indices[nearSetSize + farSetSize - 1];
+ double tempDist = distances[nearSetSize + farSetSize - 1];
+
+ indices[nearSetSize + farSetSize - 1] = indices[i];
+ distances[nearSetSize + farSetSize - 1] = distances[i];
+
+ indices[i] = tempIndex;
+ distances[i] = tempDist;
+ }
+ else
+ {
+ // No swap is necessary.
+ }
+
+ // We don't need to do a complete preservation of the child index set,
+ // but we want to make sure we only loop over points we haven't seen.
+ // So increment the child counter by 1 and move a point if we need.
+ if (j != startChildUsedSet)
+ {
+ childIndices[childFarSetSize + j] = childIndices[childFarSetSize +
+ startChildUsedSet];
+ }
+
+ // Update all counters from the swaps we have done.
+ ++startChildUsedSet;
+ --nearSetSize;
+ --i; // Since we moved a point out of the near set we must step back.
+
+ break; // Break out of this for loop; back to the first one.
+ }
+ }
+ }
+
+ // Now loop over the far set. This loop is different because we only require
+ // a normal two-way swap instead of the three-way swap to preserve the near
+ // set / far set ordering.
+ for (size_t i = 0; i < farSetSize; ++i)
+ {
+ // Discover if this point was in the child's used set.
+ for (size_t j = startChildUsedSet; j < childUsedSetSize; ++j)
+ {
+ if (childIndices[childFarSetSize + j] == indices[i + nearSetSize])
+ {
+ // We have found a point to swap.
+
+ // Perform the swap.
+ size_t tempIndex = indices[nearSetSize + farSetSize - 1];
+ double tempDist = distances[nearSetSize + farSetSize - 1];
+
+ indices[nearSetSize + farSetSize - 1] = indices[nearSetSize + i];
+ distances[nearSetSize + farSetSize - 1] = distances[nearSetSize + i];
+
+ indices[nearSetSize + i] = tempIndex;
+ distances[nearSetSize + i] = tempDist;
+
+ if (j != startChildUsedSet)
+ {
+ childIndices[childFarSetSize + j] = childIndices[childFarSetSize +
+ startChildUsedSet];
+ }
+
+ // Update all counters from the swaps we have done.
+ ++startChildUsedSet;
+ --farSetSize;
+ --i;
+
+ break; // Break out of this for loop; back to the first one.
+ }
+ }
+ }
+
+ // Update used set size.
+ usedSetSize += childUsedSetSize;
+
+ Log::Assert(originalSum == (nearSetSize + farSetSize + usedSetSize));
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::PruneFarSet(
+ arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const double bound,
+ const size_t nearSetSize,
+ const size_t pointSetSize)
+{
+ // What we are trying to do is remove any points greater than the bound from
+ // the far set. We don't care what happens to those indices and distances...
+ // so, we don't need to properly swap points -- just drop new ones in place.
+ size_t left = nearSetSize;
+ size_t right = pointSetSize - 1;
+ while ((distances[left] <= bound) && (left != right))
+ ++left;
+ while ((distances[right] > bound) && (left != right))
+ --right;
+
+ while (left != right)
+ {
+ // We don't care what happens to the point which should be on the right.
+ indices[left] = indices[right];
+ distances[left] = distances[right];
+ --right; // Since we aren't changing the right.
+
+ // Advance to next location which needs to switch.
+ while ((distances[left] <= bound) && (left != right))
+ ++left;
+ while ((distances[right] > bound) && (left != right))
+ --right;
+ }
+
+ // The far set size is the left pointer, with the near set size subtracted
+ // from it.
+ return (left - nearSetSize);
+}
+
+/**
+ * Returns a string representation of this object.
+ */
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+std::string CoverTree<MetricType, RootPointPolicy, StatisticType>::ToString() const
+{
+ std::ostringstream convert;
+ convert << "CoverTree [" << this << "]" << std::endl;
+ convert << "dataset: " << &dataset << std::endl;
+ convert << "point: " << point << std::endl;
+ convert << "scale: " << scale << std::endl;
+ convert << "base: " << base << std::endl;
+// convert << "StatisticType: " << stat << std::endl;
+ convert << "parent distance : " << parentDistance << std::endl;
+ convert << "furthest child distance: " << furthestDescendantDistance;
+ convert << std::endl;
+ convert << "children:";
+
+ if (IsLeaf() == false)
+ {
+ for (int i = 0; i < children.size(); i++)
+ {
+ convert << std::endl << mlpack::util::Indent(children.at(i)->ToString());
+ }
+ }
+ return convert.str();
+}
+}; // namespace tree
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,99 +0,0 @@
-/**
- * @file dual_tree_traverser.hpp
- * @author Ryan Curtin
- *
- * A dual-tree traverser for the cover tree.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_COVER_TREE_DUAL_TREE_TRAVERSER_HPP
-#define __MLPACK_CORE_TREE_COVER_TREE_DUAL_TREE_TRAVERSER_HPP
-
-#include <mlpack/core.hpp>
-#include <queue>
-
-namespace mlpack {
-namespace tree {
-
-//! Forward declaration of struct to be used for traversal.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-struct DualCoverTreeMapEntry;
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-template<typename RuleType>
-class CoverTree<MetricType, RootPointPolicy, StatisticType>::DualTreeTraverser
-{
- public:
- /**
- * Initialize the dual tree traverser with the given rule type.
- */
- DualTreeTraverser(RuleType& rule);
-
- /**
- * Traverse the two specified trees.
- *
- * @param queryNode Root of query tree.
- * @param referenceNode Root of reference tree.
- */
- void Traverse(CoverTree& queryNode, CoverTree& referenceNode);
-
- /**
- * Helper function for traversal of the two trees.
- */
- void Traverse(CoverTree& queryNode,
- std::map<int, std::vector<DualCoverTreeMapEntry<
- MetricType, RootPointPolicy, StatisticType> > >&
- referenceMap);
-
- //! Get the number of pruned nodes.
- size_t NumPrunes() const { return numPrunes; }
- //! Modify the number of pruned nodes.
- size_t& NumPrunes() { return numPrunes; }
-
- private:
- //! The instantiated rule set for pruning branches.
- RuleType& rule;
-
- //! The number of pruned nodes.
- size_t numPrunes;
-
- //! Prepare map for recursion.
- void PruneMap(CoverTree& queryNode,
- CoverTree& candidateQueryNode,
- std::map<int, std::vector<DualCoverTreeMapEntry<
- MetricType, RootPointPolicy, StatisticType> > >&
- referenceMap,
- std::map<int, std::vector<DualCoverTreeMapEntry<
- MetricType, RootPointPolicy, StatisticType> > >& childMap);
-
- void PruneMapForSelfChild(CoverTree& candidateQueryNode,
- std::map<int, std::vector<DualCoverTreeMapEntry<
- MetricType, RootPointPolicy, StatisticType> > >&
- referenceMap);
-
- void ReferenceRecursion(CoverTree& queryNode,
- std::map<int, std::vector<DualCoverTreeMapEntry<
- MetricType, RootPointPolicy, StatisticType> > >&
- referenceMap);
-};
-
-}; // namespace tree
-}; // namespace mlpack
-
-// Include implementation.
-#include "dual_tree_traverser_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,99 @@
+/**
+ * @file dual_tree_traverser.hpp
+ * @author Ryan Curtin
+ *
+ * A dual-tree traverser for the cover tree.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_COVER_TREE_DUAL_TREE_TRAVERSER_HPP
+#define __MLPACK_CORE_TREE_COVER_TREE_DUAL_TREE_TRAVERSER_HPP
+
+#include <mlpack/core.hpp>
+#include <queue>
+
+namespace mlpack {
+namespace tree {
+
+//! Forward declaration of struct to be used for traversal.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+struct DualCoverTreeMapEntry;
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<typename RuleType>
+class CoverTree<MetricType, RootPointPolicy, StatisticType>::DualTreeTraverser
+{
+ public:
+ /**
+ * Initialize the dual tree traverser with the given rule type.
+ */
+ DualTreeTraverser(RuleType& rule);
+
+ /**
+ * Traverse the two specified trees.
+ *
+ * @param queryNode Root of query tree.
+ * @param referenceNode Root of reference tree.
+ */
+ void Traverse(CoverTree& queryNode, CoverTree& referenceNode);
+
+ /**
+ * Helper function for traversal of the two trees.
+ */
+ void Traverse(CoverTree& queryNode,
+ std::map<int, std::vector<DualCoverTreeMapEntry<
+ MetricType, RootPointPolicy, StatisticType> > >&
+ referenceMap);
+
+ //! Get the number of pruned nodes.
+ size_t NumPrunes() const { return numPrunes; }
+ //! Modify the number of pruned nodes.
+ size_t& NumPrunes() { return numPrunes; }
+
+ private:
+ //! The instantiated rule set for pruning branches.
+ RuleType& rule;
+
+ //! The number of pruned nodes.
+ size_t numPrunes;
+
+ //! Prepare map for recursion.
+ void PruneMap(CoverTree& queryNode,
+ CoverTree& candidateQueryNode,
+ std::map<int, std::vector<DualCoverTreeMapEntry<
+ MetricType, RootPointPolicy, StatisticType> > >&
+ referenceMap,
+ std::map<int, std::vector<DualCoverTreeMapEntry<
+ MetricType, RootPointPolicy, StatisticType> > >& childMap);
+
+ void PruneMapForSelfChild(CoverTree& candidateQueryNode,
+ std::map<int, std::vector<DualCoverTreeMapEntry<
+ MetricType, RootPointPolicy, StatisticType> > >&
+ referenceMap);
+
+ void ReferenceRecursion(CoverTree& queryNode,
+ std::map<int, std::vector<DualCoverTreeMapEntry<
+ MetricType, RootPointPolicy, StatisticType> > >&
+ referenceMap);
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+// Include implementation.
+#include "dual_tree_traverser_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,542 +0,0 @@
-/**
- * @file dual_tree_traverser_impl.hpp
- * @author Ryan Curtin
- *
- * A dual-tree traverser for the cover tree.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_COVER_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
-#define __MLPACK_CORE_TREE_COVER_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
-
-#include <mlpack/core.hpp>
-#include <queue>
-
-namespace mlpack {
-namespace tree {
-
-//! The object placed in the map for tree traversal.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-struct DualCoverTreeMapEntry
-{
- //! The node this entry refers to.
- CoverTree<MetricType, RootPointPolicy, StatisticType>* referenceNode;
- //! The score of the node.
- double score;
- //! The index of the reference node used for the base case evaluation.
- size_t referenceIndex;
- //! The index of the query node used for the base case evaluation.
- size_t queryIndex;
- //! The base case evaluation.
- double baseCase;
-
- //! Comparison operator, for sorting within the map.
- bool operator<(const DualCoverTreeMapEntry& other) const
- {
- return (score < other.score);
- }
-};
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-template<typename RuleType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::
-DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
- rule(rule),
- numPrunes(0)
-{ /* Nothing to do. */ }
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
-DualTreeTraverser<RuleType>::Traverse(
- CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
- CoverTree<MetricType, RootPointPolicy, StatisticType>& referenceNode)
-{
- typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
- MapEntryType;
-
- // Start by creating a map and adding the reference node to it.
- std::map<int, std::vector<MapEntryType> > refMap;
-
- MapEntryType rootRefEntry;
-
- rootRefEntry.referenceNode = &referenceNode;
- rootRefEntry.score = 0.0; // Must recurse into.
- rootRefEntry.referenceIndex = referenceNode.Point();
- rootRefEntry.queryIndex = queryNode.Point();
- rootRefEntry.baseCase = rule.BaseCase(queryNode.Point(),
- referenceNode.Point());
-// rule.UpdateAfterRecursion(queryNode, referenceNode);
-
- refMap[referenceNode.Scale()].push_back(rootRefEntry);
-
- Traverse(queryNode, refMap);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
-DualTreeTraverser<RuleType>::Traverse(
- CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
- std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
- StatisticType> > >& referenceMap)
-{
-// Log::Debug << "Recursed into query node " << queryNode.Point() << ", scale "
-// << queryNode.Scale() << "\n";
-
- // Convenience typedef.
- typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
- MapEntryType;
-
- if (referenceMap.size() == 0)
- return; // Nothing to do!
-
- // First recurse down the reference nodes as necessary.
- ReferenceRecursion(queryNode, referenceMap);
-
- // Now, reduce the scale of the query node by recursing. But we can't recurse
- // if the query node is a leaf node.
- if ((queryNode.Scale() != INT_MIN) &&
- (queryNode.Scale() >= (*referenceMap.rbegin()).first))
- {
- // Recurse into the non-self-children first. The recursion order cannot
- // affect the runtime of the algorithm, because each query child recursion's
- // results are separate and independent.
- for (size_t i = 1; i < queryNode.NumChildren(); ++i)
- {
- std::map<int, std::vector<MapEntryType> > childMap;
- PruneMap(queryNode, queryNode.Child(i), referenceMap, childMap);
-
-// Log::Debug << "Recurse into query child " << i << ": " <<
-// queryNode.Child(i).Point() << " scale " << queryNode.Child(i).Scale()
-// << "; this parent is " << queryNode.Point() << " scale " <<
-// queryNode.Scale() << std::endl;
- Traverse(queryNode.Child(i), childMap);
- }
-
- PruneMapForSelfChild(queryNode.Child(0), referenceMap);
-
- // Now we can use the existing map (without a copy) for the self-child.
-// Log::Warn << "Recurse into query self-child " << queryNode.Child(0).Point()
-// << " scale " << queryNode.Child(0).Scale() << "; this parent is "
-// << queryNode.Point() << " scale " << queryNode.Scale() << std::endl;
- Traverse(queryNode.Child(0), referenceMap);
- }
-
- if (queryNode.Scale() != INT_MIN)
- return; // No need to evaluate base cases at this level. It's all done.
-
- // If we have made it this far, all we have is a bunch of base case
- // evaluations to do.
- Log::Assert((*referenceMap.begin()).first == INT_MIN);
- Log::Assert(queryNode.Scale() == INT_MIN);
- std::vector<MapEntryType>& pointVector = (*referenceMap.begin()).second;
-// Log::Debug << "Onto base case evaluations\n";
-
- for (size_t i = 0; i < pointVector.size(); ++i)
- {
- // Get a reference to the frame.
- const MapEntryType& frame = pointVector[i];
-
- CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
- frame.referenceNode;
- const double oldScore = frame.score;
- const size_t refIndex = frame.referenceIndex;
- const size_t queryIndex = frame.queryIndex;
-// Log::Debug << "Consider query " << queryNode.Point() << ", reference "
-// << refNode->Point() << "\n";
-// Log::Debug << "Old score " << oldScore << " with refParent " << refIndex
-// << " and parent query node " << queryIndex << "\n";
-
- // First, ensure that we have not already calculated the base case.
- if ((refIndex == refNode->Point()) && (queryIndex == queryNode.Point()))
- {
-// Log::Debug << "Pruned because we already did the base case and its value "
-// << " was " << frame.baseCase << std::endl;
- ++numPrunes;
- continue;
- }
-
- // Now, check if we can prune it.
- const double rescore = rule.Rescore(queryNode, *refNode, oldScore);
-
- if (rescore == DBL_MAX)
- {
-// Log::Debug << "Pruned after rescoring\n";
- ++numPrunes;
- continue;
- }
-
- // If not, compute the base case.
-// Log::Debug << "Not pruned, performing base case " << queryNode.Point() <<
-// " " << pointVector[i].referenceNode->Point() << "\n";
- rule.BaseCase(queryNode.Point(), pointVector[i].referenceNode->Point());
- }
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
-DualTreeTraverser<RuleType>::PruneMap(
- CoverTree& /* queryNode */,
- CoverTree& candidateQueryNode,
- std::map<int, std::vector<DualCoverTreeMapEntry<MetricType,
- RootPointPolicy, StatisticType> > >& referenceMap,
- std::map<int, std::vector<DualCoverTreeMapEntry<MetricType,
- RootPointPolicy, StatisticType> > >& childMap)
-{
-// Log::Debug << "Prep for recurse into query child point " <<
-// candidateQueryNode.Point() << " scale " <<
-// candidateQueryNode.Scale() << std::endl;
- typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
- MapEntryType;
-
- typename std::map<int, std::vector<MapEntryType> >::reverse_iterator it =
- referenceMap.rbegin();
-
- while ((it != referenceMap.rend()) && ((*it).first != INT_MIN))
- {
- // Get a reference to the vector representing the entries at this scale.
- const std::vector<MapEntryType>& scaleVector = (*it).second;
-
- std::vector<MapEntryType>& newScaleVector = childMap[(*it).first];
- newScaleVector.reserve(scaleVector.size()); // Maximum possible size.
-
- // Loop over each entry in the vector.
- for (size_t j = 0; j < scaleVector.size(); ++j)
- {
- const MapEntryType& frame = scaleVector[j];
-
- // First evaluate if we can prune without performing the base case.
- CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
- frame.referenceNode;
- const double oldScore = frame.score;
-
- // Try to prune based on shell(). This is hackish and will need to be
- // refined or cleaned at some point.
-// double score = rule.PrescoreQ(queryNode, candidateQueryNode, *refNode,
-// frame.baseCase);
-
-// if (score == DBL_MAX)
-// {
-// ++numPrunes;
-// continue;
-// }
-
-// Log::Debug << "Recheck reference node " << refNode->Point() <<
-// " scale " << refNode->Scale() << " which has old score " <<
-// oldScore << " with old reference index " << frame.referenceIndex
-// << " and old query index " << frame.queryIndex << std::endl;
-
- double score = rule.Rescore(candidateQueryNode, *refNode, oldScore);
-
-// Log::Debug << "Rescored as " << score << std::endl;
-
- if (score == DBL_MAX)
- {
- // Pruned. Move on.
- ++numPrunes;
- continue;
- }
-
- // Evaluate base case.
-// Log::Debug << "Must evaluate base case "
-// << candidateQueryNode.Point() << " " << refNode->Point()
-// << "\n";
- double baseCase = rule.BaseCase(candidateQueryNode.Point(),
- refNode->Point());
-// Log::Debug << "Base case was " << baseCase << std::endl;
-
- score = rule.Score(candidateQueryNode, *refNode, baseCase);
-
- if (score == DBL_MAX)
- {
- // Pruned. Move on.
- ++numPrunes;
- continue;
- }
-
- // Add to child map.
- newScaleVector.push_back(frame);
- newScaleVector.back().score = score;
- newScaleVector.back().baseCase = baseCase;
- newScaleVector.back().referenceIndex = refNode->Point();
- newScaleVector.back().queryIndex = candidateQueryNode.Point();
- }
-
- // If we didn't add anything, then strike this vector from the map.
- if (newScaleVector.size() == 0)
- childMap.erase((*it).first);
-
- ++it; // Advance to next scale.
- }
-
- childMap[INT_MIN] = referenceMap[INT_MIN];
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
-DualTreeTraverser<RuleType>::PruneMapForSelfChild(
- CoverTree& candidateQueryNode,
- std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
- StatisticType> > >& referenceMap)
-{
-// Log::Debug << "Prep for recurse into query self-child point " <<
-// candidateQueryNode.Point() << " scale " <<
-// candidateQueryNode.Scale() << std::endl;
- typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
- MapEntryType;
-
- // Create the child reference map. We will do this by recursing through
- // every entry in the reference map and evaluating (or pruning) it. But
- // in this setting we do not recurse into any children of the reference
- // entries.
- typename std::map<int, std::vector<MapEntryType> >::reverse_iterator it =
- referenceMap.rbegin();
-
- while (it != referenceMap.rend() && (*it).first != INT_MIN)
- {
- // Get a reference to the vector representing the entries at this scale.
- std::vector<MapEntryType>& newScaleVector = (*it).second;
- const std::vector<MapEntryType> scaleVector = newScaleVector;
-
- newScaleVector.clear();
- newScaleVector.reserve(scaleVector.size());
-
- // Loop over each entry in the vector.
- for (size_t j = 0; j < scaleVector.size(); ++j)
- {
- const MapEntryType& frame = scaleVector[j];
-
- // First evaluate if we can prune without performing the base case.
- CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
- frame.referenceNode;
- const double oldScore = frame.score;
- double baseCase = frame.baseCase;
- const size_t queryIndex = frame.queryIndex;
- const size_t refIndex = frame.referenceIndex;
-
-// Log::Debug << "Recheck reference node " << refNode->Point() << " scale "
-// << refNode->Scale() << " which has old score " << oldScore
-// << std::endl;
-
- // Have we performed the base case yet?
- double score;
- if ((refIndex != refNode->Point()) ||
- (queryIndex != candidateQueryNode.Point()))
- {
- // Attempt to rescore before performing the base case.
- score = rule.Rescore(candidateQueryNode, *refNode, oldScore);
-
- if (score == DBL_MAX)
- {
- ++numPrunes;
- continue;
- }
-
- // If not pruned, we have to perform the base case.
- baseCase = rule.BaseCase(candidateQueryNode.Point(), refNode->Point());
- }
-
- score = rule.Score(candidateQueryNode, *refNode, baseCase);
-
-// Log::Debug << "Rescored as " << score << std::endl;
-
- if (score == DBL_MAX)
- {
- // Pruned. Move on.
- ++numPrunes;
- continue;
- }
-
-// Log::Debug << "Kept in map\n";
-
- // Add to child map.
- newScaleVector.push_back(frame);
- newScaleVector.back().score = score;
- newScaleVector.back().baseCase = baseCase;
- newScaleVector.back().queryIndex = candidateQueryNode.Point();
- newScaleVector.back().referenceIndex = refNode->Point();
- }
-
- // If we didn't add anything, then strike this vector from the map.
- if (newScaleVector.size() == 0)
- referenceMap.erase((*it).first);
-
- ++it; // Advance to next scale.
- }
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
-DualTreeTraverser<RuleType>::ReferenceRecursion(
- CoverTree& queryNode,
- std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
- StatisticType> > >& referenceMap)
-{
- typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
- MapEntryType;
-
- // First, reduce the maximum scale in the reference map down to the scale of
- // the query node.
- while ((*referenceMap.rbegin()).first > queryNode.Scale())
- {
- // If the query node's scale is INT_MIN and the reference map's maximum
- // scale is INT_MIN, don't try to recurse...
- if ((queryNode.Scale() == INT_MIN) &&
- ((*referenceMap.rbegin()).first == INT_MIN))
- break;
-
- // Get a reference to the current largest scale.
- std::vector<MapEntryType>& scaleVector = (*referenceMap.rbegin()).second;
-
- // Before traversing all the points in this scale, sort by score.
- std::sort(scaleVector.begin(), scaleVector.end());
-
- // Now loop over each element.
- for (size_t i = 0; i < scaleVector.size(); ++i)
- {
- // Get a reference to the current element.
- const MapEntryType& frame = scaleVector.at(i);
-
- CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
- frame.referenceNode;
- const double score = frame.score;
- const size_t refIndex = frame.referenceIndex;
- const size_t refPoint = refNode->Point();
- const size_t queryIndex = frame.queryIndex;
- const size_t queryPoint = queryNode.Point();
- double baseCase = frame.baseCase;
-
-// Log::Debug << "Currently inspecting reference node " << refNode->Point()
-// << " scale " << refNode->Scale() << " earlier query index " <<
-// queryIndex << std::endl;
-
-// Log::Debug << "Old score " << score << " with refParent " << refIndex
-// << " and queryIndex " << queryIndex << "\n";
-
- // First we recalculate the score of this node to find if we can prune it.
- if (rule.Rescore(queryNode, *refNode, score) == DBL_MAX)
- {
-// Log::Warn << "Pruned after rescore\n";
- ++numPrunes;
- continue;
- }
-
- // If this is a self-child, the base case has already been evaluated.
- // We also must ensure that the base case was evaluated with this query
- // point.
- if ((refPoint != refIndex) || (queryPoint != queryIndex))
- {
-// Log::Warn << "Must evaluate base case " << queryNode.Point() << " "
-// << refPoint << "\n";
- baseCase = rule.BaseCase(queryPoint, refPoint);
-// Log::Debug << "Base case " << baseCase << std::endl;
- }
-
- // Create the score for the children.
- double childScore = rule.Score(queryNode, *refNode, baseCase);
-
- // Now if this childScore is DBL_MAX we can prune all children. In this
- // recursion setup pruning is all or nothing for children.
- if (childScore == DBL_MAX)
- {
-// Log::Warn << "Pruned all children.\n";
- numPrunes += refNode->NumChildren();
- continue;
- }
-
- // We must treat the self-leaf differently. The base case has already
- // been performed.
- childScore = rule.Score(queryNode, refNode->Child(0), baseCase);
-
- if (childScore != DBL_MAX)
- {
- MapEntryType newFrame;
- newFrame.referenceNode = &refNode->Child(0);
- newFrame.score = childScore;
- newFrame.baseCase = baseCase;
- newFrame.referenceIndex = refPoint;
- newFrame.queryIndex = queryNode.Point();
-
- referenceMap[newFrame.referenceNode->Scale()].push_back(newFrame);
- }
- else
- {
- ++numPrunes;
- }
-
- // Add the non-self-leaf children.
- for (size_t j = 1; j < refNode->NumChildren(); ++j)
- {
- const size_t queryIndex = queryNode.Point();
- const size_t refIndex = refNode->Child(j).Point();
-
- // We need to incorporate shell() here to try and avoid base case
- // computations. TODO
-// Log::Debug << "Prescore query " << queryNode.Point() << " scale "
-// << queryNode.Scale() << ", reference " << refNode->Point() <<
-// " scale " << refNode->Scale() << ", reference child " <<
-// refNode->Child(j).Point() << " scale " << refNode->Child(j).Scale()
-// << " with base case " << baseCase;
-// childScore = rule.Prescore(queryNode, *refNode, refNode->Child(j),
-// frame.baseCase);
-// Log::Debug << " and result " << childScore << ".\n";
-
-// if (childScore == DBL_MAX)
-// {
-// ++numPrunes;
-// continue;
-// }
-
- // Calculate the base case of each child.
- baseCase = rule.BaseCase(queryIndex, refIndex);
-
- // See if we can prune it.
- double childScore = rule.Score(queryNode, refNode->Child(j), baseCase);
-
- if (childScore == DBL_MAX)
- {
- ++numPrunes;
- continue;
- }
-
- MapEntryType newFrame;
- newFrame.referenceNode = &refNode->Child(j);
- newFrame.score = childScore;
- newFrame.baseCase = baseCase;
- newFrame.referenceIndex = refIndex;
- newFrame.queryIndex = queryIndex;
-
-// Log::Debug << "Push onto map child " << refNode->Child(j).Point() <<
-// " scale " << refNode->Child(j).Scale() << std::endl;
-
- referenceMap[newFrame.referenceNode->Scale()].push_back(newFrame);
- }
- }
-
- // Now clear the memory for this scale; it isn't needed anymore.
- referenceMap.erase((*referenceMap.rbegin()).first);
- }
-
-}
-
-}; // namespace tree
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,542 @@
+/**
+ * @file dual_tree_traverser_impl.hpp
+ * @author Ryan Curtin
+ *
+ * A dual-tree traverser for the cover tree.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_COVER_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
+#define __MLPACK_CORE_TREE_COVER_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
+
+#include <mlpack/core.hpp>
+#include <queue>
+
+namespace mlpack {
+namespace tree {
+
+//! The object placed in the map for tree traversal.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+struct DualCoverTreeMapEntry
+{
+ //! The node this entry refers to.
+ CoverTree<MetricType, RootPointPolicy, StatisticType>* referenceNode;
+ //! The score of the node.
+ double score;
+ //! The index of the reference node used for the base case evaluation.
+ size_t referenceIndex;
+ //! The index of the query node used for the base case evaluation.
+ size_t queryIndex;
+ //! The base case evaluation.
+ double baseCase;
+
+ //! Comparison operator, for sorting within the map.
+ bool operator<(const DualCoverTreeMapEntry& other) const
+ {
+ return (score < other.score);
+ }
+};
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<typename RuleType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::
+DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
+ rule(rule),
+ numPrunes(0)
+{ /* Nothing to do. */ }
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<typename RuleType>
+void CoverTree<MetricType, RootPointPolicy, StatisticType>::
+DualTreeTraverser<RuleType>::Traverse(
+ CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
+ CoverTree<MetricType, RootPointPolicy, StatisticType>& referenceNode)
+{
+ typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
+ MapEntryType;
+
+ // Start by creating a map and adding the reference node to it.
+ std::map<int, std::vector<MapEntryType> > refMap;
+
+ MapEntryType rootRefEntry;
+
+ rootRefEntry.referenceNode = &referenceNode;
+ rootRefEntry.score = 0.0; // Must recurse into.
+ rootRefEntry.referenceIndex = referenceNode.Point();
+ rootRefEntry.queryIndex = queryNode.Point();
+ rootRefEntry.baseCase = rule.BaseCase(queryNode.Point(),
+ referenceNode.Point());
+// rule.UpdateAfterRecursion(queryNode, referenceNode);
+
+ refMap[referenceNode.Scale()].push_back(rootRefEntry);
+
+ Traverse(queryNode, refMap);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<typename RuleType>
+void CoverTree<MetricType, RootPointPolicy, StatisticType>::
+DualTreeTraverser<RuleType>::Traverse(
+ CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
+ std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
+ StatisticType> > >& referenceMap)
+{
+// Log::Debug << "Recursed into query node " << queryNode.Point() << ", scale "
+// << queryNode.Scale() << "\n";
+
+ // Convenience typedef.
+ typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
+ MapEntryType;
+
+ if (referenceMap.size() == 0)
+ return; // Nothing to do!
+
+ // First recurse down the reference nodes as necessary.
+ ReferenceRecursion(queryNode, referenceMap);
+
+ // Now, reduce the scale of the query node by recursing. But we can't recurse
+ // if the query node is a leaf node.
+ if ((queryNode.Scale() != INT_MIN) &&
+ (queryNode.Scale() >= (*referenceMap.rbegin()).first))
+ {
+ // Recurse into the non-self-children first. The recursion order cannot
+ // affect the runtime of the algorithm, because each query child recursion's
+ // results are separate and independent.
+ for (size_t i = 1; i < queryNode.NumChildren(); ++i)
+ {
+ std::map<int, std::vector<MapEntryType> > childMap;
+ PruneMap(queryNode, queryNode.Child(i), referenceMap, childMap);
+
+// Log::Debug << "Recurse into query child " << i << ": " <<
+// queryNode.Child(i).Point() << " scale " << queryNode.Child(i).Scale()
+// << "; this parent is " << queryNode.Point() << " scale " <<
+// queryNode.Scale() << std::endl;
+ Traverse(queryNode.Child(i), childMap);
+ }
+
+ PruneMapForSelfChild(queryNode.Child(0), referenceMap);
+
+ // Now we can use the existing map (without a copy) for the self-child.
+// Log::Warn << "Recurse into query self-child " << queryNode.Child(0).Point()
+// << " scale " << queryNode.Child(0).Scale() << "; this parent is "
+// << queryNode.Point() << " scale " << queryNode.Scale() << std::endl;
+ Traverse(queryNode.Child(0), referenceMap);
+ }
+
+ if (queryNode.Scale() != INT_MIN)
+ return; // No need to evaluate base cases at this level. It's all done.
+
+ // If we have made it this far, all we have is a bunch of base case
+ // evaluations to do.
+ Log::Assert((*referenceMap.begin()).first == INT_MIN);
+ Log::Assert(queryNode.Scale() == INT_MIN);
+ std::vector<MapEntryType>& pointVector = (*referenceMap.begin()).second;
+// Log::Debug << "Onto base case evaluations\n";
+
+ for (size_t i = 0; i < pointVector.size(); ++i)
+ {
+ // Get a reference to the frame.
+ const MapEntryType& frame = pointVector[i];
+
+ CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
+ frame.referenceNode;
+ const double oldScore = frame.score;
+ const size_t refIndex = frame.referenceIndex;
+ const size_t queryIndex = frame.queryIndex;
+// Log::Debug << "Consider query " << queryNode.Point() << ", reference "
+// << refNode->Point() << "\n";
+// Log::Debug << "Old score " << oldScore << " with refParent " << refIndex
+// << " and parent query node " << queryIndex << "\n";
+
+ // First, ensure that we have not already calculated the base case.
+ if ((refIndex == refNode->Point()) && (queryIndex == queryNode.Point()))
+ {
+// Log::Debug << "Pruned because we already did the base case and its value "
+// << " was " << frame.baseCase << std::endl;
+ ++numPrunes;
+ continue;
+ }
+
+ // Now, check if we can prune it.
+ const double rescore = rule.Rescore(queryNode, *refNode, oldScore);
+
+ if (rescore == DBL_MAX)
+ {
+// Log::Debug << "Pruned after rescoring\n";
+ ++numPrunes;
+ continue;
+ }
+
+ // If not, compute the base case.
+// Log::Debug << "Not pruned, performing base case " << queryNode.Point() <<
+// " " << pointVector[i].referenceNode->Point() << "\n";
+ rule.BaseCase(queryNode.Point(), pointVector[i].referenceNode->Point());
+ }
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<typename RuleType>
+void CoverTree<MetricType, RootPointPolicy, StatisticType>::
+DualTreeTraverser<RuleType>::PruneMap(
+ CoverTree& /* queryNode */,
+ CoverTree& candidateQueryNode,
+ std::map<int, std::vector<DualCoverTreeMapEntry<MetricType,
+ RootPointPolicy, StatisticType> > >& referenceMap,
+ std::map<int, std::vector<DualCoverTreeMapEntry<MetricType,
+ RootPointPolicy, StatisticType> > >& childMap)
+{
+// Log::Debug << "Prep for recurse into query child point " <<
+// candidateQueryNode.Point() << " scale " <<
+// candidateQueryNode.Scale() << std::endl;
+ typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
+ MapEntryType;
+
+ typename std::map<int, std::vector<MapEntryType> >::reverse_iterator it =
+ referenceMap.rbegin();
+
+ while ((it != referenceMap.rend()) && ((*it).first != INT_MIN))
+ {
+ // Get a reference to the vector representing the entries at this scale.
+ const std::vector<MapEntryType>& scaleVector = (*it).second;
+
+ std::vector<MapEntryType>& newScaleVector = childMap[(*it).first];
+ newScaleVector.reserve(scaleVector.size()); // Maximum possible size.
+
+ // Loop over each entry in the vector.
+ for (size_t j = 0; j < scaleVector.size(); ++j)
+ {
+ const MapEntryType& frame = scaleVector[j];
+
+ // First evaluate if we can prune without performing the base case.
+ CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
+ frame.referenceNode;
+ const double oldScore = frame.score;
+
+ // Try to prune based on shell(). This is hackish and will need to be
+ // refined or cleaned at some point.
+// double score = rule.PrescoreQ(queryNode, candidateQueryNode, *refNode,
+// frame.baseCase);
+
+// if (score == DBL_MAX)
+// {
+// ++numPrunes;
+// continue;
+// }
+
+// Log::Debug << "Recheck reference node " << refNode->Point() <<
+// " scale " << refNode->Scale() << " which has old score " <<
+// oldScore << " with old reference index " << frame.referenceIndex
+// << " and old query index " << frame.queryIndex << std::endl;
+
+ double score = rule.Rescore(candidateQueryNode, *refNode, oldScore);
+
+// Log::Debug << "Rescored as " << score << std::endl;
+
+ if (score == DBL_MAX)
+ {
+ // Pruned. Move on.
+ ++numPrunes;
+ continue;
+ }
+
+ // Evaluate base case.
+// Log::Debug << "Must evaluate base case "
+// << candidateQueryNode.Point() << " " << refNode->Point()
+// << "\n";
+ double baseCase = rule.BaseCase(candidateQueryNode.Point(),
+ refNode->Point());
+// Log::Debug << "Base case was " << baseCase << std::endl;
+
+ score = rule.Score(candidateQueryNode, *refNode, baseCase);
+
+ if (score == DBL_MAX)
+ {
+ // Pruned. Move on.
+ ++numPrunes;
+ continue;
+ }
+
+ // Add to child map.
+ newScaleVector.push_back(frame);
+ newScaleVector.back().score = score;
+ newScaleVector.back().baseCase = baseCase;
+ newScaleVector.back().referenceIndex = refNode->Point();
+ newScaleVector.back().queryIndex = candidateQueryNode.Point();
+ }
+
+ // If we didn't add anything, then strike this vector from the map.
+ if (newScaleVector.size() == 0)
+ childMap.erase((*it).first);
+
+ ++it; // Advance to next scale.
+ }
+
+ childMap[INT_MIN] = referenceMap[INT_MIN];
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<typename RuleType>
+void CoverTree<MetricType, RootPointPolicy, StatisticType>::
+DualTreeTraverser<RuleType>::PruneMapForSelfChild(
+ CoverTree& candidateQueryNode,
+ std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
+ StatisticType> > >& referenceMap)
+{
+// Log::Debug << "Prep for recurse into query self-child point " <<
+// candidateQueryNode.Point() << " scale " <<
+// candidateQueryNode.Scale() << std::endl;
+ typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
+ MapEntryType;
+
+ // Create the child reference map. We will do this by recursing through
+ // every entry in the reference map and evaluating (or pruning) it. But
+ // in this setting we do not recurse into any children of the reference
+ // entries.
+ typename std::map<int, std::vector<MapEntryType> >::reverse_iterator it =
+ referenceMap.rbegin();
+
+ while (it != referenceMap.rend() && (*it).first != INT_MIN)
+ {
+ // Get a reference to the vector representing the entries at this scale.
+ std::vector<MapEntryType>& newScaleVector = (*it).second;
+ const std::vector<MapEntryType> scaleVector = newScaleVector;
+
+ newScaleVector.clear();
+ newScaleVector.reserve(scaleVector.size());
+
+ // Loop over each entry in the vector.
+ for (size_t j = 0; j < scaleVector.size(); ++j)
+ {
+ const MapEntryType& frame = scaleVector[j];
+
+ // First evaluate if we can prune without performing the base case.
+ CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
+ frame.referenceNode;
+ const double oldScore = frame.score;
+ double baseCase = frame.baseCase;
+ const size_t queryIndex = frame.queryIndex;
+ const size_t refIndex = frame.referenceIndex;
+
+// Log::Debug << "Recheck reference node " << refNode->Point() << " scale "
+// << refNode->Scale() << " which has old score " << oldScore
+// << std::endl;
+
+ // Have we performed the base case yet?
+ double score;
+ if ((refIndex != refNode->Point()) ||
+ (queryIndex != candidateQueryNode.Point()))
+ {
+ // Attempt to rescore before performing the base case.
+ score = rule.Rescore(candidateQueryNode, *refNode, oldScore);
+
+ if (score == DBL_MAX)
+ {
+ ++numPrunes;
+ continue;
+ }
+
+ // If not pruned, we have to perform the base case.
+ baseCase = rule.BaseCase(candidateQueryNode.Point(), refNode->Point());
+ }
+
+ score = rule.Score(candidateQueryNode, *refNode, baseCase);
+
+// Log::Debug << "Rescored as " << score << std::endl;
+
+ if (score == DBL_MAX)
+ {
+ // Pruned. Move on.
+ ++numPrunes;
+ continue;
+ }
+
+// Log::Debug << "Kept in map\n";
+
+ // Add to child map.
+ newScaleVector.push_back(frame);
+ newScaleVector.back().score = score;
+ newScaleVector.back().baseCase = baseCase;
+ newScaleVector.back().queryIndex = candidateQueryNode.Point();
+ newScaleVector.back().referenceIndex = refNode->Point();
+ }
+
+ // If we didn't add anything, then strike this vector from the map.
+ if (newScaleVector.size() == 0)
+ referenceMap.erase((*it).first);
+
+ ++it; // Advance to next scale.
+ }
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<typename RuleType>
+void CoverTree<MetricType, RootPointPolicy, StatisticType>::
+DualTreeTraverser<RuleType>::ReferenceRecursion(
+ CoverTree& queryNode,
+ std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
+ StatisticType> > >& referenceMap)
+{
+ typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
+ MapEntryType;
+
+ // First, reduce the maximum scale in the reference map down to the scale of
+ // the query node.
+ while ((*referenceMap.rbegin()).first > queryNode.Scale())
+ {
+ // If the query node's scale is INT_MIN and the reference map's maximum
+ // scale is INT_MIN, don't try to recurse...
+ if ((queryNode.Scale() == INT_MIN) &&
+ ((*referenceMap.rbegin()).first == INT_MIN))
+ break;
+
+ // Get a reference to the current largest scale.
+ std::vector<MapEntryType>& scaleVector = (*referenceMap.rbegin()).second;
+
+ // Before traversing all the points in this scale, sort by score.
+ std::sort(scaleVector.begin(), scaleVector.end());
+
+ // Now loop over each element.
+ for (size_t i = 0; i < scaleVector.size(); ++i)
+ {
+ // Get a reference to the current element.
+ const MapEntryType& frame = scaleVector.at(i);
+
+ CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
+ frame.referenceNode;
+ const double score = frame.score;
+ const size_t refIndex = frame.referenceIndex;
+ const size_t refPoint = refNode->Point();
+ const size_t queryIndex = frame.queryIndex;
+ const size_t queryPoint = queryNode.Point();
+ double baseCase = frame.baseCase;
+
+// Log::Debug << "Currently inspecting reference node " << refNode->Point()
+// << " scale " << refNode->Scale() << " earlier query index " <<
+// queryIndex << std::endl;
+
+// Log::Debug << "Old score " << score << " with refParent " << refIndex
+// << " and queryIndex " << queryIndex << "\n";
+
+ // First we recalculate the score of this node to find if we can prune it.
+ if (rule.Rescore(queryNode, *refNode, score) == DBL_MAX)
+ {
+// Log::Warn << "Pruned after rescore\n";
+ ++numPrunes;
+ continue;
+ }
+
+ // If this is a self-child, the base case has already been evaluated.
+ // We also must ensure that the base case was evaluated with this query
+ // point.
+ if ((refPoint != refIndex) || (queryPoint != queryIndex))
+ {
+// Log::Warn << "Must evaluate base case " << queryNode.Point() << " "
+// << refPoint << "\n";
+ baseCase = rule.BaseCase(queryPoint, refPoint);
+// Log::Debug << "Base case " << baseCase << std::endl;
+ }
+
+ // Create the score for the children.
+ double childScore = rule.Score(queryNode, *refNode, baseCase);
+
+ // Now if this childScore is DBL_MAX we can prune all children. In this
+ // recursion setup pruning is all or nothing for children.
+ if (childScore == DBL_MAX)
+ {
+// Log::Warn << "Pruned all children.\n";
+ numPrunes += refNode->NumChildren();
+ continue;
+ }
+
+ // We must treat the self-leaf differently. The base case has already
+ // been performed.
+ childScore = rule.Score(queryNode, refNode->Child(0), baseCase);
+
+ if (childScore != DBL_MAX)
+ {
+ MapEntryType newFrame;
+ newFrame.referenceNode = &refNode->Child(0);
+ newFrame.score = childScore;
+ newFrame.baseCase = baseCase;
+ newFrame.referenceIndex = refPoint;
+ newFrame.queryIndex = queryNode.Point();
+
+ referenceMap[newFrame.referenceNode->Scale()].push_back(newFrame);
+ }
+ else
+ {
+ ++numPrunes;
+ }
+
+ // Add the non-self-leaf children.
+ for (size_t j = 1; j < refNode->NumChildren(); ++j)
+ {
+ const size_t queryIndex = queryNode.Point();
+ const size_t refIndex = refNode->Child(j).Point();
+
+ // We need to incorporate shell() here to try and avoid base case
+ // computations. TODO
+// Log::Debug << "Prescore query " << queryNode.Point() << " scale "
+// << queryNode.Scale() << ", reference " << refNode->Point() <<
+// " scale " << refNode->Scale() << ", reference child " <<
+// refNode->Child(j).Point() << " scale " << refNode->Child(j).Scale()
+// << " with base case " << baseCase;
+// childScore = rule.Prescore(queryNode, *refNode, refNode->Child(j),
+// frame.baseCase);
+// Log::Debug << " and result " << childScore << ".\n";
+
+// if (childScore == DBL_MAX)
+// {
+// ++numPrunes;
+// continue;
+// }
+
+ // Calculate the base case of each child.
+ baseCase = rule.BaseCase(queryIndex, refIndex);
+
+ // See if we can prune it.
+ double childScore = rule.Score(queryNode, refNode->Child(j), baseCase);
+
+ if (childScore == DBL_MAX)
+ {
+ ++numPrunes;
+ continue;
+ }
+
+ MapEntryType newFrame;
+ newFrame.referenceNode = &refNode->Child(j);
+ newFrame.score = childScore;
+ newFrame.baseCase = baseCase;
+ newFrame.referenceIndex = refIndex;
+ newFrame.queryIndex = queryIndex;
+
+// Log::Debug << "Push onto map child " << refNode->Child(j).Point() <<
+// " scale " << refNode->Child(j).Scale() << std::endl;
+
+ referenceMap[newFrame.referenceNode->Scale()].push_back(newFrame);
+ }
+ }
+
+ // Now clear the memory for this scale; it isn't needed anymore.
+ referenceMap.erase((*referenceMap.rbegin()).first);
+ }
+
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/first_point_is_root.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/cover_tree/first_point_is_root.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/first_point_is_root.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,52 +0,0 @@
-/**
- * @file first_point_is_root.hpp
- * @author Ryan Curtin
- *
- * A very simple policy for the cover tree; the first point in the dataset is
- * chosen as the root of the cover tree.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_FIRST_POINT_IS_ROOT_HPP
-#define __MLPACK_CORE_TREE_FIRST_POINT_IS_ROOT_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace tree {
-
-/**
- * This class is meant to be used as a choice for the policy class
- * RootPointPolicy of the CoverTree class. This policy determines which point
- * is used for the root node of the cover tree. This particular implementation
- * simply chooses the first point in the dataset as the root. A more complex
- * implementation might choose, for instance, the point with least maximum
- * distance to other points (the closest to the "middle").
- */
-class FirstPointIsRoot
-{
- public:
- /**
- * Return the point to be used as the root point of the cover tree. This just
- * returns 0.
- */
- static size_t ChooseRoot(const arma::mat& /* dataset */) { return 0; }
-};
-
-}; // namespace tree
-}; // namespace mlpack
-
-#endif // __MLPACK_CORE_TREE_FIRST_POINT_IS_ROOT_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/first_point_is_root.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/cover_tree/first_point_is_root.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/first_point_is_root.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/first_point_is_root.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,52 @@
+/**
+ * @file first_point_is_root.hpp
+ * @author Ryan Curtin
+ *
+ * A very simple policy for the cover tree; the first point in the dataset is
+ * chosen as the root of the cover tree.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_FIRST_POINT_IS_ROOT_HPP
+#define __MLPACK_CORE_TREE_FIRST_POINT_IS_ROOT_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace tree {
+
+/**
+ * This class is meant to be used as a choice for the policy class
+ * RootPointPolicy of the CoverTree class. This policy determines which point
+ * is used for the root node of the cover tree. This particular implementation
+ * simply chooses the first point in the dataset as the root. A more complex
+ * implementation might choose, for instance, the point with least maximum
+ * distance to other points (the closest to the "middle").
+ */
+class FirstPointIsRoot
+{
+ public:
+ /**
+ * Return the point to be used as the root point of the cover tree. This just
+ * returns 0.
+ */
+ static size_t ChooseRoot(const arma::mat& /* dataset */) { return 0; }
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_TREE_FIRST_POINT_IS_ROOT_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/single_tree_traverser.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/cover_tree/single_tree_traverser.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/single_tree_traverser.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,72 +0,0 @@
-/**
- * @file single_tree_traverser.hpp
- * @author Ryan Curtin
- *
- * Defines the SingleTreeTraverser for the cover tree. This implements a
- * single-tree breadth-first recursion with a pruning rule and a base case (two
- * point) rule.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_COVER_TREE_SINGLE_TREE_TRAVERSER_HPP
-#define __MLPACK_CORE_TREE_COVER_TREE_SINGLE_TREE_TRAVERSER_HPP
-
-#include <mlpack/core.hpp>
-
-#include "cover_tree.hpp"
-
-namespace mlpack {
-namespace tree {
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-template<typename RuleType>
-class CoverTree<MetricType, RootPointPolicy, StatisticType>::SingleTreeTraverser
-{
- public:
- /**
- * Initialize the single tree traverser with the given rule.
- */
- SingleTreeTraverser(RuleType& rule);
-
- /**
- * Traverse the tree with the given point.
- *
- * @param queryIndex The index of the point in the query set which is used as
- * the query point.
- * @param referenceNode The tree node to be traversed.
- */
- void Traverse(const size_t queryIndex, CoverTree& referenceNode);
-
- //! Get the number of prunes so far.
- size_t NumPrunes() const { return numPrunes; }
- //! Set the number of prunes (good for a reset to 0).
- size_t& NumPrunes() { return numPrunes; }
-
- private:
- //! Reference to the rules with which the tree will be traversed.
- RuleType& rule;
-
- //! The number of nodes which have been pruned during traversal.
- size_t numPrunes;
-};
-
-}; // namespace tree
-}; // namespace mlpack
-
-// Include implementation.
-#include "single_tree_traverser_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/single_tree_traverser.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/cover_tree/single_tree_traverser.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/single_tree_traverser.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/single_tree_traverser.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,72 @@
+/**
+ * @file single_tree_traverser.hpp
+ * @author Ryan Curtin
+ *
+ * Defines the SingleTreeTraverser for the cover tree. This implements a
+ * single-tree breadth-first recursion with a pruning rule and a base case (two
+ * point) rule.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_COVER_TREE_SINGLE_TREE_TRAVERSER_HPP
+#define __MLPACK_CORE_TREE_COVER_TREE_SINGLE_TREE_TRAVERSER_HPP
+
+#include <mlpack/core.hpp>
+
+#include "cover_tree.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<typename RuleType>
+class CoverTree<MetricType, RootPointPolicy, StatisticType>::SingleTreeTraverser
+{
+ public:
+ /**
+ * Initialize the single tree traverser with the given rule.
+ */
+ SingleTreeTraverser(RuleType& rule);
+
+ /**
+ * Traverse the tree with the given point.
+ *
+ * @param queryIndex The index of the point in the query set which is used as
+ * the query point.
+ * @param referenceNode The tree node to be traversed.
+ */
+ void Traverse(const size_t queryIndex, CoverTree& referenceNode);
+
+ //! Get the number of prunes so far.
+ size_t NumPrunes() const { return numPrunes; }
+ //! Set the number of prunes (good for a reset to 0).
+ size_t& NumPrunes() { return numPrunes; }
+
+ private:
+ //! Reference to the rules with which the tree will be traversed.
+ RuleType& rule;
+
+ //! The number of nodes which have been pruned during traversal.
+ size_t numPrunes;
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+// Include implementation.
+#include "single_tree_traverser_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,213 +0,0 @@
-/**
- * @file single_tree_traverser_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of the single tree traverser for cover trees, which implements
- * a breadth-first traversal.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_COVER_TREE_SINGLE_TREE_TRAVERSER_IMPL_HPP
-#define __MLPACK_CORE_TREE_COVER_TREE_SINGLE_TREE_TRAVERSER_IMPL_HPP
-
-// In case it hasn't been included yet.
-#include "single_tree_traverser.hpp"
-
-#include <queue>
-
-namespace mlpack {
-namespace tree {
-
-//! This is the structure the cover tree map will use for traversal.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-struct CoverTreeMapEntry
-{
- //! The node this entry refers to.
- CoverTree<MetricType, RootPointPolicy, StatisticType>* node;
- //! The score of the node.
- double score;
- //! The index of the parent node.
- size_t parent;
- //! The base case evaluation.
- double baseCase;
-
- //! Comparison operator.
- bool operator<(const CoverTreeMapEntry& other) const
- {
- return (score < other.score);
- }
-};
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-template<typename RuleType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::
-SingleTreeTraverser<RuleType>::SingleTreeTraverser(RuleType& rule) :
- rule(rule),
- numPrunes(0)
-{ /* Nothing to do. */ }
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
-SingleTreeTraverser<RuleType>::Traverse(
- const size_t queryIndex,
- CoverTree<MetricType, RootPointPolicy, StatisticType>& referenceNode)
-{
- // This is a non-recursive implementation (which should be faster than a
- // recursive implementation).
- typedef CoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
- MapEntryType;
-
- // We will use this map as a priority queue. Each key represents the scale,
- // and then the vector is all the nodes in that scale which need to be
- // investigated. Because no point in a scale can add a point in its own
- // scale, we know that the vector for each scale is final when we get to it.
- // In addition, map is organized in such a way that rbegin() will return the
- // largest scale.
- std::map<int, std::vector<MapEntryType> > mapQueue;
-
- // Manually add the children of the first node. These cannot be pruned
- // anyway.
- double rootBaseCase = rule.BaseCase(queryIndex, referenceNode.Point());
-
- // Create the score for the children.
- double rootChildScore = rule.Score(queryIndex, referenceNode, rootBaseCase);
-
- if (rootChildScore == DBL_MAX)
- {
- numPrunes += referenceNode.NumChildren();
- }
- else
- {
- // Don't add the self-leaf.
- size_t i = 0;
- if (referenceNode.Child(0).NumChildren() == 0)
- {
- ++numPrunes;
- i = 1;
- }
-
- for (/* i was set above. */; i < referenceNode.NumChildren(); ++i)
- {
- MapEntryType newFrame;
- newFrame.node = &referenceNode.Child(i);
- newFrame.score = rootChildScore;
- newFrame.baseCase = rootBaseCase;
- newFrame.parent = referenceNode.Point();
-
- // Put it into the map.
- mapQueue[newFrame.node->Scale()].push_back(newFrame);
- }
- }
-
- // Now begin the iteration through the map.
- typename std::map<int, std::vector<MapEntryType> >::reverse_iterator rit =
- mapQueue.rbegin();
-
- // We will treat the leaves differently (below).
- while ((*rit).first != INT_MIN)
- {
- // Get a reference to the current scale.
- std::vector<MapEntryType>& scaleVector = (*rit).second;
-
- // Before traversing all the points in this scale, sort by score.
- std::sort(scaleVector.begin(), scaleVector.end());
-
- // Now loop over each element.
- for (size_t i = 0; i < scaleVector.size(); ++i)
- {
- // Get a reference to the current element.
- const MapEntryType& frame = scaleVector.at(i);
-
- CoverTree<MetricType, RootPointPolicy, StatisticType>* node = frame.node;
- const double score = frame.score;
- const size_t parent = frame.parent;
- const size_t point = node->Point();
- double baseCase = frame.baseCase;
-
- // First we recalculate the score of this node to find if we can prune it.
- if (rule.Rescore(queryIndex, *node, score) == DBL_MAX)
- {
- ++numPrunes;
- continue;
- }
-
- // If we are a self-child, the base case has already been evaluated.
- if (point != parent)
- baseCase = rule.BaseCase(queryIndex, point);
-
- // Create the score for the children.
- const double childScore = rule.Score(queryIndex, *node, baseCase);
-
- // Now if this childScore is DBL_MAX we can prune all children. In this
- // recursion setup pruning is all or nothing for children.
- if (childScore == DBL_MAX)
- {
- numPrunes += node->NumChildren();
- continue;
- }
-
- // Don't add the self-leaf.
- size_t j = 0;
- if (node->Child(0).NumChildren() == 0)
- {
- ++numPrunes;
- j = 1;
- }
-
- for (/* j is already set. */; j < node->NumChildren(); ++j)
- {
- MapEntryType newFrame;
- newFrame.node = &node->Child(j);
- newFrame.score = childScore;
- newFrame.baseCase = baseCase;
- newFrame.parent = point;
-
- mapQueue[newFrame.node->Scale()].push_back(newFrame);
- }
- }
-
- // Now clear the memory for this scale; it isn't needed anymore.
- mapQueue.erase((*rit).first);
- }
-
- // Now deal with the leaves.
- for (size_t i = 0; i < mapQueue[INT_MIN].size(); ++i)
- {
- const MapEntryType& frame = mapQueue[INT_MIN].at(i);
-
- CoverTree<MetricType, RootPointPolicy, StatisticType>* node = frame.node;
- const double score = frame.score;
- const size_t point = node->Point();
-
- // First, recalculate the score of this node to find if we can prune it.
- double actualScore = rule.Rescore(queryIndex, *node, score);
-
- if (actualScore == DBL_MAX)
- {
- ++numPrunes;
- continue;
- }
-
- // There are no self-leaves; evaluate the base case.
- rule.BaseCase(queryIndex, point);
- }
-}
-
-}; // namespace tree
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,213 @@
+/**
+ * @file single_tree_traverser_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the single tree traverser for cover trees, which implements
+ * a breadth-first traversal.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_COVER_TREE_SINGLE_TREE_TRAVERSER_IMPL_HPP
+#define __MLPACK_CORE_TREE_COVER_TREE_SINGLE_TREE_TRAVERSER_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "single_tree_traverser.hpp"
+
+#include <queue>
+
+namespace mlpack {
+namespace tree {
+
+//! This is the structure the cover tree map will use for traversal.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+struct CoverTreeMapEntry
+{
+ //! The node this entry refers to.
+ CoverTree<MetricType, RootPointPolicy, StatisticType>* node;
+ //! The score of the node.
+ double score;
+ //! The index of the parent node.
+ size_t parent;
+ //! The base case evaluation.
+ double baseCase;
+
+ //! Comparison operator.
+ bool operator<(const CoverTreeMapEntry& other) const
+ {
+ return (score < other.score);
+ }
+};
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<typename RuleType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::
+SingleTreeTraverser<RuleType>::SingleTreeTraverser(RuleType& rule) :
+ rule(rule),
+ numPrunes(0)
+{ /* Nothing to do. */ }
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<typename RuleType>
+void CoverTree<MetricType, RootPointPolicy, StatisticType>::
+SingleTreeTraverser<RuleType>::Traverse(
+ const size_t queryIndex,
+ CoverTree<MetricType, RootPointPolicy, StatisticType>& referenceNode)
+{
+ // This is a non-recursive implementation (which should be faster than a
+ // recursive implementation).
+ typedef CoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
+ MapEntryType;
+
+ // We will use this map as a priority queue. Each key represents the scale,
+ // and then the vector is all the nodes in that scale which need to be
+ // investigated. Because no point in a scale can add a point in its own
+ // scale, we know that the vector for each scale is final when we get to it.
+ // In addition, map is organized in such a way that rbegin() will return the
+ // largest scale.
+ std::map<int, std::vector<MapEntryType> > mapQueue;
+
+ // Manually add the children of the first node. These cannot be pruned
+ // anyway.
+ double rootBaseCase = rule.BaseCase(queryIndex, referenceNode.Point());
+
+ // Create the score for the children.
+ double rootChildScore = rule.Score(queryIndex, referenceNode, rootBaseCase);
+
+ if (rootChildScore == DBL_MAX)
+ {
+ numPrunes += referenceNode.NumChildren();
+ }
+ else
+ {
+ // Don't add the self-leaf.
+ size_t i = 0;
+ if (referenceNode.Child(0).NumChildren() == 0)
+ {
+ ++numPrunes;
+ i = 1;
+ }
+
+ for (/* i was set above. */; i < referenceNode.NumChildren(); ++i)
+ {
+ MapEntryType newFrame;
+ newFrame.node = &referenceNode.Child(i);
+ newFrame.score = rootChildScore;
+ newFrame.baseCase = rootBaseCase;
+ newFrame.parent = referenceNode.Point();
+
+ // Put it into the map.
+ mapQueue[newFrame.node->Scale()].push_back(newFrame);
+ }
+ }
+
+ // Now begin the iteration through the map.
+ typename std::map<int, std::vector<MapEntryType> >::reverse_iterator rit =
+ mapQueue.rbegin();
+
+ // We will treat the leaves differently (below).
+ while ((*rit).first != INT_MIN)
+ {
+ // Get a reference to the current scale.
+ std::vector<MapEntryType>& scaleVector = (*rit).second;
+
+ // Before traversing all the points in this scale, sort by score.
+ std::sort(scaleVector.begin(), scaleVector.end());
+
+ // Now loop over each element.
+ for (size_t i = 0; i < scaleVector.size(); ++i)
+ {
+ // Get a reference to the current element.
+ const MapEntryType& frame = scaleVector.at(i);
+
+ CoverTree<MetricType, RootPointPolicy, StatisticType>* node = frame.node;
+ const double score = frame.score;
+ const size_t parent = frame.parent;
+ const size_t point = node->Point();
+ double baseCase = frame.baseCase;
+
+ // First we recalculate the score of this node to find if we can prune it.
+ if (rule.Rescore(queryIndex, *node, score) == DBL_MAX)
+ {
+ ++numPrunes;
+ continue;
+ }
+
+ // If we are a self-child, the base case has already been evaluated.
+ if (point != parent)
+ baseCase = rule.BaseCase(queryIndex, point);
+
+ // Create the score for the children.
+ const double childScore = rule.Score(queryIndex, *node, baseCase);
+
+ // Now if this childScore is DBL_MAX we can prune all children. In this
+ // recursion setup pruning is all or nothing for children.
+ if (childScore == DBL_MAX)
+ {
+ numPrunes += node->NumChildren();
+ continue;
+ }
+
+ // Don't add the self-leaf.
+ size_t j = 0;
+ if (node->Child(0).NumChildren() == 0)
+ {
+ ++numPrunes;
+ j = 1;
+ }
+
+ for (/* j is already set. */; j < node->NumChildren(); ++j)
+ {
+ MapEntryType newFrame;
+ newFrame.node = &node->Child(j);
+ newFrame.score = childScore;
+ newFrame.baseCase = baseCase;
+ newFrame.parent = point;
+
+ mapQueue[newFrame.node->Scale()].push_back(newFrame);
+ }
+ }
+
+ // Now clear the memory for this scale; it isn't needed anymore.
+ mapQueue.erase((*rit).first);
+ }
+
+ // Now deal with the leaves.
+ for (size_t i = 0; i < mapQueue[INT_MIN].size(); ++i)
+ {
+ const MapEntryType& frame = mapQueue[INT_MIN].at(i);
+
+ CoverTree<MetricType, RootPointPolicy, StatisticType>* node = frame.node;
+ const double score = frame.score;
+ const size_t point = node->Point();
+
+ // First, recalculate the score of this node to find if we can prune it.
+ double actualScore = rule.Rescore(queryIndex, *node, score);
+
+ if (actualScore == DBL_MAX)
+ {
+ ++numPrunes;
+ continue;
+ }
+
+ // There are no self-leaves; evaluate the base case.
+ rule.BaseCase(queryIndex, point);
+ }
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/traits.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/cover_tree/traits.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/traits.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,71 +0,0 @@
-/**
- * @file traits.hpp
- * @author Ryan Curtin
- *
- * This file contains the specialization of the TreeTraits class for the
- * CoverTree type of tree.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_COVER_TREE_TRAITS_HPP
-#define __MLPACK_CORE_TREE_COVER_TREE_TRAITS_HPP
-
-#include <mlpack/core/tree/tree_traits.hpp>
-
-namespace mlpack {
-namespace tree {
-
-/**
- * The specialization of the TreeTraits class for the CoverTree tree type. It
- * defines characteristics of the cover tree, and is used to help write
- * tree-independent (but still optimized) tree-based algorithms. See
- * mlpack/core/tree/tree_traits.hpp for more information.
- */
-template<typename MetricType,
- typename RootPointPolicy,
- typename StatisticType>
-class TreeTraits<CoverTree<MetricType, RootPointPolicy, StatisticType> >
-{
- public:
- /**
- * The cover tree calculates the distance between parent and child during
- * construction, so that value is saved and CoverTree<...>::ParentDistance()
- * does exist.
- */
- static const bool HasParentDistance = true;
-
- /**
- * The cover tree (or, this implementation of it) does not require that
- * children represent non-overlapping subsets of the parent node.
- */
- static const bool HasOverlappingChildren = true;
-
- /**
- * Each cover tree node contains only one point, and that point is its
- * centroid.
- */
- static const bool FirstPointIsCentroid = true;
-
- /**
- * Cover trees do have self-children.
- */
- static const bool HasSelfChildren = true;
-};
-
-}; // namespace tree
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/traits.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/cover_tree/traits.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/traits.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree/traits.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,71 @@
+/**
+ * @file traits.hpp
+ * @author Ryan Curtin
+ *
+ * This file contains the specialization of the TreeTraits class for the
+ * CoverTree type of tree.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_COVER_TREE_TRAITS_HPP
+#define __MLPACK_CORE_TREE_COVER_TREE_TRAITS_HPP
+
+#include <mlpack/core/tree/tree_traits.hpp>
+
+namespace mlpack {
+namespace tree {
+
+/**
+ * The specialization of the TreeTraits class for the CoverTree tree type. It
+ * defines characteristics of the cover tree, and is used to help write
+ * tree-independent (but still optimized) tree-based algorithms. See
+ * mlpack/core/tree/tree_traits.hpp for more information.
+ */
+template<typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType>
+class TreeTraits<CoverTree<MetricType, RootPointPolicy, StatisticType> >
+{
+ public:
+ /**
+ * The cover tree calculates the distance between parent and child during
+ * construction, so that value is saved and CoverTree<...>::ParentDistance()
+ * does exist.
+ */
+ static const bool HasParentDistance = true;
+
+ /**
+ * The cover tree (or, this implementation of it) does not require that
+ * children represent non-overlapping subsets of the parent node.
+ */
+ static const bool HasOverlappingChildren = true;
+
+ /**
+ * Each cover tree node contains only one point, and that point is its
+ * centroid.
+ */
+ static const bool FirstPointIsCentroid = true;
+
+ /**
+ * Cover trees do have self-children.
+ */
+ static const bool HasSelfChildren = true;
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/cover_tree.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,31 +0,0 @@
-/**
- * @file cover_tree.hpp
- * @author Ryan Curtin
- *
- * Includes all the necessary files to use the CoverTree class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_COVER_TREE_HPP
-#define __MLPACK_CORE_TREE_COVER_TREE_HPP
-
-#include "bounds.hpp"
-#include "cover_tree/cover_tree.hpp"
-#include "cover_tree/single_tree_traverser.hpp"
-#include "cover_tree/dual_tree_traverser.hpp"
-#include "cover_tree/traits.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/cover_tree.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/cover_tree.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,31 @@
+/**
+ * @file cover_tree.hpp
+ * @author Ryan Curtin
+ *
+ * Includes all the necessary files to use the CoverTree class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_COVER_TREE_HPP
+#define __MLPACK_CORE_TREE_COVER_TREE_HPP
+
+#include "bounds.hpp"
+#include "cover_tree/cover_tree.hpp"
+#include "cover_tree/single_tree_traverser.hpp"
+#include "cover_tree/dual_tree_traverser.hpp"
+#include "cover_tree/traits.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/hrectbound.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/hrectbound.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/hrectbound.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,187 +0,0 @@
-/**
- * @file hrectbound.hpp
- *
- * Bounds that are useful for binary space partitioning trees.
- *
- * This file describes the interface for the HRectBound class, which implements
- * a hyperrectangle bound.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_HRECTBOUND_HPP
-#define __MLPACK_CORE_TREE_HRECTBOUND_HPP
-
-#include <mlpack/core.hpp>
-#include <mlpack/core/math/range.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-
-namespace mlpack {
-namespace bound {
-
-/**
- * Hyper-rectangle bound for an L-metric. This should be used in conjunction
- * with the LMetric class. Be sure to use the same template parameters for
- * LMetric as you do for HRectBound -- otherwise odd results may occur.
- *
- * @tparam Power The metric to use; use 2 for Euclidean (L2).
- * @tparam TakeRoot Whether or not the root should be taken (see LMetric
- * documentation).
- */
-template<int Power = 2, bool TakeRoot = true>
-class HRectBound
-{
- public:
- //! This is the metric type that this bound is using.
- typedef metric::LMetric<Power, TakeRoot> MetricType;
-
- /**
- * Empty constructor; creates a bound of dimensionality 0.
- */
- HRectBound();
-
- /**
- * Initializes to specified dimensionality with each dimension the empty
- * set.
- */
- HRectBound(const size_t dimension);
-
- //! Copy constructor; necessary to prevent memory leaks.
- HRectBound(const HRectBound& other);
- //! Same as copy constructor; necessary to prevent memory leaks.
- HRectBound& operator=(const HRectBound& other);
-
- //! Destructor: clean up memory.
- ~HRectBound();
-
- /**
- * Resets all dimensions to the empty set (so that this bound contains
- * nothing).
- */
- void Clear();
-
- //! Gets the dimensionality.
- size_t Dim() const { return dim; }
-
- //! Get the range for a particular dimension. No bounds checking.
- math::Range& operator[](const size_t i) { return bounds[i]; }
- //! Modify the range for a particular dimension. No bounds checking.
- const math::Range& operator[](const size_t i) const { return bounds[i]; }
-
- /**
- * Calculates the centroid of the range, placing it into the given vector.
- *
- * @param centroid Vector which the centroid will be written to.
- */
- void Centroid(arma::vec& centroid) const;
-
- /**
- * Calculates minimum bound-to-point distance.
- *
- * @param point Point to which the minimum distance is requested.
- */
- template<typename VecType>
- double MinDistance(const VecType& point) const;
-
- /**
- * Calculates minimum bound-to-bound distance.
- *
- * @param other Bound to which the minimum distance is requested.
- */
- double MinDistance(const HRectBound& other) const;
-
- /**
- * Calculates maximum bound-to-point squared distance.
- *
- * @param point Point to which the maximum distance is requested.
- */
- template<typename VecType>
- double MaxDistance(const VecType& point) const;
-
- /**
- * Computes maximum distance.
- *
- * @param other Bound to which the maximum distance is requested.
- */
- double MaxDistance(const HRectBound& other) const;
-
- /**
- * Calculates minimum and maximum bound-to-bound distance.
- *
- * @param other Bound to which the minimum and maximum distances are
- * requested.
- */
- math::Range RangeDistance(const HRectBound& other) const;
-
- /**
- * Calculates minimum and maximum bound-to-point distance.
- *
- * @param point Point to which the minimum and maximum distances are
- * requested.
- */
- template<typename VecType>
- math::Range RangeDistance(const VecType& point) const;
-
- /**
- * Expands this region to include new points.
- *
- * @tparam MatType Type of matrix; could be Mat, SpMat, a subview, or just a
- * vector.
- * @param data Data points to expand this region to include.
- */
- template<typename MatType>
- HRectBound& operator|=(const MatType& data);
-
- /**
- * Expands this region to encompass another bound.
- */
- HRectBound& operator|=(const HRectBound& other);
-
- /**
- * Determines if a point is within this bound.
- */
- template<typename VecType>
- bool Contains(const VecType& point) const;
-
- /**
- * Returns the diameter of the hyperrectangle (that is, the longest diagonal).
- */
- double Diameter() const;
-
- /**
- * Returns a string representation of this object.
- */
- std::string ToString() const;
-
- /**
- * Return the metric associated with this bound. Because it is an LMetric, it
- * cannot store state, so we can make it on the fly. It is also static
- * because the metric is only dependent on the template arguments.
- */
- static MetricType Metric() { return metric::LMetric<Power, TakeRoot>(); }
-
- private:
- //! The dimensionality of the bound.
- size_t dim;
- //! The bounds for each dimension.
- math::Range* bounds;
-};
-
-}; // namespace bound
-}; // namespace mlpack
-
-#include "hrectbound_impl.hpp"
-
-#endif // __MLPACK_CORE_TREE_HRECTBOUND_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/hrectbound.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/hrectbound.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/hrectbound.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/hrectbound.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,187 @@
+/**
+ * @file hrectbound.hpp
+ *
+ * Bounds that are useful for binary space partitioning trees.
+ *
+ * This file describes the interface for the HRectBound class, which implements
+ * a hyperrectangle bound.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_HRECTBOUND_HPP
+#define __MLPACK_CORE_TREE_HRECTBOUND_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/math/range.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+
+namespace mlpack {
+namespace bound {
+
+/**
+ * Hyper-rectangle bound for an L-metric. This should be used in conjunction
+ * with the LMetric class. Be sure to use the same template parameters for
+ * LMetric as you do for HRectBound -- otherwise odd results may occur.
+ *
+ * @tparam Power The metric to use; use 2 for Euclidean (L2).
+ * @tparam TakeRoot Whether or not the root should be taken (see LMetric
+ * documentation).
+ */
+template<int Power = 2, bool TakeRoot = true>
+class HRectBound
+{
+ public:
+ //! This is the metric type that this bound is using.
+ typedef metric::LMetric<Power, TakeRoot> MetricType;
+
+ /**
+ * Empty constructor; creates a bound of dimensionality 0.
+ */
+ HRectBound();
+
+ /**
+ * Initializes to specified dimensionality with each dimension the empty
+ * set.
+ */
+ HRectBound(const size_t dimension);
+
+ //! Copy constructor; necessary to prevent memory leaks.
+ HRectBound(const HRectBound& other);
+ //! Same as copy constructor; necessary to prevent memory leaks.
+ HRectBound& operator=(const HRectBound& other);
+
+ //! Destructor: clean up memory.
+ ~HRectBound();
+
+ /**
+ * Resets all dimensions to the empty set (so that this bound contains
+ * nothing).
+ */
+ void Clear();
+
+ //! Gets the dimensionality.
+ size_t Dim() const { return dim; }
+
+ //! Get the range for a particular dimension. No bounds checking.
+ math::Range& operator[](const size_t i) { return bounds[i]; }
+ //! Modify the range for a particular dimension. No bounds checking.
+ const math::Range& operator[](const size_t i) const { return bounds[i]; }
+
+ /**
+ * Calculates the centroid of the range, placing it into the given vector.
+ *
+ * @param centroid Vector which the centroid will be written to.
+ */
+ void Centroid(arma::vec& centroid) const;
+
+ /**
+ * Calculates minimum bound-to-point distance.
+ *
+ * @param point Point to which the minimum distance is requested.
+ */
+ template<typename VecType>
+ double MinDistance(const VecType& point) const;
+
+ /**
+ * Calculates minimum bound-to-bound distance.
+ *
+ * @param other Bound to which the minimum distance is requested.
+ */
+ double MinDistance(const HRectBound& other) const;
+
+ /**
+ * Calculates maximum bound-to-point squared distance.
+ *
+ * @param point Point to which the maximum distance is requested.
+ */
+ template<typename VecType>
+ double MaxDistance(const VecType& point) const;
+
+ /**
+ * Computes maximum distance.
+ *
+ * @param other Bound to which the maximum distance is requested.
+ */
+ double MaxDistance(const HRectBound& other) const;
+
+ /**
+ * Calculates minimum and maximum bound-to-bound distance.
+ *
+ * @param other Bound to which the minimum and maximum distances are
+ * requested.
+ */
+ math::Range RangeDistance(const HRectBound& other) const;
+
+ /**
+ * Calculates minimum and maximum bound-to-point distance.
+ *
+ * @param point Point to which the minimum and maximum distances are
+ * requested.
+ */
+ template<typename VecType>
+ math::Range RangeDistance(const VecType& point) const;
+
+ /**
+ * Expands this region to include new points.
+ *
+ * @tparam MatType Type of matrix; could be Mat, SpMat, a subview, or just a
+ * vector.
+ * @param data Data points to expand this region to include.
+ */
+ template<typename MatType>
+ HRectBound& operator|=(const MatType& data);
+
+ /**
+ * Expands this region to encompass another bound.
+ */
+ HRectBound& operator|=(const HRectBound& other);
+
+ /**
+ * Determines if a point is within this bound.
+ */
+ template<typename VecType>
+ bool Contains(const VecType& point) const;
+
+ /**
+ * Returns the diameter of the hyperrectangle (that is, the longest diagonal).
+ */
+ double Diameter() const;
+
+ /**
+ * Returns a string representation of this object.
+ */
+ std::string ToString() const;
+
+ /**
+ * Return the metric associated with this bound. Because it is an LMetric, it
+ * cannot store state, so we can make it on the fly. It is also static
+ * because the metric is only dependent on the template arguments.
+ */
+ static MetricType Metric() { return metric::LMetric<Power, TakeRoot>(); }
+
+ private:
+ //! The dimensionality of the bound.
+ size_t dim;
+ //! The bounds for each dimension.
+ math::Range* bounds;
+};
+
+}; // namespace bound
+}; // namespace mlpack
+
+#include "hrectbound_impl.hpp"
+
+#endif // __MLPACK_CORE_TREE_HRECTBOUND_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/hrectbound_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/hrectbound_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/hrectbound_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,418 +0,0 @@
-/**
- * @file hrectbound_impl.hpp
- *
- * Implementation of hyper-rectangle bound policy class.
- * Template parameter Power is the metric to use; use 2 for Euclidean (L2).
- *
- * @experimental
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_HRECTBOUND_IMPL_HPP
-#define __MLPACK_CORE_TREE_HRECTBOUND_IMPL_HPP
-
-#include <math.h>
-
-// In case it has not been included yet.
-#include "hrectbound.hpp"
-
-namespace mlpack {
-namespace bound {
-
-/**
- * Empty constructor.
- */
-template<int Power, bool TakeRoot>
-HRectBound<Power, TakeRoot>::HRectBound() :
- dim(0),
- bounds(NULL)
-{ /* Nothing to do. */ }
-
-/**
- * Initializes to specified dimensionality with each dimension the empty
- * set.
- */
-template<int Power, bool TakeRoot>
-HRectBound<Power, TakeRoot>::HRectBound(const size_t dimension) :
- dim(dimension),
- bounds(new math::Range[dim])
-{ /* Nothing to do. */ }
-
-/***
- * Copy constructor necessary to prevent memory leaks.
- */
-template<int Power, bool TakeRoot>
-HRectBound<Power, TakeRoot>::HRectBound(const HRectBound& other) :
- dim(other.Dim()),
- bounds(new math::Range[dim])
-{
- // Copy other bounds over.
- for (size_t i = 0; i < dim; i++)
- bounds[i] = other[i];
-}
-
-/***
- * Same as the copy constructor.
- */
-template<int Power, bool TakeRoot>
-HRectBound<Power, TakeRoot>& HRectBound<Power, TakeRoot>::operator=(
- const HRectBound& other)
-{
- if (dim != other.Dim())
- {
- // Reallocation is necessary.
- if (bounds)
- delete[] bounds;
-
- dim = other.Dim();
- bounds = new math::Range[dim];
- }
-
- // Now copy each of the bound values.
- for (size_t i = 0; i < dim; i++)
- bounds[i] = other[i];
-
- return *this;
-}
-
-/**
- * Destructor: clean up memory.
- */
-template<int Power, bool TakeRoot>
-HRectBound<Power, TakeRoot>::~HRectBound()
-{
- if (bounds)
- delete[] bounds;
-}
-
-/**
- * Resets all dimensions to the empty set.
- */
-template<int Power, bool TakeRoot>
-void HRectBound<Power, TakeRoot>::Clear()
-{
- for (size_t i = 0; i < dim; i++)
- bounds[i] = math::Range();
-}
-
-/***
- * Calculates the centroid of the range, placing it into the given vector.
- *
- * @param centroid Vector which the centroid will be written to.
- */
-template<int Power, bool TakeRoot>
-void HRectBound<Power, TakeRoot>::Centroid(arma::vec& centroid) const
-{
- // Set size correctly if necessary.
- if (!(centroid.n_elem == dim))
- centroid.set_size(dim);
-
- for (size_t i = 0; i < dim; i++)
- centroid(i) = bounds[i].Mid();
-}
-
-/**
- * Calculates minimum bound-to-point squared distance.
- */
-template<int Power, bool TakeRoot>
-template<typename VecType>
-double HRectBound<Power, TakeRoot>::MinDistance(const VecType& point) const
-{
- Log::Assert(point.n_elem == dim);
-
- double sum = 0;
-
- double lower, higher;
- for (size_t d = 0; d < dim; d++)
- {
- lower = bounds[d].Lo() - point[d];
- higher = point[d] - bounds[d].Hi();
-
- // Since only one of 'lower' or 'higher' is negative, if we add each's
- // absolute value to itself and then sum those two, our result is the
- // nonnegative half of the equation times two; then we raise to power Power.
- sum += pow((lower + fabs(lower)) + (higher + fabs(higher)), (double) Power);
- }
-
- // Now take the Power'th root (but make sure our result is squared if it needs
- // to be); then cancel out the constant of 2 (which may have been squared now)
- // that was introduced earlier. The compiler should optimize out the if
- // statement entirely.
- if (TakeRoot)
- return pow(sum, 1.0 / (double) Power) / 2.0;
- else
- return sum / pow(2.0, Power);
-}
-
-/**
- * Calculates minimum bound-to-bound squared distance.
- */
-template<int Power, bool TakeRoot>
-double HRectBound<Power, TakeRoot>::MinDistance(const HRectBound& other) const
-{
- Log::Assert(dim == other.dim);
-
- double sum = 0;
- const math::Range* mbound = bounds;
- const math::Range* obound = other.bounds;
-
- double lower, higher;
- for (size_t d = 0; d < dim; d++)
- {
- lower = obound->Lo() - mbound->Hi();
- higher = mbound->Lo() - obound->Hi();
- // We invoke the following:
- // x + fabs(x) = max(x * 2, 0)
- // (x * 2)^2 / 4 = x^2
- sum += pow((lower + fabs(lower)) + (higher + fabs(higher)), (double) Power);
-
- // Move bound pointers.
- mbound++;
- obound++;
- }
-
- // The compiler should optimize out this if statement entirely.
- if (TakeRoot)
- return pow(sum, 1.0 / (double) Power) / 2.0;
- else
- return sum / pow(2.0, Power);
-}
-
-/**
- * Calculates maximum bound-to-point squared distance.
- */
-template<int Power, bool TakeRoot>
-template<typename VecType>
-double HRectBound<Power, TakeRoot>::MaxDistance(const VecType& point) const
-{
- double sum = 0;
-
- Log::Assert(point.n_elem == dim);
-
- for (size_t d = 0; d < dim; d++)
- {
- double v = std::max(fabs(point[d] - bounds[d].Lo()),
- fabs(bounds[d].Hi() - point[d]));
- sum += pow(v, (double) Power);
- }
-
- // The compiler should optimize out this if statement entirely.
- if (TakeRoot)
- return pow(sum, 1.0 / (double) Power);
- else
- return sum;
-}
-
-/**
- * Computes maximum distance.
- */
-template<int Power, bool TakeRoot>
-double HRectBound<Power, TakeRoot>::MaxDistance(const HRectBound& other) const
-{
- double sum = 0;
-
- Log::Assert(dim == other.dim);
-
- double v;
- for (size_t d = 0; d < dim; d++)
- {
- v = std::max(fabs(other.bounds[d].Hi() - bounds[d].Lo()),
- fabs(bounds[d].Hi() - other.bounds[d].Lo()));
- sum += pow(v, (double) Power); // v is non-negative.
- }
-
- // The compiler should optimize out this if statement entirely.
- if (TakeRoot)
- return pow(sum, 1.0 / (double) Power);
- else
- return sum;
-}
-
-/**
- * Calculates minimum and maximum bound-to-bound squared distance.
- */
-template<int Power, bool TakeRoot>
-math::Range HRectBound<Power, TakeRoot>::RangeDistance(const HRectBound& other)
- const
-{
- double loSum = 0;
- double hiSum = 0;
-
- Log::Assert(dim == other.dim);
-
- double v1, v2, vLo, vHi;
- for (size_t d = 0; d < dim; d++)
- {
- v1 = other.bounds[d].Lo() - bounds[d].Hi();
- v2 = bounds[d].Lo() - other.bounds[d].Hi();
- // One of v1 or v2 is negative.
- if (v1 >= v2)
- {
- vHi = -v2; // Make it nonnegative.
- vLo = (v1 > 0) ? v1 : 0; // Force to be 0 if negative.
- }
- else
- {
- vHi = -v1; // Make it nonnegative.
- vLo = (v2 > 0) ? v2 : 0; // Force to be 0 if negative.
- }
-
- loSum += pow(vLo, (double) Power);
- hiSum += pow(vHi, (double) Power);
- }
-
- if (TakeRoot)
- return math::Range(pow(loSum, 1.0 / (double) Power),
- pow(hiSum, 1.0 / (double) Power));
- else
- return math::Range(loSum, hiSum);
-}
-
-/**
- * Calculates minimum and maximum bound-to-point squared distance.
- */
-template<int Power, bool TakeRoot>
-template<typename VecType>
-math::Range HRectBound<Power, TakeRoot>::RangeDistance(const VecType& point)
- const
-{
- double loSum = 0;
- double hiSum = 0;
-
- Log::Assert(point.n_elem == dim);
-
- double v1, v2, vLo, vHi;
- for (size_t d = 0; d < dim; d++)
- {
- v1 = bounds[d].Lo() - point[d]; // Negative if point[d] > lo.
- v2 = point[d] - bounds[d].Hi(); // Negative if point[d] < hi.
- // One of v1 or v2 (or both) is negative.
- if (v1 >= 0) // point[d] <= bounds_[d].Lo().
- {
- vHi = -v2; // v2 will be larger but must be negated.
- vLo = v1;
- }
- else // point[d] is between lo and hi, or greater than hi.
- {
- if (v2 >= 0)
- {
- vHi = -v1; // v1 will be larger, but must be negated.
- vLo = v2;
- }
- else
- {
- vHi = -std::min(v1, v2); // Both are negative, but we need the larger.
- vLo = 0;
- }
- }
-
- loSum += pow(vLo, (double) Power);
- hiSum += pow(vHi, (double) Power);
- }
-
- if (TakeRoot)
- return math::Range(pow(loSum, 1.0 / (double) Power),
- pow(hiSum, 1.0 / (double) Power));
- else
- return math::Range(loSum, hiSum);
-}
-
-/**
- * Expands this region to include a new point.
- */
-template<int Power, bool TakeRoot>
-template<typename MatType>
-HRectBound<Power, TakeRoot>& HRectBound<Power, TakeRoot>::operator|=(
- const MatType& data)
-{
- Log::Assert(data.n_rows == dim);
-
- arma::vec mins(min(data, 1));
- arma::vec maxs(max(data, 1));
-
- for (size_t i = 0; i < dim; i++)
- bounds[i] |= math::Range(mins[i], maxs[i]);
-
- return *this;
-}
-
-/**
- * Expands this region to encompass another bound.
- */
-template<int Power, bool TakeRoot>
-HRectBound<Power, TakeRoot>& HRectBound<Power, TakeRoot>::operator|=(
- const HRectBound& other)
-{
- assert(other.dim == dim);
-
- for (size_t i = 0; i < dim; i++)
- bounds[i] |= other.bounds[i];
-
- return *this;
-}
-
-/**
- * Determines if a point is within this bound.
- */
-template<int Power, bool TakeRoot>
-template<typename VecType>
-bool HRectBound<Power, TakeRoot>::Contains(const VecType& point) const
-{
- for (size_t i = 0; i < point.n_elem; i++)
- {
- if (!bounds[i].Contains(point(i)))
- return false;
- }
-
- return true;
-}
-
-/**
- * Returns the diameter of the hyperrectangle (that is, the longest diagonal).
- */
-template<int Power, bool TakeRoot>
-double HRectBound<Power, TakeRoot>::Diameter() const
-{
- double d = 0;
- for (size_t i = 0; i < dim; ++i)
- d += std::pow(bounds[i].Hi() - bounds[i].Lo(), (double) Power);
-
- if (TakeRoot)
- return std::pow(d, 1.0 / (double) Power);
- else
- return d;
-}
-
-/**
- * Returns a string representation of this object.
- */
-template<int Power, bool TakeRoot>
-std::string HRectBound<Power, TakeRoot>::ToString() const
-{
- std::ostringstream convert;
- convert << "HRectBound [" << this << "]" << std::endl;
- convert << "dim: " << dim << std::endl;
- convert << "bounds: " << std::endl;
- for (size_t i = 0; i < dim; ++i)
- convert << util::Indent(bounds[i].ToString()) << std::endl;
-
- return convert.str();
-}
-
-}; // namespace bound
-}; // namespace mlpack
-
-#endif // __MLPACK_CORE_TREE_HRECTBOUND_IMPL_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/hrectbound_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/hrectbound_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/hrectbound_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/hrectbound_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,418 @@
+/**
+ * @file hrectbound_impl.hpp
+ *
+ * Implementation of hyper-rectangle bound policy class.
+ * Template parameter Power is the metric to use; use 2 for Euclidean (L2).
+ *
+ * @experimental
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_HRECTBOUND_IMPL_HPP
+#define __MLPACK_CORE_TREE_HRECTBOUND_IMPL_HPP
+
+#include <math.h>
+
+// In case it has not been included yet.
+#include "hrectbound.hpp"
+
+namespace mlpack {
+namespace bound {
+
+/**
+ * Empty constructor.
+ */
+template<int Power, bool TakeRoot>
+HRectBound<Power, TakeRoot>::HRectBound() :
+ dim(0),
+ bounds(NULL)
+{ /* Nothing to do. */ }
+
+/**
+ * Initializes to specified dimensionality with each dimension the empty
+ * set.
+ */
+template<int Power, bool TakeRoot>
+HRectBound<Power, TakeRoot>::HRectBound(const size_t dimension) :
+ dim(dimension),
+ bounds(new math::Range[dim])
+{ /* Nothing to do. */ }
+
+/***
+ * Copy constructor necessary to prevent memory leaks.
+ */
+template<int Power, bool TakeRoot>
+HRectBound<Power, TakeRoot>::HRectBound(const HRectBound& other) :
+ dim(other.Dim()),
+ bounds(new math::Range[dim])
+{
+ // Copy other bounds over.
+ for (size_t i = 0; i < dim; i++)
+ bounds[i] = other[i];
+}
+
+/***
+ * Same as the copy constructor.
+ */
+template<int Power, bool TakeRoot>
+HRectBound<Power, TakeRoot>& HRectBound<Power, TakeRoot>::operator=(
+ const HRectBound& other)
+{
+ if (dim != other.Dim())
+ {
+ // Reallocation is necessary.
+ if (bounds)
+ delete[] bounds;
+
+ dim = other.Dim();
+ bounds = new math::Range[dim];
+ }
+
+ // Now copy each of the bound values.
+ for (size_t i = 0; i < dim; i++)
+ bounds[i] = other[i];
+
+ return *this;
+}
+
+/**
+ * Destructor: clean up memory.
+ */
+template<int Power, bool TakeRoot>
+HRectBound<Power, TakeRoot>::~HRectBound()
+{
+ if (bounds)
+ delete[] bounds;
+}
+
+/**
+ * Resets all dimensions to the empty set.
+ */
+template<int Power, bool TakeRoot>
+void HRectBound<Power, TakeRoot>::Clear()
+{
+ for (size_t i = 0; i < dim; i++)
+ bounds[i] = math::Range();
+}
+
+/***
+ * Calculates the centroid of the range, placing it into the given vector.
+ *
+ * @param centroid Vector which the centroid will be written to.
+ */
+template<int Power, bool TakeRoot>
+void HRectBound<Power, TakeRoot>::Centroid(arma::vec& centroid) const
+{
+ // Set size correctly if necessary.
+ if (!(centroid.n_elem == dim))
+ centroid.set_size(dim);
+
+ for (size_t i = 0; i < dim; i++)
+ centroid(i) = bounds[i].Mid();
+}
+
+/**
+ * Calculates minimum bound-to-point squared distance.
+ */
+template<int Power, bool TakeRoot>
+template<typename VecType>
+double HRectBound<Power, TakeRoot>::MinDistance(const VecType& point) const
+{
+ Log::Assert(point.n_elem == dim);
+
+ double sum = 0;
+
+ double lower, higher;
+ for (size_t d = 0; d < dim; d++)
+ {
+ lower = bounds[d].Lo() - point[d];
+ higher = point[d] - bounds[d].Hi();
+
+ // Since only one of 'lower' or 'higher' is negative, if we add each's
+ // absolute value to itself and then sum those two, our result is the
+ // nonnegative half of the equation times two; then we raise to power Power.
+ sum += pow((lower + fabs(lower)) + (higher + fabs(higher)), (double) Power);
+ }
+
+ // Now take the Power'th root (but make sure our result is squared if it needs
+ // to be); then cancel out the constant of 2 (which may have been squared now)
+ // that was introduced earlier. The compiler should optimize out the if
+ // statement entirely.
+ if (TakeRoot)
+ return pow(sum, 1.0 / (double) Power) / 2.0;
+ else
+ return sum / pow(2.0, Power);
+}
+
+/**
+ * Calculates minimum bound-to-bound squared distance.
+ */
+template<int Power, bool TakeRoot>
+double HRectBound<Power, TakeRoot>::MinDistance(const HRectBound& other) const
+{
+ Log::Assert(dim == other.dim);
+
+ double sum = 0;
+ const math::Range* mbound = bounds;
+ const math::Range* obound = other.bounds;
+
+ double lower, higher;
+ for (size_t d = 0; d < dim; d++)
+ {
+ lower = obound->Lo() - mbound->Hi();
+ higher = mbound->Lo() - obound->Hi();
+ // We invoke the following:
+ // x + fabs(x) = max(x * 2, 0)
+ // (x * 2)^2 / 4 = x^2
+ sum += pow((lower + fabs(lower)) + (higher + fabs(higher)), (double) Power);
+
+ // Move bound pointers.
+ mbound++;
+ obound++;
+ }
+
+ // The compiler should optimize out this if statement entirely.
+ if (TakeRoot)
+ return pow(sum, 1.0 / (double) Power) / 2.0;
+ else
+ return sum / pow(2.0, Power);
+}
+
+/**
+ * Calculates maximum bound-to-point squared distance.
+ */
+template<int Power, bool TakeRoot>
+template<typename VecType>
+double HRectBound<Power, TakeRoot>::MaxDistance(const VecType& point) const
+{
+ double sum = 0;
+
+ Log::Assert(point.n_elem == dim);
+
+ for (size_t d = 0; d < dim; d++)
+ {
+ double v = std::max(fabs(point[d] - bounds[d].Lo()),
+ fabs(bounds[d].Hi() - point[d]));
+ sum += pow(v, (double) Power);
+ }
+
+ // The compiler should optimize out this if statement entirely.
+ if (TakeRoot)
+ return pow(sum, 1.0 / (double) Power);
+ else
+ return sum;
+}
+
+/**
+ * Computes maximum distance.
+ */
+template<int Power, bool TakeRoot>
+double HRectBound<Power, TakeRoot>::MaxDistance(const HRectBound& other) const
+{
+ double sum = 0;
+
+ Log::Assert(dim == other.dim);
+
+ double v;
+ for (size_t d = 0; d < dim; d++)
+ {
+ v = std::max(fabs(other.bounds[d].Hi() - bounds[d].Lo()),
+ fabs(bounds[d].Hi() - other.bounds[d].Lo()));
+ sum += pow(v, (double) Power); // v is non-negative.
+ }
+
+ // The compiler should optimize out this if statement entirely.
+ if (TakeRoot)
+ return pow(sum, 1.0 / (double) Power);
+ else
+ return sum;
+}
+
+/**
+ * Calculates minimum and maximum bound-to-bound squared distance.
+ */
+template<int Power, bool TakeRoot>
+math::Range HRectBound<Power, TakeRoot>::RangeDistance(const HRectBound& other)
+ const
+{
+ double loSum = 0;
+ double hiSum = 0;
+
+ Log::Assert(dim == other.dim);
+
+ double v1, v2, vLo, vHi;
+ for (size_t d = 0; d < dim; d++)
+ {
+ v1 = other.bounds[d].Lo() - bounds[d].Hi();
+ v2 = bounds[d].Lo() - other.bounds[d].Hi();
+ // One of v1 or v2 is negative.
+ if (v1 >= v2)
+ {
+ vHi = -v2; // Make it nonnegative.
+ vLo = (v1 > 0) ? v1 : 0; // Force to be 0 if negative.
+ }
+ else
+ {
+ vHi = -v1; // Make it nonnegative.
+ vLo = (v2 > 0) ? v2 : 0; // Force to be 0 if negative.
+ }
+
+ loSum += pow(vLo, (double) Power);
+ hiSum += pow(vHi, (double) Power);
+ }
+
+ if (TakeRoot)
+ return math::Range(pow(loSum, 1.0 / (double) Power),
+ pow(hiSum, 1.0 / (double) Power));
+ else
+ return math::Range(loSum, hiSum);
+}
+
+/**
+ * Calculates minimum and maximum bound-to-point squared distance.
+ */
+template<int Power, bool TakeRoot>
+template<typename VecType>
+math::Range HRectBound<Power, TakeRoot>::RangeDistance(const VecType& point)
+ const
+{
+ double loSum = 0;
+ double hiSum = 0;
+
+ Log::Assert(point.n_elem == dim);
+
+ double v1, v2, vLo, vHi;
+ for (size_t d = 0; d < dim; d++)
+ {
+ v1 = bounds[d].Lo() - point[d]; // Negative if point[d] > lo.
+ v2 = point[d] - bounds[d].Hi(); // Negative if point[d] < hi.
+ // One of v1 or v2 (or both) is negative.
+ if (v1 >= 0) // point[d] <= bounds_[d].Lo().
+ {
+ vHi = -v2; // v2 will be larger but must be negated.
+ vLo = v1;
+ }
+ else // point[d] is between lo and hi, or greater than hi.
+ {
+ if (v2 >= 0)
+ {
+ vHi = -v1; // v1 will be larger, but must be negated.
+ vLo = v2;
+ }
+ else
+ {
+ vHi = -std::min(v1, v2); // Both are negative, but we need the larger.
+ vLo = 0;
+ }
+ }
+
+ loSum += pow(vLo, (double) Power);
+ hiSum += pow(vHi, (double) Power);
+ }
+
+ if (TakeRoot)
+ return math::Range(pow(loSum, 1.0 / (double) Power),
+ pow(hiSum, 1.0 / (double) Power));
+ else
+ return math::Range(loSum, hiSum);
+}
+
+/**
+ * Expands this region to include a new point.
+ */
+template<int Power, bool TakeRoot>
+template<typename MatType>
+HRectBound<Power, TakeRoot>& HRectBound<Power, TakeRoot>::operator|=(
+ const MatType& data)
+{
+ Log::Assert(data.n_rows == dim);
+
+ arma::vec mins(min(data, 1));
+ arma::vec maxs(max(data, 1));
+
+ for (size_t i = 0; i < dim; i++)
+ bounds[i] |= math::Range(mins[i], maxs[i]);
+
+ return *this;
+}
+
+/**
+ * Expands this region to encompass another bound.
+ */
+template<int Power, bool TakeRoot>
+HRectBound<Power, TakeRoot>& HRectBound<Power, TakeRoot>::operator|=(
+ const HRectBound& other)
+{
+ assert(other.dim == dim);
+
+ for (size_t i = 0; i < dim; i++)
+ bounds[i] |= other.bounds[i];
+
+ return *this;
+}
+
+/**
+ * Determines if a point is within this bound.
+ */
+template<int Power, bool TakeRoot>
+template<typename VecType>
+bool HRectBound<Power, TakeRoot>::Contains(const VecType& point) const
+{
+ for (size_t i = 0; i < point.n_elem; i++)
+ {
+ if (!bounds[i].Contains(point(i)))
+ return false;
+ }
+
+ return true;
+}
+
+/**
+ * Returns the diameter of the hyperrectangle (that is, the longest diagonal).
+ */
+template<int Power, bool TakeRoot>
+double HRectBound<Power, TakeRoot>::Diameter() const
+{
+ double d = 0;
+ for (size_t i = 0; i < dim; ++i)
+ d += std::pow(bounds[i].Hi() - bounds[i].Lo(), (double) Power);
+
+ if (TakeRoot)
+ return std::pow(d, 1.0 / (double) Power);
+ else
+ return d;
+}
+
+/**
+ * Returns a string representation of this object.
+ */
+template<int Power, bool TakeRoot>
+std::string HRectBound<Power, TakeRoot>::ToString() const
+{
+ std::ostringstream convert;
+ convert << "HRectBound [" << this << "]" << std::endl;
+ convert << "dim: " << dim << std::endl;
+ convert << "bounds: " << std::endl;
+ for (size_t i = 0; i < dim; ++i)
+ convert << util::Indent(bounds[i].ToString()) << std::endl;
+
+ return convert.str();
+}
+
+}; // namespace bound
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_TREE_HRECTBOUND_IMPL_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/mrkd_statistic.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/mrkd_statistic.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/mrkd_statistic.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,58 +0,0 @@
-/**
- * @file mrkd_statistic.cpp
- * @author James Cline
- *
- * Definition of the statistic for multi-resolution kd-trees.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "mrkd_statistic.hpp"
-
-using namespace mlpack;
-using namespace mlpack::tree;
-
-MRKDStatistic::MRKDStatistic() :
- dataset(NULL),
- begin(0),
- count(0),
- leftStat(NULL),
- rightStat(NULL),
- parentStat(NULL)
-{ }
-
-/**
- * Returns a string representation of this object.
- */
-std::string MRKDStatistic::ToString() const
-{
- std::ostringstream convert;
-
- convert << "MRKDStatistic [" << this << std::endl;
- convert << "begin: " << begin << std::endl;
- convert << "count: " << count << std::endl;
- convert << "sumOfSquaredNorms: " << sumOfSquaredNorms << std::endl;
- if (leftStat != NULL)
- {
- convert << "leftStat:" << std::endl;
- convert << mlpack::util::Indent(leftStat->ToString());
- }
- if (rightStat != NULL)
- {
- convert << "rightStat:" << std::endl;
- convert << mlpack::util::Indent(rightStat->ToString());
- }
- return convert.str();
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/mrkd_statistic.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/mrkd_statistic.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/mrkd_statistic.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/mrkd_statistic.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,58 @@
+/**
+ * @file mrkd_statistic.cpp
+ * @author James Cline
+ *
+ * Definition of the statistic for multi-resolution kd-trees.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "mrkd_statistic.hpp"
+
+using namespace mlpack;
+using namespace mlpack::tree;
+
+MRKDStatistic::MRKDStatistic() :
+ dataset(NULL),
+ begin(0),
+ count(0),
+ leftStat(NULL),
+ rightStat(NULL),
+ parentStat(NULL)
+{ }
+
+/**
+ * Returns a string representation of this object.
+ */
+std::string MRKDStatistic::ToString() const
+{
+ std::ostringstream convert;
+
+ convert << "MRKDStatistic [" << this << std::endl;
+ convert << "begin: " << begin << std::endl;
+ convert << "count: " << count << std::endl;
+ convert << "sumOfSquaredNorms: " << sumOfSquaredNorms << std::endl;
+ if (leftStat != NULL)
+ {
+ convert << "leftStat:" << std::endl;
+ convert << mlpack::util::Indent(leftStat->ToString());
+ }
+ if (rightStat != NULL)
+ {
+ convert << "rightStat:" << std::endl;
+ convert << mlpack::util::Indent(rightStat->ToString());
+ }
+ return convert.str();
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/mrkd_statistic.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/mrkd_statistic.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/mrkd_statistic.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,113 +0,0 @@
-/**
- * @file mrkd_statistic.hpp
- * @author James Cline
- *
- * Definition of the statistic for multi-resolution kd-trees.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_MRKD_STATISTIC_HPP
-#define __MLPACK_CORE_TREE_MRKD_STATISTIC_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace tree {
-
-/**
- * Statistic for multi-resolution kd-trees.
- */
-class MRKDStatistic
-{
- public:
- //! Initialize an empty statistic.
- MRKDStatistic();
-
- /**
- * This constructor is called when a node is finished initializing.
- *
- * @param node The node that has been finished.
- */
- template<typename TreeType>
- MRKDStatistic(const TreeType& /* node */);
-
- /**
- * Returns a string representation of this object.
- */
- std::string ToString() const;
-
- //! Get the index of the initial item in the dataset.
- size_t Begin() const { return begin; }
- //! Modify the index of the initial item in the dataset.
- size_t& Begin() { return begin; }
-
- //! Get the number of items in the dataset.
- size_t Count() const { return count; }
- //! Modify the number of items in the dataset.
- size_t& Count() { return count; }
-
- //! Get the center of mass.
- const arma::colvec& CenterOfMass() const { return centerOfMass; }
- //! Modify the center of mass.
- arma::colvec& CenterOfMass() { return centerOfMass; }
-
- //! Get the index of the dominating centroid.
- size_t DominatingCentroid() const { return dominatingCentroid; }
- //! Modify the index of the dominating centroid.
- size_t& DominatingCentroid() { return dominatingCentroid; }
-
- //! Access the whitelist.
- const std::vector<size_t>& Whitelist() const { return whitelist; }
- //! Modify the whitelist.
- std::vector<size_t>& Whitelist() { return whitelist; }
-
- private:
- //! The data points this object contains.
- const arma::mat* dataset;
- //! The initial item in the dataset, so we don't have to make a copy.
- size_t begin;
- //! The number of items in the dataset.
- size_t count;
- //! The left child.
- const MRKDStatistic* leftStat;
- //! The right child.
- const MRKDStatistic* rightStat;
- //! A link to the parent node; NULL if this is the root.
- const MRKDStatistic* parentStat;
-
- // Computed statistics.
- //! The center of mass for this dataset.
- arma::colvec centerOfMass;
- //! The sum of the squared Euclidean norms for this dataset.
- double sumOfSquaredNorms;
-
- // There may be a better place to store this -- HRectBound?
- //! The index of the dominating centroid of the associated hyperrectangle.
- size_t dominatingCentroid;
-
- //! The list of centroids that cannot own this hyperrectangle.
- std::vector<size_t> whitelist;
- //! Whether or not the whitelist is valid.
- bool isWhitelistValid;
-};
-
-}; // namespace tree
-}; // namespace mlpack
-
-// Include implementation.
-#include "mrkd_statistic_impl.hpp"
-
-#endif // __MLPACK_CORE_TREE_MRKD_STATISTIC_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/mrkd_statistic.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/mrkd_statistic.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/mrkd_statistic.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/mrkd_statistic.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,113 @@
+/**
+ * @file mrkd_statistic.hpp
+ * @author James Cline
+ *
+ * Definition of the statistic for multi-resolution kd-trees.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_MRKD_STATISTIC_HPP
+#define __MLPACK_CORE_TREE_MRKD_STATISTIC_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace tree {
+
+/**
+ * Statistic for multi-resolution kd-trees.
+ */
+class MRKDStatistic
+{
+ public:
+ //! Initialize an empty statistic.
+ MRKDStatistic();
+
+ /**
+ * This constructor is called when a node is finished initializing.
+ *
+ * @param node The node that has been finished.
+ */
+ template<typename TreeType>
+ MRKDStatistic(const TreeType& /* node */);
+
+ /**
+ * Returns a string representation of this object.
+ */
+ std::string ToString() const;
+
+ //! Get the index of the initial item in the dataset.
+ size_t Begin() const { return begin; }
+ //! Modify the index of the initial item in the dataset.
+ size_t& Begin() { return begin; }
+
+ //! Get the number of items in the dataset.
+ size_t Count() const { return count; }
+ //! Modify the number of items in the dataset.
+ size_t& Count() { return count; }
+
+ //! Get the center of mass.
+ const arma::colvec& CenterOfMass() const { return centerOfMass; }
+ //! Modify the center of mass.
+ arma::colvec& CenterOfMass() { return centerOfMass; }
+
+ //! Get the index of the dominating centroid.
+ size_t DominatingCentroid() const { return dominatingCentroid; }
+ //! Modify the index of the dominating centroid.
+ size_t& DominatingCentroid() { return dominatingCentroid; }
+
+ //! Access the whitelist.
+ const std::vector<size_t>& Whitelist() const { return whitelist; }
+ //! Modify the whitelist.
+ std::vector<size_t>& Whitelist() { return whitelist; }
+
+ private:
+ //! The data points this object contains.
+ const arma::mat* dataset;
+ //! The initial item in the dataset, so we don't have to make a copy.
+ size_t begin;
+ //! The number of items in the dataset.
+ size_t count;
+ //! The left child.
+ const MRKDStatistic* leftStat;
+ //! The right child.
+ const MRKDStatistic* rightStat;
+ //! A link to the parent node; NULL if this is the root.
+ const MRKDStatistic* parentStat;
+
+ // Computed statistics.
+ //! The center of mass for this dataset.
+ arma::colvec centerOfMass;
+ //! The sum of the squared Euclidean norms for this dataset.
+ double sumOfSquaredNorms;
+
+ // There may be a better place to store this -- HRectBound?
+ //! The index of the dominating centroid of the associated hyperrectangle.
+ size_t dominatingCentroid;
+
+ //! The list of centroids that cannot own this hyperrectangle.
+ std::vector<size_t> whitelist;
+ //! Whether or not the whitelist is valid.
+ bool isWhitelistValid;
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+// Include implementation.
+#include "mrkd_statistic_impl.hpp"
+
+#endif // __MLPACK_CORE_TREE_MRKD_STATISTIC_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/mrkd_statistic_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/mrkd_statistic_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/mrkd_statistic_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,108 +0,0 @@
-/**
- * @file mrkd_statistic_impl.hpp
- * @author James Cline
- *
- * Definition of the statistic for multi-resolution kd-trees.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_MRKD_STATISTIC_IMPL_HPP
-#define __MLPACK_CORE_TREE_MRKD_STATISTIC_IMPL_HPP
-
-// In case it hasn't already been included.
-#include "mrkd_statistic.hpp"
-
-namespace mlpack {
-namespace tree {
-
-template<typename TreeType>
-MRKDStatistic::MRKDStatistic(const TreeType& /* node */) :
- dataset(NULL),
- begin(0),
- count(0),
- leftStat(NULL),
- rightStat(NULL),
- parentStat(NULL)
-{ }
-
-/**
- * This constructor is called when a leaf is created.
- *
- * @param dataset Matrix that the tree is being built on.
- * @param begin Starting index corresponding to this leaf.
- * @param count Number of points held in this leaf.
- *
-template<typename MatType>
-MRKDStatistic::MRKDStatistic(const TreeType& node) :
- dataset(&dataset),
- begin(begin),
- count(count),
- leftStat(NULL),
- rightStat(NULL),
- parentStat(NULL)
-{
- centerOfMass = dataset.col(begin);
- for (size_t i = begin + 1; i < begin + count; ++i)
- centerOfMass += dataset.col(i);
-
- sumOfSquaredNorms = 0.0;
- for (size_t i = begin; i < begin + count; ++i)
- sumOfSquaredNorms += arma::norm(dataset.col(i), 2);
-}
-
- **
- * This constructor is called when a non-leaf node is created.
- * This lets you build fast bottom-up statistics when building trees.
- *
- * @param dataset Matrix that the tree is being built on.
- * @param begin Starting index corresponding to this leaf.
- * @param count Number of points held in this leaf.
- * @param leftStat MRKDStatistic object of the left child node.
- * @param rightStat MRKDStatistic object of the right child node.
- *
-template<typename MatType>
-MRKDStatistic::MRKDStatistic(const MatType& dataset,
- const size_t begin,
- const size_t count,
- MRKDStatistic& leftStat,
- MRKDStatistic& rightStat) :
- dataset(&dataset),
- begin(begin),
- count(count),
- leftStat(&leftStat),
- rightStat(&rightStat),
- parentStat(NULL)
-{
- sumOfSquaredNorms = leftStat.sumOfSquaredNorms + rightStat.sumOfSquaredNorms;
-
- *
- centerOfMass = ((leftStat.centerOfMass * leftStat.count) +
- (rightStat.centerOfMass * rightStat.count)) /
- (leftStat.count + rightStat.count);
- *
- centerOfMass = leftStat.centerOfMass + rightStat.centerOfMass;
-
- isWhitelistValid = false;
-
- leftStat.parentStat = this;
- rightStat.parentStat = this;
-}
-*/
-
-}; // namespace tree
-}; // namespace mlpack
-
-#endif // __MLPACK_CORE_TREE_MRKD_STATISTIC_IMPL_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/mrkd_statistic_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/mrkd_statistic_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/mrkd_statistic_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/mrkd_statistic_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,108 @@
+/**
+ * @file mrkd_statistic_impl.hpp
+ * @author James Cline
+ *
+ * Definition of the statistic for multi-resolution kd-trees.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_MRKD_STATISTIC_IMPL_HPP
+#define __MLPACK_CORE_TREE_MRKD_STATISTIC_IMPL_HPP
+
+// In case it hasn't already been included.
+#include "mrkd_statistic.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename TreeType>
+MRKDStatistic::MRKDStatistic(const TreeType& /* node */) :
+ dataset(NULL),
+ begin(0),
+ count(0),
+ leftStat(NULL),
+ rightStat(NULL),
+ parentStat(NULL)
+{ }
+
+/**
+ * This constructor is called when a leaf is created.
+ *
+ * @param dataset Matrix that the tree is being built on.
+ * @param begin Starting index corresponding to this leaf.
+ * @param count Number of points held in this leaf.
+ *
+template<typename MatType>
+MRKDStatistic::MRKDStatistic(const TreeType& node) :
+ dataset(&dataset),
+ begin(begin),
+ count(count),
+ leftStat(NULL),
+ rightStat(NULL),
+ parentStat(NULL)
+{
+ centerOfMass = dataset.col(begin);
+ for (size_t i = begin + 1; i < begin + count; ++i)
+ centerOfMass += dataset.col(i);
+
+ sumOfSquaredNorms = 0.0;
+ for (size_t i = begin; i < begin + count; ++i)
+ sumOfSquaredNorms += arma::norm(dataset.col(i), 2);
+}
+
+ **
+ * This constructor is called when a non-leaf node is created.
+ * This lets you build fast bottom-up statistics when building trees.
+ *
+ * @param dataset Matrix that the tree is being built on.
+ * @param begin Starting index corresponding to this leaf.
+ * @param count Number of points held in this leaf.
+ * @param leftStat MRKDStatistic object of the left child node.
+ * @param rightStat MRKDStatistic object of the right child node.
+ *
+template<typename MatType>
+MRKDStatistic::MRKDStatistic(const MatType& dataset,
+ const size_t begin,
+ const size_t count,
+ MRKDStatistic& leftStat,
+ MRKDStatistic& rightStat) :
+ dataset(&dataset),
+ begin(begin),
+ count(count),
+ leftStat(&leftStat),
+ rightStat(&rightStat),
+ parentStat(NULL)
+{
+ sumOfSquaredNorms = leftStat.sumOfSquaredNorms + rightStat.sumOfSquaredNorms;
+
+ *
+ centerOfMass = ((leftStat.centerOfMass * leftStat.count) +
+ (rightStat.centerOfMass * rightStat.count)) /
+ (leftStat.count + rightStat.count);
+ *
+ centerOfMass = leftStat.centerOfMass + rightStat.centerOfMass;
+
+ isWhitelistValid = false;
+
+ leftStat.parentStat = this;
+ rightStat.parentStat = this;
+}
+*/
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_TREE_MRKD_STATISTIC_IMPL_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/periodichrectbound.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/periodichrectbound.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/periodichrectbound.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,165 +0,0 @@
-/**
- * @file periodichrectbound.hpp
- *
- * Bounds that are useful for binary space partitioning trees.
- *
- * This file describes the interface for the PeriodicHRectBound policy, which
- * implements a hyperrectangle bound in a periodic space.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_PERIODICHRECTBOUND_HPP
-#define __MLPACK_CORE_TREE_PERIODICHRECTBOUND_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace bound {
-
-/**
- * Hyper-rectangle bound for an L-metric.
- *
- * Template parameter t_pow is the metric to use; use 2 for Euclidean (L2).
- */
-template<int t_pow = 2>
-class PeriodicHRectBound
-{
- public:
- /**
- * Empty constructor.
- */
- PeriodicHRectBound();
-
- /**
- * Specifies the box size. The dimensionality is set to the same of the box
- * size, and the bounds are initialized to be empty.
- */
- PeriodicHRectBound(arma::vec box);
-
- /***
- * Copy constructor and copy operator. These are necessary because we do our
- * own memory management.
- */
- PeriodicHRectBound(const PeriodicHRectBound& other);
- PeriodicHRectBound& operator=(const PeriodicHRectBound& other);
-
- /**
- * Destructor: clean up memory.
- */
- ~PeriodicHRectBound();
-
- /**
- * Modifies the box to the desired dimenstions.
- */
- void SetBoxSize(arma::vec box);
-
- /**
- * Returns the box vector.
- */
- const arma::vec& Box() const { return box; }
-
- /**
- * Resets all dimensions to the empty set.
- */
- void Clear();
-
- /** Gets the dimensionality */
- size_t Dim() const { return dim; }
-
- /**
- * Sets and gets the range for a particular dimension.
- */
- math::Range& operator[](size_t i);
- const math::Range operator[](size_t i) const;
-
- /***
- * Calculates the centroid of the range. This does not factor in periodic
- * coordinates, so the centroid may not necessarily be inside the given box.
- *
- * @param centroid Vector to write the centroid to.
- */
- void Centroid(arma::vec& centroid) const;
-
- /**
- * Calculates minimum bound-to-point squared distance in the periodic bound
- * case.
- */
- double MinDistance(const arma::vec& point) const;
-
- /**
- * Calculates minimum bound-to-bound squared distance in the periodic bound
- * case.
- *
- * Example: bound1.MinDistance(other) for minimum squared distance.
- */
- double MinDistance(const PeriodicHRectBound& other) const;
-
- /**
- * Calculates maximum bound-to-point squared distance in the periodic bound
- * case.
- */
- double MaxDistance(const arma::vec& point) const;
-
- /**
- * Computes maximum bound-to-bound squared distance in the periodic bound
- * case.
- */
- double MaxDistance(const PeriodicHRectBound& other) const;
-
- /**
- * Calculates minimum and maximum bound-to-point squared distance in the
- * periodic bound case.
- */
- math::Range RangeDistance(const arma::vec& point) const;
-
- /**
- * Calculates minimum and maximum bound-to-bound squared distance in the
- * periodic bound case.
- */
- math::Range RangeDistance(const PeriodicHRectBound& other) const;
-
- /**
- * Expands this region to include a new point.
- */
- PeriodicHRectBound& operator|=(const arma::vec& vector);
-
- /**
- * Expands this region to encompass another bound.
- */
- PeriodicHRectBound& operator|=(const PeriodicHRectBound& other);
-
- /**
- * Determines if a point is within this bound.
- */
- bool Contains(const arma::vec& point) const;
-
- /**
- * Returns a string representation of an object.
- */
- std::string ToString() const;
-
- private:
- math::Range *bounds;
- size_t dim;
- arma::vec box;
-};
-
-}; // namespace bound
-}; // namespace mlpack
-
-#include "periodichrectbound_impl.hpp"
-
-#endif // __MLPACK_CORE_TREE_PERIODICHRECTBOUND_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/periodichrectbound.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/periodichrectbound.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/periodichrectbound.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/periodichrectbound.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,165 @@
+/**
+ * @file periodichrectbound.hpp
+ *
+ * Bounds that are useful for binary space partitioning trees.
+ *
+ * This file describes the interface for the PeriodicHRectBound policy, which
+ * implements a hyperrectangle bound in a periodic space.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_PERIODICHRECTBOUND_HPP
+#define __MLPACK_CORE_TREE_PERIODICHRECTBOUND_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace bound {
+
+/**
+ * Hyper-rectangle bound for an L-metric.
+ *
+ * Template parameter t_pow is the metric to use; use 2 for Euclidean (L2).
+ */
+template<int t_pow = 2>
+class PeriodicHRectBound
+{
+ public:
+ /**
+ * Empty constructor.
+ */
+ PeriodicHRectBound();
+
+ /**
+ * Specifies the box size. The dimensionality is set to the same of the box
+ * size, and the bounds are initialized to be empty.
+ */
+ PeriodicHRectBound(arma::vec box);
+
+ /***
+ * Copy constructor and copy operator. These are necessary because we do our
+ * own memory management.
+ */
+ PeriodicHRectBound(const PeriodicHRectBound& other);
+ PeriodicHRectBound& operator=(const PeriodicHRectBound& other);
+
+ /**
+ * Destructor: clean up memory.
+ */
+ ~PeriodicHRectBound();
+
+ /**
+ * Modifies the box to the desired dimenstions.
+ */
+ void SetBoxSize(arma::vec box);
+
+ /**
+ * Returns the box vector.
+ */
+ const arma::vec& Box() const { return box; }
+
+ /**
+ * Resets all dimensions to the empty set.
+ */
+ void Clear();
+
+ /** Gets the dimensionality */
+ size_t Dim() const { return dim; }
+
+ /**
+ * Sets and gets the range for a particular dimension.
+ */
+ math::Range& operator[](size_t i);
+ const math::Range operator[](size_t i) const;
+
+ /***
+ * Calculates the centroid of the range. This does not factor in periodic
+ * coordinates, so the centroid may not necessarily be inside the given box.
+ *
+ * @param centroid Vector to write the centroid to.
+ */
+ void Centroid(arma::vec& centroid) const;
+
+ /**
+ * Calculates minimum bound-to-point squared distance in the periodic bound
+ * case.
+ */
+ double MinDistance(const arma::vec& point) const;
+
+ /**
+ * Calculates minimum bound-to-bound squared distance in the periodic bound
+ * case.
+ *
+ * Example: bound1.MinDistance(other) for minimum squared distance.
+ */
+ double MinDistance(const PeriodicHRectBound& other) const;
+
+ /**
+ * Calculates maximum bound-to-point squared distance in the periodic bound
+ * case.
+ */
+ double MaxDistance(const arma::vec& point) const;
+
+ /**
+ * Computes maximum bound-to-bound squared distance in the periodic bound
+ * case.
+ */
+ double MaxDistance(const PeriodicHRectBound& other) const;
+
+ /**
+ * Calculates minimum and maximum bound-to-point squared distance in the
+ * periodic bound case.
+ */
+ math::Range RangeDistance(const arma::vec& point) const;
+
+ /**
+ * Calculates minimum and maximum bound-to-bound squared distance in the
+ * periodic bound case.
+ */
+ math::Range RangeDistance(const PeriodicHRectBound& other) const;
+
+ /**
+ * Expands this region to include a new point.
+ */
+ PeriodicHRectBound& operator|=(const arma::vec& vector);
+
+ /**
+ * Expands this region to encompass another bound.
+ */
+ PeriodicHRectBound& operator|=(const PeriodicHRectBound& other);
+
+ /**
+ * Determines if a point is within this bound.
+ */
+ bool Contains(const arma::vec& point) const;
+
+ /**
+ * Returns a string representation of an object.
+ */
+ std::string ToString() const;
+
+ private:
+ math::Range *bounds;
+ size_t dim;
+ arma::vec box;
+};
+
+}; // namespace bound
+}; // namespace mlpack
+
+#include "periodichrectbound_impl.hpp"
+
+#endif // __MLPACK_CORE_TREE_PERIODICHRECTBOUND_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/periodichrectbound_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/periodichrectbound_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/periodichrectbound_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,600 +0,0 @@
-/**
- * @file periodichrectbound_impl.hpp
- *
- * Implementation of periodic hyper-rectangle bound policy class.
- * Template parameter t_pow is the metric to use; use 2 for Euclidian (L2).
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_PERIODICHRECTBOUND_IMPL_HPP
-#define __MLPACK_CORE_TREE_PERIODICHRECTBOUND_IMPL_HPP
-
-// In case it has not already been included.
-#include "periodichrectbound.hpp"
-
-#include <math.h>
-
-namespace mlpack {
-namespace bound {
-
-/**
- * Empty constructor
- */
-template<int t_pow>
-PeriodicHRectBound<t_pow>::PeriodicHRectBound() :
- bounds(NULL),
- dim(0),
- box(/* empty */)
-{ /* nothing to do */ }
-
-/**
- * Specifies the box size, but not dimensionality.
- */
-template<int t_pow>
-PeriodicHRectBound<t_pow>::PeriodicHRectBound(arma::vec box) :
- bounds(new math::Range[box.n_rows]),
- dim(box.n_rows),
- box(box)
-{ /* nothing to do */ }
-
-/***
- * Copy constructor.
- */
-template<int t_pow>
-PeriodicHRectBound<t_pow>::PeriodicHRectBound(const PeriodicHRectBound& other) :
- dim(other.Dim()),
- box(other.Box())
-{
- bounds = new math::Range[other.Dim()];
- for (size_t i = 0; i < dim; i++)
- bounds[i] |= other[i];
-}
-
-/***
- * Copy operator.
- */
-template<int t_pow>
-PeriodicHRectBound<t_pow>& PeriodicHRectBound<t_pow>::operator=(
- const PeriodicHRectBound& other)
-{
- // not done yet
-
- return *this;
-}
-
-/**
- * Destructor: clean up memory
- */
-template<int t_pow>
-PeriodicHRectBound<t_pow>::~PeriodicHRectBound()
-{
- if (bounds)
- delete[] bounds;
-}
-
-/**
- * Modifies the box to the desired dimenstions.
- */
-template<int t_pow>
-void PeriodicHRectBound<t_pow>::SetBoxSize(arma::vec box)
-{
- box = box;
-}
-
-/**
- * Resets all dimensions to the empty set.
- */
-template<int t_pow>
-void PeriodicHRectBound<t_pow>::Clear()
-{
- for (size_t i = 0; i < dim; i++)
- bounds[i] = math::Range();
-}
-
-/**
- * Gets the range for a particular dimension.
- */
-template<int t_pow>
-const math::Range PeriodicHRectBound<t_pow>::operator[](size_t i) const
-{
- return bounds[i];
-}
-
-/**
- * Sets the range for the given dimension.
- */
-template<int t_pow>
-math::Range& PeriodicHRectBound<t_pow>::operator[](size_t i)
-{
- return bounds[i];
-}
-
-/** Calculates the midpoint of the range */
-template<int t_pow>
-void PeriodicHRectBound<t_pow>::Centroid(arma::vec& centroid) const
-{
- // set size correctly if necessary
- if (!(centroid.n_elem == dim))
- centroid.set_size(dim);
-
- for (size_t i = 0; i < dim; i++)
- centroid(i) = bounds[i].Mid();
-}
-
-/**
- * Calculates minimum bound-to-point squared distance.
- *
- */
-
-template<int t_pow>
-double PeriodicHRectBound<t_pow>::MinDistance(const arma::vec& point) const
-{
- arma::vec point2 = point;
- double totalMin = 0;
- // Create the mirrored images. The minimum distance from the bound to a
- // mirrored point is the minimum periodic distance.
- arma::vec box = box;
- for (int i = 0; i < dim; i++)
- {
- point2 = point;
- double min = 100000000;
- // Mod the point within the box.
-
- if (box[i] < 0)
- {
- box[i] = abs(box[i]);
- }
- if (box[i] != 0)
- {
- if (abs(point[i]) > box[i])
- {
- point2[i] = fmod(point2[i],box[i]);
- }
- }
-
- for (int k = 0; k < 3; k++)
- {
- arma::vec point3 = point2;
-
- if (k == 1)
- point3[i] += box[i];
- else if (k == 2)
- point3[i] -= box[i];
-
- double tempMin;
- double sum = 0;
-
- double lower, higher;
- lower = bounds[i].Lo() - point3[i];
- higher = point3[i] - bounds[i].Hi();
-
- sum += pow((lower + fabs(lower)) +
- (higher + fabs(higher)), (double) t_pow);
- tempMin = pow(sum, 2.0 / (double) t_pow) / 4.0;
-
- if (tempMin < min)
- min = tempMin;
- }
-
- totalMin += min;
- }
- return totalMin;
-
-}
-
-/**
- * Calculates minimum bound-to-bound squared distance.
- *
- * Example: bound1.MinDistance(other) for minimum squared distance.
- */
-template<int t_pow>
-double PeriodicHRectBound<t_pow>::MinDistance(
- const PeriodicHRectBound& other) const
-{
- double totalMin = 0;
- // Create the mirrored images. The minimum distance from the bound to a
- // mirrored point is the minimum periodic distance.
- arma::vec box = box;
- PeriodicHRectBound<2> a(other);
-
- for (int i = 0; i < dim; i++)
- {
- double min = DBL_MAX;
- if (box[i] < 0)
- box[i] = abs(box[i]);
-
- if (box[i] != 0)
- {
- if (abs(other[i].Lo()) > box[i])
- a[i].Lo() = fmod(a[i].Lo(),box[i]);
-
- if (abs(other[i].Hi()) > box[i])
- a[i].Hi() = fmod(a[i].Hi(),box[i]);
- }
-
- for (int k = 0; k < 3; k++)
- {
- PeriodicHRectBound<2> b = a;
- if (k == 1)
- {
- b[i].Lo() += box[i];
- b[i].Hi() += box[i];
- }
- else if (k == 2)
- {
- b[i].Lo() -= box[i];
- b[i].Hi() -= box[i];
- }
-
- double sum = 0;
- double tempMin;
- double sumLower = 0;
- double sumHigher = 0;
-
- double lower, higher, lowerLower, lowerHigher, higherLower,
- higherHigher;
-
- // If the bound crosses over the box, split ito two seperate bounds and
- // find the minimum distance between them.
- if (b[i].Hi() < b[i].Lo())
- {
- PeriodicHRectBound<2> d(b);
- PeriodicHRectBound<2> c(b);
- d[i].Lo() = 0;
- c[i].Hi() = box[i];
-
- if (k == 1)
- {
- d[i].Lo() += box[i];
- c[i].Hi() += box[i];
- }
- else if (k == 2)
- {
- d[i].Lo() -= box[i];
- c[i].Hi() -= box[i];
- }
-
- d[i].Hi() = b[i].Hi();
- c[i].Lo() = b[i].Lo();
-
- lowerLower = d[i].Lo() - bounds[i].Hi();
- higherLower = bounds[i].Lo() - d[i].Hi();
-
- lowerHigher = c[i].Lo() - bounds[i].Hi();
- higherHigher = bounds[i].Lo() - c[i].Hi();
-
- sumLower += pow((lowerLower + fabs(lowerLower)) +
- (higherLower + fabs(higherLower)), (double) t_pow);
-
- sumHigher += pow((lowerHigher + fabs(lowerHigher)) +
- (higherHigher + fabs(higherHigher)), (double) t_pow);
-
- if (sumLower > sumHigher)
- tempMin = pow(sumHigher, 2.0 / (double) t_pow) / 4.0;
- else
- tempMin = pow(sumLower, 2.0 / (double) t_pow) / 4.0;
- }
- else
- {
- lower = b[i].Lo() - bounds[i].Hi();
- higher = bounds[i].Lo() - b[i].Hi();
- // We invoke the following:
- // x + fabs(x) = max(x * 2, 0)
- // (x * 2)^2 / 4 = x^2
- sum += pow((lower + fabs(lower)) +
- (higher + fabs(higher)), (double) t_pow);
- tempMin = pow(sum, 2.0 / (double) t_pow) / 4.0;
- }
-
- if (tempMin < min)
- min = tempMin;
- }
- totalMin += min;
- }
- return totalMin;
-}
-
-
-/**
- * Calculates maximum bound-to-point squared distance.
- */
-template<int t_pow>
-double PeriodicHRectBound<t_pow>::MaxDistance(const arma::vec& point) const
-{
- arma::vec point2 = point;
- double totalMax = 0;
- //Create the mirrored images. The minimum distance from the bound to a
- //mirrored point is the minimum periodic distance.
- arma::vec box = box;
- for (int i = 0; i < dim; i++)
- {
- point2 = point;
- double max = 0;
- // Mod the point within the box.
-
- if (box[i] < 0)
- box[i] = abs(box[i]);
-
- if (box[i] != 0)
- if (abs(point[i]) > box[i])
- point2[i] = fmod(point2[i],box[i]);
-
- for (int k = 0; k < 3; k++)
- {
- arma::vec point3 = point2;
-
- if (k == 1)
- point3[i] += box[i];
- else if (k == 2)
- point3[i] -= box[i];
-
- double tempMax;
- double sum = 0;
-
- double v = std::max(fabs(point3[i] - bounds[i].Lo()),
- fabs(bounds[i].Hi() - point3[i]));
- sum += pow(v, (double) t_pow);
-
- tempMax = pow(sum, 2.0 / (double) t_pow) / 4.0;
-
- if (tempMax > max)
- max = tempMax;
- }
-
- totalMax += max;
- }
- return totalMax;
-
-}
-
-/**
- * Computes maximum distance.
- */
-template<int t_pow>
-double PeriodicHRectBound<t_pow>::MaxDistance(
- const PeriodicHRectBound& other) const
-{
- double totalMax = 0;
- //Create the mirrored images. The minimum distance from the bound to a
- //mirrored point is the minimum periodic distance.
- arma::vec box = box;
- PeriodicHRectBound<2> a(other);
-
-
- for (int i = 0; i < dim; i++)
- {
- double max = 0;
- if (box[i] < 0)
- box[i] = abs(box[i]);
-
- if (box[i] != 0)
- {
- if (abs(other[i].Lo()) > box[i])
- a[i].Lo() = fmod(a[i].Lo(),box[i]);
-
- if (abs(other[i].Hi()) > box[i])
- a[i].Hi() = fmod(a[i].Hi(),box[i]);
- }
-
- for (int k = 0; k < 3; k++)
- {
- PeriodicHRectBound<2> b = a;
- if (k == 1)
- {
- b[i].Lo() += box[i];
- b[i].Hi() += box[i];
- }
- else if (k == 2)
- {
- b[i].Lo() -= box[i];
- b[i].Hi() -= box[i];
- }
-
- double sum = 0;
- double tempMax;
-
- double sumLower = 0, sumHigher = 0;
-
-
- // If the bound corsses over the box, split ito two seperate bounds and
- // find thhe minimum distance between them.
- if (b[i].Hi() < b[i].Lo())
- {
- PeriodicHRectBound<2> d(b);
- PeriodicHRectBound<2> c(b);
- a[i].Lo() = 0;
- c[i].Hi() = box[i];
-
- if (k == 1)
- {
- d[i].Lo() += box[i];
- c[i].Hi() += box[i];
- }
- else if (k == 2)
- {
- d[i].Lo() -= box[i];
- c[i].Hi() -= box[i];
- }
-
- d[i].Hi() = b[i].Hi();
- c[i].Lo() = b[i].Lo();
-
- double vLower = std::max(fabs(d.bounds[i].Hi() - bounds[i].Lo()),
- fabs(bounds[i].Hi() - d.bounds[i].Lo()));
-
- double vHigher = std::max(fabs(c.bounds[i].Hi() - bounds[i].Lo()),
- fabs(bounds[i].Hi() - c.bounds[i].Lo()));
-
- sumLower += pow(vLower, (double) t_pow);
- sumHigher += pow(vHigher, (double) t_pow);
-
- if (sumLower > sumHigher)
- tempMax = pow(sumHigher, 2.0 / (double) t_pow) / 4.0;
- else
- tempMax = pow(sumLower, 2.0 / (double) t_pow) / 4.0;
- }
- else
- {
- double v = std::max(fabs(b.bounds[i].Hi() - bounds[i].Lo()),
- fabs(bounds[i].Hi() - b.bounds[i].Lo()));
- sum += pow(v, (double) t_pow); // v is non-negative.
- tempMax = pow(sum, 2.0 / (double) t_pow);
- }
-
-
- if (tempMax > max)
- max = tempMax;
- }
- totalMax += max;
- }
- return totalMax;
-}
-
-/**
- * Calculates minimum and maximum bound-to-point squared distance.
- */
-template<int t_pow>
-math::Range PeriodicHRectBound<t_pow>::RangeDistance(
- const arma::vec& point) const
-{
- double sum_lo = 0;
- double sum_hi = 0;
-
- Log::Assert(point.n_elem == dim);
-
- double v1, v2, v_lo, v_hi;
- for (size_t d = 0; d < dim; d++)
- {
- v1 = bounds[d].Lo() - point[d];
- v2 = point[d] - bounds[d].Hi();
- // One of v1 or v2 is negative.
- if (v1 >= 0)
- {
- v_hi = -v2;
- v_lo = v1;
- }
- else
- {
- v_hi = -v1;
- v_lo = v2;
- }
-
- sum_lo += pow(v_lo, (double) t_pow);
- sum_hi += pow(v_hi, (double) t_pow);
- }
-
- return math::Range(pow(sum_lo, 2.0 / (double) t_pow),
- pow(sum_hi, 2.0 / (double) t_pow));
-}
-
-/**
- * Calculates minimum and maximum bound-to-bound squared distance.
- */
-template<int t_pow>
-math::Range PeriodicHRectBound<t_pow>::RangeDistance(
- const PeriodicHRectBound& other) const
-{
- double sum_lo = 0;
- double sum_hi = 0;
-
- Log::Assert(dim == other.dim);
-
- double v1, v2, v_lo, v_hi;
- for (size_t d = 0; d < dim; d++)
- {
- v1 = other.bounds[d].Lo() - bounds[d].Hi();
- v2 = bounds[d].Lo() - other.bounds[d].Hi();
- // One of v1 or v2 is negative.
- if (v1 >= v2)
- {
- v_hi = -v2; // Make it nonnegative.
- v_lo = (v1 > 0) ? v1 : 0; // Force to be 0 if negative.
- }
- else
- {
- v_hi = -v1; // Make it nonnegative.
- v_lo = (v2 > 0) ? v2 : 0; // Force to be 0 if negative.
- }
-
- sum_lo += pow(v_lo, (double) t_pow);
- sum_hi += pow(v_hi, (double) t_pow);
- }
-
- return math::Range(pow(sum_lo, 2.0 / (double) t_pow),
- pow(sum_hi, 2.0 / (double) t_pow));
-}
-
-/**
- * Expands this region to include a new point.
- */
-template<int t_pow>
-PeriodicHRectBound<t_pow>& PeriodicHRectBound<t_pow>::operator|=(
- const arma::vec& vector)
-{
- Log::Assert(vector.n_elem == dim);
-
- for (size_t i = 0; i < dim; i++)
- bounds[i] |= vector[i];
-
- return *this;
-}
-
-/**
- * Expands this region to encompass another bound.
- */
-template<int t_pow>
-PeriodicHRectBound<t_pow>& PeriodicHRectBound<t_pow>::operator|=(
- const PeriodicHRectBound& other)
-{
- Log::Assert(other.dim == dim);
-
- for (size_t i = 0; i < dim; i++)
- bounds[i] |= other.bounds[i];
-
- return *this;
-}
-
-/**
- * Determines if a point is within this bound.
- */
-template<int t_pow>
-bool PeriodicHRectBound<t_pow>::Contains(const arma::vec& point) const
-{
- for (size_t i = 0; i < point.n_elem; i++)
- if (!bounds[i].Contains(point(i)))
- return false;
-
- return true;
-}
-
-/**
- * Returns a string representation of this object.
- */
-template<int t_pow>
-std::string PeriodicHRectBound<t_pow>::ToString() const
-{
- std::ostringstream convert;
- convert << "PeriodicHRectBound [" << this << "]" << std::endl;
- convert << "bounds: " << bounds->ToString() << std::endl;
- convert << "dim: " << dim << std::endl;
- convert << "box: " << box;
- return convert.str();
-}
-
-}; // namespace bound
-}; // namespace mlpack
-
-#endif // __MLPACK_CORE_TREE_PERIODICHRECTBOUND_IMPL_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/periodichrectbound_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/periodichrectbound_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/periodichrectbound_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/periodichrectbound_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,600 @@
+/**
+ * @file periodichrectbound_impl.hpp
+ *
+ * Implementation of periodic hyper-rectangle bound policy class.
+ * Template parameter t_pow is the metric to use; use 2 for Euclidian (L2).
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_PERIODICHRECTBOUND_IMPL_HPP
+#define __MLPACK_CORE_TREE_PERIODICHRECTBOUND_IMPL_HPP
+
+// In case it has not already been included.
+#include "periodichrectbound.hpp"
+
+#include <math.h>
+
+namespace mlpack {
+namespace bound {
+
+/**
+ * Empty constructor
+ */
+template<int t_pow>
+PeriodicHRectBound<t_pow>::PeriodicHRectBound() :
+ bounds(NULL),
+ dim(0),
+ box(/* empty */)
+{ /* nothing to do */ }
+
+/**
+ * Specifies the box size, but not dimensionality.
+ */
+template<int t_pow>
+PeriodicHRectBound<t_pow>::PeriodicHRectBound(arma::vec box) :
+ bounds(new math::Range[box.n_rows]),
+ dim(box.n_rows),
+ box(box)
+{ /* nothing to do */ }
+
+/***
+ * Copy constructor.
+ */
+template<int t_pow>
+PeriodicHRectBound<t_pow>::PeriodicHRectBound(const PeriodicHRectBound& other) :
+ dim(other.Dim()),
+ box(other.Box())
+{
+ bounds = new math::Range[other.Dim()];
+ for (size_t i = 0; i < dim; i++)
+ bounds[i] |= other[i];
+}
+
+/***
+ * Copy operator.
+ */
+template<int t_pow>
+PeriodicHRectBound<t_pow>& PeriodicHRectBound<t_pow>::operator=(
+ const PeriodicHRectBound& other)
+{
+ // not done yet
+
+ return *this;
+}
+
+/**
+ * Destructor: clean up memory
+ */
+template<int t_pow>
+PeriodicHRectBound<t_pow>::~PeriodicHRectBound()
+{
+ if (bounds)
+ delete[] bounds;
+}
+
+/**
+ * Modifies the box to the desired dimenstions.
+ */
+template<int t_pow>
+void PeriodicHRectBound<t_pow>::SetBoxSize(arma::vec box)
+{
+ box = box;
+}
+
+/**
+ * Resets all dimensions to the empty set.
+ */
+template<int t_pow>
+void PeriodicHRectBound<t_pow>::Clear()
+{
+ for (size_t i = 0; i < dim; i++)
+ bounds[i] = math::Range();
+}
+
+/**
+ * Gets the range for a particular dimension.
+ */
+template<int t_pow>
+const math::Range PeriodicHRectBound<t_pow>::operator[](size_t i) const
+{
+ return bounds[i];
+}
+
+/**
+ * Sets the range for the given dimension.
+ */
+template<int t_pow>
+math::Range& PeriodicHRectBound<t_pow>::operator[](size_t i)
+{
+ return bounds[i];
+}
+
+/** Calculates the midpoint of the range */
+template<int t_pow>
+void PeriodicHRectBound<t_pow>::Centroid(arma::vec& centroid) const
+{
+ // set size correctly if necessary
+ if (!(centroid.n_elem == dim))
+ centroid.set_size(dim);
+
+ for (size_t i = 0; i < dim; i++)
+ centroid(i) = bounds[i].Mid();
+}
+
+/**
+ * Calculates minimum bound-to-point squared distance.
+ *
+ */
+
+template<int t_pow>
+double PeriodicHRectBound<t_pow>::MinDistance(const arma::vec& point) const
+{
+ arma::vec point2 = point;
+ double totalMin = 0;
+ // Create the mirrored images. The minimum distance from the bound to a
+ // mirrored point is the minimum periodic distance.
+ arma::vec box = box;
+ for (int i = 0; i < dim; i++)
+ {
+ point2 = point;
+ double min = 100000000;
+ // Mod the point within the box.
+
+ if (box[i] < 0)
+ {
+ box[i] = abs(box[i]);
+ }
+ if (box[i] != 0)
+ {
+ if (abs(point[i]) > box[i])
+ {
+ point2[i] = fmod(point2[i],box[i]);
+ }
+ }
+
+ for (int k = 0; k < 3; k++)
+ {
+ arma::vec point3 = point2;
+
+ if (k == 1)
+ point3[i] += box[i];
+ else if (k == 2)
+ point3[i] -= box[i];
+
+ double tempMin;
+ double sum = 0;
+
+ double lower, higher;
+ lower = bounds[i].Lo() - point3[i];
+ higher = point3[i] - bounds[i].Hi();
+
+ sum += pow((lower + fabs(lower)) +
+ (higher + fabs(higher)), (double) t_pow);
+ tempMin = pow(sum, 2.0 / (double) t_pow) / 4.0;
+
+ if (tempMin < min)
+ min = tempMin;
+ }
+
+ totalMin += min;
+ }
+ return totalMin;
+
+}
+
+/**
+ * Calculates minimum bound-to-bound squared distance.
+ *
+ * Example: bound1.MinDistance(other) for minimum squared distance.
+ */
+template<int t_pow>
+double PeriodicHRectBound<t_pow>::MinDistance(
+ const PeriodicHRectBound& other) const
+{
+ double totalMin = 0;
+ // Create the mirrored images. The minimum distance from the bound to a
+ // mirrored point is the minimum periodic distance.
+ arma::vec box = box;
+ PeriodicHRectBound<2> a(other);
+
+ for (int i = 0; i < dim; i++)
+ {
+ double min = DBL_MAX;
+ if (box[i] < 0)
+ box[i] = abs(box[i]);
+
+ if (box[i] != 0)
+ {
+ if (abs(other[i].Lo()) > box[i])
+ a[i].Lo() = fmod(a[i].Lo(),box[i]);
+
+ if (abs(other[i].Hi()) > box[i])
+ a[i].Hi() = fmod(a[i].Hi(),box[i]);
+ }
+
+ for (int k = 0; k < 3; k++)
+ {
+ PeriodicHRectBound<2> b = a;
+ if (k == 1)
+ {
+ b[i].Lo() += box[i];
+ b[i].Hi() += box[i];
+ }
+ else if (k == 2)
+ {
+ b[i].Lo() -= box[i];
+ b[i].Hi() -= box[i];
+ }
+
+ double sum = 0;
+ double tempMin;
+ double sumLower = 0;
+ double sumHigher = 0;
+
+ double lower, higher, lowerLower, lowerHigher, higherLower,
+ higherHigher;
+
+ // If the bound crosses over the box, split ito two seperate bounds and
+ // find the minimum distance between them.
+ if (b[i].Hi() < b[i].Lo())
+ {
+ PeriodicHRectBound<2> d(b);
+ PeriodicHRectBound<2> c(b);
+ d[i].Lo() = 0;
+ c[i].Hi() = box[i];
+
+ if (k == 1)
+ {
+ d[i].Lo() += box[i];
+ c[i].Hi() += box[i];
+ }
+ else if (k == 2)
+ {
+ d[i].Lo() -= box[i];
+ c[i].Hi() -= box[i];
+ }
+
+ d[i].Hi() = b[i].Hi();
+ c[i].Lo() = b[i].Lo();
+
+ lowerLower = d[i].Lo() - bounds[i].Hi();
+ higherLower = bounds[i].Lo() - d[i].Hi();
+
+ lowerHigher = c[i].Lo() - bounds[i].Hi();
+ higherHigher = bounds[i].Lo() - c[i].Hi();
+
+ sumLower += pow((lowerLower + fabs(lowerLower)) +
+ (higherLower + fabs(higherLower)), (double) t_pow);
+
+ sumHigher += pow((lowerHigher + fabs(lowerHigher)) +
+ (higherHigher + fabs(higherHigher)), (double) t_pow);
+
+ if (sumLower > sumHigher)
+ tempMin = pow(sumHigher, 2.0 / (double) t_pow) / 4.0;
+ else
+ tempMin = pow(sumLower, 2.0 / (double) t_pow) / 4.0;
+ }
+ else
+ {
+ lower = b[i].Lo() - bounds[i].Hi();
+ higher = bounds[i].Lo() - b[i].Hi();
+ // We invoke the following:
+ // x + fabs(x) = max(x * 2, 0)
+ // (x * 2)^2 / 4 = x^2
+ sum += pow((lower + fabs(lower)) +
+ (higher + fabs(higher)), (double) t_pow);
+ tempMin = pow(sum, 2.0 / (double) t_pow) / 4.0;
+ }
+
+ if (tempMin < min)
+ min = tempMin;
+ }
+ totalMin += min;
+ }
+ return totalMin;
+}
+
+
+/**
+ * Calculates maximum bound-to-point squared distance.
+ */
+template<int t_pow>
+double PeriodicHRectBound<t_pow>::MaxDistance(const arma::vec& point) const
+{
+ arma::vec point2 = point;
+ double totalMax = 0;
+ //Create the mirrored images. The minimum distance from the bound to a
+ //mirrored point is the minimum periodic distance.
+ arma::vec box = box;
+ for (int i = 0; i < dim; i++)
+ {
+ point2 = point;
+ double max = 0;
+ // Mod the point within the box.
+
+ if (box[i] < 0)
+ box[i] = abs(box[i]);
+
+ if (box[i] != 0)
+ if (abs(point[i]) > box[i])
+ point2[i] = fmod(point2[i],box[i]);
+
+ for (int k = 0; k < 3; k++)
+ {
+ arma::vec point3 = point2;
+
+ if (k == 1)
+ point3[i] += box[i];
+ else if (k == 2)
+ point3[i] -= box[i];
+
+ double tempMax;
+ double sum = 0;
+
+ double v = std::max(fabs(point3[i] - bounds[i].Lo()),
+ fabs(bounds[i].Hi() - point3[i]));
+ sum += pow(v, (double) t_pow);
+
+ tempMax = pow(sum, 2.0 / (double) t_pow) / 4.0;
+
+ if (tempMax > max)
+ max = tempMax;
+ }
+
+ totalMax += max;
+ }
+ return totalMax;
+
+}
+
+/**
+ * Computes maximum distance.
+ */
+template<int t_pow>
+double PeriodicHRectBound<t_pow>::MaxDistance(
+ const PeriodicHRectBound& other) const
+{
+ double totalMax = 0;
+ //Create the mirrored images. The minimum distance from the bound to a
+ //mirrored point is the minimum periodic distance.
+ arma::vec box = box;
+ PeriodicHRectBound<2> a(other);
+
+
+ for (int i = 0; i < dim; i++)
+ {
+ double max = 0;
+ if (box[i] < 0)
+ box[i] = abs(box[i]);
+
+ if (box[i] != 0)
+ {
+ if (abs(other[i].Lo()) > box[i])
+ a[i].Lo() = fmod(a[i].Lo(),box[i]);
+
+ if (abs(other[i].Hi()) > box[i])
+ a[i].Hi() = fmod(a[i].Hi(),box[i]);
+ }
+
+ for (int k = 0; k < 3; k++)
+ {
+ PeriodicHRectBound<2> b = a;
+ if (k == 1)
+ {
+ b[i].Lo() += box[i];
+ b[i].Hi() += box[i];
+ }
+ else if (k == 2)
+ {
+ b[i].Lo() -= box[i];
+ b[i].Hi() -= box[i];
+ }
+
+ double sum = 0;
+ double tempMax;
+
+ double sumLower = 0, sumHigher = 0;
+
+
+ // If the bound corsses over the box, split ito two seperate bounds and
+ // find thhe minimum distance between them.
+ if (b[i].Hi() < b[i].Lo())
+ {
+ PeriodicHRectBound<2> d(b);
+ PeriodicHRectBound<2> c(b);
+ a[i].Lo() = 0;
+ c[i].Hi() = box[i];
+
+ if (k == 1)
+ {
+ d[i].Lo() += box[i];
+ c[i].Hi() += box[i];
+ }
+ else if (k == 2)
+ {
+ d[i].Lo() -= box[i];
+ c[i].Hi() -= box[i];
+ }
+
+ d[i].Hi() = b[i].Hi();
+ c[i].Lo() = b[i].Lo();
+
+ double vLower = std::max(fabs(d.bounds[i].Hi() - bounds[i].Lo()),
+ fabs(bounds[i].Hi() - d.bounds[i].Lo()));
+
+ double vHigher = std::max(fabs(c.bounds[i].Hi() - bounds[i].Lo()),
+ fabs(bounds[i].Hi() - c.bounds[i].Lo()));
+
+ sumLower += pow(vLower, (double) t_pow);
+ sumHigher += pow(vHigher, (double) t_pow);
+
+ if (sumLower > sumHigher)
+ tempMax = pow(sumHigher, 2.0 / (double) t_pow) / 4.0;
+ else
+ tempMax = pow(sumLower, 2.0 / (double) t_pow) / 4.0;
+ }
+ else
+ {
+ double v = std::max(fabs(b.bounds[i].Hi() - bounds[i].Lo()),
+ fabs(bounds[i].Hi() - b.bounds[i].Lo()));
+ sum += pow(v, (double) t_pow); // v is non-negative.
+ tempMax = pow(sum, 2.0 / (double) t_pow);
+ }
+
+
+ if (tempMax > max)
+ max = tempMax;
+ }
+ totalMax += max;
+ }
+ return totalMax;
+}
+
+/**
+ * Calculates minimum and maximum bound-to-point squared distance.
+ */
+template<int t_pow>
+math::Range PeriodicHRectBound<t_pow>::RangeDistance(
+ const arma::vec& point) const
+{
+ double sum_lo = 0;
+ double sum_hi = 0;
+
+ Log::Assert(point.n_elem == dim);
+
+ double v1, v2, v_lo, v_hi;
+ for (size_t d = 0; d < dim; d++)
+ {
+ v1 = bounds[d].Lo() - point[d];
+ v2 = point[d] - bounds[d].Hi();
+ // One of v1 or v2 is negative.
+ if (v1 >= 0)
+ {
+ v_hi = -v2;
+ v_lo = v1;
+ }
+ else
+ {
+ v_hi = -v1;
+ v_lo = v2;
+ }
+
+ sum_lo += pow(v_lo, (double) t_pow);
+ sum_hi += pow(v_hi, (double) t_pow);
+ }
+
+ return math::Range(pow(sum_lo, 2.0 / (double) t_pow),
+ pow(sum_hi, 2.0 / (double) t_pow));
+}
+
+/**
+ * Calculates minimum and maximum bound-to-bound squared distance.
+ */
+template<int t_pow>
+math::Range PeriodicHRectBound<t_pow>::RangeDistance(
+ const PeriodicHRectBound& other) const
+{
+ double sum_lo = 0;
+ double sum_hi = 0;
+
+ Log::Assert(dim == other.dim);
+
+ double v1, v2, v_lo, v_hi;
+ for (size_t d = 0; d < dim; d++)
+ {
+ v1 = other.bounds[d].Lo() - bounds[d].Hi();
+ v2 = bounds[d].Lo() - other.bounds[d].Hi();
+ // One of v1 or v2 is negative.
+ if (v1 >= v2)
+ {
+ v_hi = -v2; // Make it nonnegative.
+ v_lo = (v1 > 0) ? v1 : 0; // Force to be 0 if negative.
+ }
+ else
+ {
+ v_hi = -v1; // Make it nonnegative.
+ v_lo = (v2 > 0) ? v2 : 0; // Force to be 0 if negative.
+ }
+
+ sum_lo += pow(v_lo, (double) t_pow);
+ sum_hi += pow(v_hi, (double) t_pow);
+ }
+
+ return math::Range(pow(sum_lo, 2.0 / (double) t_pow),
+ pow(sum_hi, 2.0 / (double) t_pow));
+}
+
+/**
+ * Expands this region to include a new point.
+ */
+template<int t_pow>
+PeriodicHRectBound<t_pow>& PeriodicHRectBound<t_pow>::operator|=(
+ const arma::vec& vector)
+{
+ Log::Assert(vector.n_elem == dim);
+
+ for (size_t i = 0; i < dim; i++)
+ bounds[i] |= vector[i];
+
+ return *this;
+}
+
+/**
+ * Expands this region to encompass another bound.
+ */
+template<int t_pow>
+PeriodicHRectBound<t_pow>& PeriodicHRectBound<t_pow>::operator|=(
+ const PeriodicHRectBound& other)
+{
+ Log::Assert(other.dim == dim);
+
+ for (size_t i = 0; i < dim; i++)
+ bounds[i] |= other.bounds[i];
+
+ return *this;
+}
+
+/**
+ * Determines if a point is within this bound.
+ */
+template<int t_pow>
+bool PeriodicHRectBound<t_pow>::Contains(const arma::vec& point) const
+{
+ for (size_t i = 0; i < point.n_elem; i++)
+ if (!bounds[i].Contains(point(i)))
+ return false;
+
+ return true;
+}
+
+/**
+ * Returns a string representation of this object.
+ */
+template<int t_pow>
+std::string PeriodicHRectBound<t_pow>::ToString() const
+{
+ std::ostringstream convert;
+ convert << "PeriodicHRectBound [" << this << "]" << std::endl;
+ convert << "bounds: " << bounds->ToString() << std::endl;
+ convert << "dim: " << dim << std::endl;
+ convert << "box: " << box;
+ return convert.str();
+}
+
+}; // namespace bound
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_TREE_PERIODICHRECTBOUND_IMPL_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/statistic.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/statistic.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/statistic.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,65 +0,0 @@
-/**
- * @file statistic.hpp
- *
- * Definition of the policy type for the statistic class.
- *
- * You should define your own statistic that looks like EmptyStatistic.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#ifndef __MLPACK_CORE_TREE_STATISTIC_HPP
-#define __MLPACK_CORE_TREE_STATISTIC_HPP
-
-namespace mlpack {
-namespace tree {
-
-/**
- * Empty statistic if you are not interested in storing statistics in your
- * tree. Use this as a template for your own.
- */
-class EmptyStatistic
-{
- public:
- EmptyStatistic() { }
- ~EmptyStatistic() { }
-
- /**
- * This constructor is called when a node is finished being created. The
- * node is finished, and its children are finished, but it is not
- * necessarily true that the statistics of other nodes are initialized yet.
- *
- * @param node Node which this corresponds to.
- */
- template<typename TreeType>
- EmptyStatistic(TreeType& /* node */) { }
-
- public:
- /**
- * Returns a string representation of this object.
- */
- std::string ToString() const
- {
- std::stringstream convert;
- convert << "EmptyStatistic [" << this << "]" << std::endl;
- return convert.str();
- }
-};
-
-}; // namespace tree
-}; // namespace mlpack
-
-#endif // __MLPACK_CORE_TREE_STATISTIC_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/statistic.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/statistic.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/statistic.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/statistic.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,65 @@
+/**
+ * @file statistic.hpp
+ *
+ * Definition of the policy type for the statistic class.
+ *
+ * You should define your own statistic that looks like EmptyStatistic.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef __MLPACK_CORE_TREE_STATISTIC_HPP
+#define __MLPACK_CORE_TREE_STATISTIC_HPP
+
+namespace mlpack {
+namespace tree {
+
+/**
+ * Empty statistic if you are not interested in storing statistics in your
+ * tree. Use this as a template for your own.
+ */
+class EmptyStatistic
+{
+ public:
+ EmptyStatistic() { }
+ ~EmptyStatistic() { }
+
+ /**
+ * This constructor is called when a node is finished being created. The
+ * node is finished, and its children are finished, but it is not
+ * necessarily true that the statistics of other nodes are initialized yet.
+ *
+ * @param node Node which this corresponds to.
+ */
+ template<typename TreeType>
+ EmptyStatistic(TreeType& /* node */) { }
+
+ public:
+ /**
+ * Returns a string representation of this object.
+ */
+ std::string ToString() const
+ {
+ std::stringstream convert;
+ convert << "EmptyStatistic [" << this << "]" << std::endl;
+ return convert.str();
+ }
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_TREE_STATISTIC_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/tree_traits.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/tree/tree_traits.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/tree_traits.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,112 +0,0 @@
-/**
- * @file tree_traits.hpp
- * @author Ryan Curtin
- *
- * This file implements the basic, unspecialized TreeTraits class, which
- * provides information about tree types. If you create a tree class, you
- * should specialize this class with the characteristics of your tree.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_TREE_TREE_TRAITS_HPP
-#define __MLPACK_CORE_TREE_TREE_TRAITS_HPP
-
-namespace mlpack {
-namespace tree {
-
-/**
- * The TreeTraits class provides compile-time information on the characteristics
- * of a given tree type. These include traits such as whether or not a node
- * knows the distance to its parent node, or whether or not the subspaces
- * represented by children can overlap.
- *
- * These traits can be used for static compile-time optimization:
- *
- * @code
- * // This if statement will be optimized out at compile time!
- * if (TreeTraits<TreeType>::HasOverlappingChildren == false)
- * {
- * // Do a simpler computation because no children overlap.
- * }
- * else
- * {
- * // Do the full, complex calculation.
- * }
- * @endcode
- *
- * The traits can also be used in conjunction with SFINAE to write specialized
- * versions of functions:
- *
- * @code
- * template<typename TreeType>
- * void Compute(TreeType& node,
- * boost::enable_if<
- * TreeTraits<TreeType>::HasParentDistance>::type*)
- * {
- * // Computation with TreeType::ParentDistance().
- * }
- *
- * template<typename TreeType>
- * void Compute(TreeType& node,
- * boost::enable_if<
- * !TreeTraits<TreeType>::HasParentDistance>::type*)
- * {
- * // Computation without TreeType::ParentDistance().
- * }
- * @endcode
- *
- * In those two examples, the boost::enable_if<> class takes a boolean template
- * parameter which allows that function to be called when the boolean is true.
- *
- * Each trait must be a static const value and not a function; only const values
- * can be used as template parameters (with the exception of constexprs, which
- * are a C++11 feature; but MLPACK is not using C++11). By default (the
- * unspecialized implementation of TreeTraits), each parameter is set to make as
- * few assumptions about the tree as possible; so, even if TreeTraits is not
- * specialized for a particular tree type, tree-based algorithms should still
- * work.
- *
- * When you write your own tree, you must specialize the TreeTraits class to
- * your tree type and set the corresponding values appropriately. See
- * mlpack/core/tree/binary_space_tree/traits.hpp for an example.
- */
-template<typename TreeType>
-class TreeTraits
-{
- public:
- /**
- * This is true if TreeType::ParentDistance() exists and works. The
- * ParentDistance() function returns the distance between the center of a node
- * and the center of its parent.
- */
- static const bool HasParentDistance = false;
-
- /**
- * This is true if the subspaces represented by the children of a node can
- * overlap.
- */
- static const bool HasOverlappingChildren = true;
-
- /**
- * This is true if Point(0) is the centroid of the node.
- */
- static const bool FirstPointIsCentroid = false;
-};
-
-}; // namespace tree
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/tree_traits.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/tree/tree_traits.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/tree_traits.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/tree/tree_traits.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,112 @@
+/**
+ * @file tree_traits.hpp
+ * @author Ryan Curtin
+ *
+ * This file implements the basic, unspecialized TreeTraits class, which
+ * provides information about tree types. If you create a tree class, you
+ * should specialize this class with the characteristics of your tree.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_TREE_TREE_TRAITS_HPP
+#define __MLPACK_CORE_TREE_TREE_TRAITS_HPP
+
+namespace mlpack {
+namespace tree {
+
+/**
+ * The TreeTraits class provides compile-time information on the characteristics
+ * of a given tree type. These include traits such as whether or not a node
+ * knows the distance to its parent node, or whether or not the subspaces
+ * represented by children can overlap.
+ *
+ * These traits can be used for static compile-time optimization:
+ *
+ * @code
+ * // This if statement will be optimized out at compile time!
+ * if (TreeTraits<TreeType>::HasOverlappingChildren == false)
+ * {
+ * // Do a simpler computation because no children overlap.
+ * }
+ * else
+ * {
+ * // Do the full, complex calculation.
+ * }
+ * @endcode
+ *
+ * The traits can also be used in conjunction with SFINAE to write specialized
+ * versions of functions:
+ *
+ * @code
+ * template<typename TreeType>
+ * void Compute(TreeType& node,
+ * boost::enable_if<
+ * TreeTraits<TreeType>::HasParentDistance>::type*)
+ * {
+ * // Computation with TreeType::ParentDistance().
+ * }
+ *
+ * template<typename TreeType>
+ * void Compute(TreeType& node,
+ * boost::enable_if<
+ * !TreeTraits<TreeType>::HasParentDistance>::type*)
+ * {
+ * // Computation without TreeType::ParentDistance().
+ * }
+ * @endcode
+ *
+ * In those two examples, the boost::enable_if<> class takes a boolean template
+ * parameter which allows that function to be called when the boolean is true.
+ *
+ * Each trait must be a static const value and not a function; only const values
+ * can be used as template parameters (with the exception of constexprs, which
+ * are a C++11 feature; but MLPACK is not using C++11). By default (the
+ * unspecialized implementation of TreeTraits), each parameter is set to make as
+ * few assumptions about the tree as possible; so, even if TreeTraits is not
+ * specialized for a particular tree type, tree-based algorithms should still
+ * work.
+ *
+ * When you write your own tree, you must specialize the TreeTraits class to
+ * your tree type and set the corresponding values appropriately. See
+ * mlpack/core/tree/binary_space_tree/traits.hpp for an example.
+ */
+template<typename TreeType>
+class TreeTraits
+{
+ public:
+ /**
+ * This is true if TreeType::ParentDistance() exists and works. The
+ * ParentDistance() function returns the distance between the center of a node
+ * and the center of its parent.
+ */
+ static const bool HasParentDistance = false;
+
+ /**
+ * This is true if the subspaces represented by the children of a node can
+ * overlap.
+ */
+ static const bool HasOverlappingChildren = true;
+
+ /**
+ * This is true if Point(0) is the centroid of the node.
+ */
+ static const bool FirstPointIsCentroid = false;
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/cli.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,750 +0,0 @@
-/**
- * @file cli.cpp
- * @author Matthew Amidon
- *
- * Implementation of the CLI module for parsing parameters.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <list>
-#include <boost/program_options.hpp>
-#include <boost/any.hpp>
-#include <boost/scoped_ptr.hpp>
-#include <iostream>
-#include <string>
-
-#ifndef _WIN32
- #include <sys/time.h> // For Linux.
- #include <execinfo.h>
-#else
- #include <winsock.h> // timeval on Windows.
- #include <windows.h> // GetSystemTimeAsFileTime() on Windows.
-// gettimeofday() has no equivalent; we will need to write extra code for that.
- #if defined(_MSC_VER) || defined(_MSC_EXTENSCLINS)
- #define DELTA_EPOCH_IN_MICROSECS 11644473600000000Ui64
- #else
- #define DELTA_EPOCH_IN_MICROSECS 11644473600000000ULL
- #endif
-#endif // _WIN32
-
-#include "cli.hpp"
-#include "log.hpp"
-
-#include "option.hpp"
-
-using namespace mlpack;
-using namespace mlpack::util;
-
-CLI* CLI::singleton = NULL;
-
-/* For clarity, we will alias boost's namespace. */
-namespace po = boost::program_options;
-
-// Fake ProgramDoc in case none is supplied.
-static ProgramDoc emptyProgramDoc = ProgramDoc("", "");
-
-/* Constructors, Destructors, Copy */
-/* Make the constructor private, to preclude unauthorized instances */
-CLI::CLI() : desc("Allowed Options") , didParse(false), doc(&emptyProgramDoc)
-{
- return;
-}
-
-/**
- * Initialize desc with a particular name.
- *
- * @param optionsName Name of the module, as far as boost is concerned.
- */
-CLI::CLI(const std::string& optionsName) :
- desc(optionsName.c_str()), didParse(false), doc(&emptyProgramDoc)
-{
- return;
-}
-
-// Private copy constructor; don't want copies floating around.
-CLI::CLI(const CLI& other) : desc(other.desc),
- didParse(false), doc(&emptyProgramDoc)
-{
- return;
-}
-
-CLI::~CLI()
-{
- // Terminate the program timer.
- Timer::Stop("total_time");
-
- // Did the user ask for verbose output? If so we need to print everything.
- // But only if the user did not ask for help or info.
- if (HasParam("verbose") && !HasParam("help") && !HasParam("info"))
- {
- Log::Info << std::endl << "Execution parameters:" << std::endl;
- Print();
-
- Log::Info << "Program timers:" << std::endl;
- std::map<std::string, timeval>::iterator it;
- for (it = timer.GetAllTimers().begin(); it != timer.GetAllTimers().end();
- ++it)
- {
- std::string i = (*it).first;
- Log::Info << " " << i << ": ";
- timer.PrintTimer((*it).first.c_str());
- }
- }
-
- // Notify the user if we are debugging, but only if we actually parsed the
- // options. This way this output doesn't show up inexplicably for someone who
- // may not have wanted it there (i.e. in Boost unit tests).
- if (didParse)
- Log::Debug << "Compiled with debugging symbols." << std::endl;
-
- return;
-}
-
-/* Methods */
-
-/**
- * Adds a parameter to the hierarchy. Use char* and not std::string since the
- * vast majority of use cases will be literal strings.
- *
- * @param identifier The name of the parameter.
- * @param description Short string description of the parameter.
- * @param alias An alias for the parameter.
- * @param required Indicates if parameter must be set on command line.
- */
-void CLI::Add(const std::string& path,
- const std::string& description,
- const std::string& alias,
- bool required)
-{
- po::options_description& desc = CLI::GetSingleton().desc;
-
- // Must make use of boost option name syntax.
- std::string progOptId = alias.length() ? path + "," + alias : path;
-
- // Deal with a required alias.
- AddAlias(alias, path);
-
- // Add the option to boost::program_options.
- desc.add_options()(progOptId.c_str(), description.c_str());
-
- // Make sure the description, etc. ends up in gmap.
- gmap_t& gmap = GetSingleton().globalValues;
-
- ParamData data;
- data.desc = description;
- data.tname = "";
- data.name = path;
- data.isFlag = false;
- data.wasPassed = false;
-
- gmap[path] = data;
-
- // If the option is required, add it to the required options list.
- if (required)
- GetSingleton().requiredOptions.push_front(path);
-
- return;
-}
-
-/*
- * Adds an alias mapping for a given parameter.
- *
- * @param alias The alias we will use for the parameter.
- * @param original The name of the actual parameter we will be mapping to.
- */
-void CLI::AddAlias(const std::string& alias, const std::string& original)
-{
- //Conduct the mapping
- if (alias.length())
- {
- amap_t& amap = GetSingleton().aliasValues;
- amap[alias] = original;
- }
-}
-
-/*
- * @brief Adds a flag parameter to CLI.
- */
-void CLI::AddFlag(const std::string& identifier,
- const std::string& description,
- const std::string& alias)
-{
- // Reuse functionality from add
- Add(identifier, description, alias, false);
-
- // Insert the proper metadata in gmap.
- gmap_t& gmap = GetSingleton().globalValues;
-
- ParamData data;
- data.desc = description;
- data.tname = TYPENAME(bool);
- data.name = std::string(identifier);
- data.isFlag = true;
- data.wasPassed = false;
-
- gmap[data.name] = data;
-}
-
-std::string CLI::AliasReverseLookup(const std::string& value)
-{
- amap_t& amap = GetSingleton().aliasValues;
- amap_t::iterator iter;
- for (iter = amap.begin(); iter != amap.end(); ++iter)
- if (iter->second == value) // Found our match.
- return iter->first;
-
- return ""; // Nothing found.
-}
-
-/**
- * Parses the parameters for 'help' and 'info'
- * If found, will print out the appropriate information
- * and kill the program.
- */
-void CLI::DefaultMessages()
-{
- // Default help message
- if (HasParam("help"))
- {
- Log::Info.ignoreInput = false;
- PrintHelp();
- exit(0); // The user doesn't want to run the program, he wants help.
- }
-
- if (HasParam("info"))
- {
- Log::Info.ignoreInput = false;
- std::string str = GetParam<std::string>("info");
-
- // The info node should always be there, but the user may not have specified
- // anything.
- if (str != "")
- {
- PrintHelp(str);
- exit(0);
- }
-
- // Otherwise just print the generalized help.
- PrintHelp();
- exit(0);
- }
-
- if (HasParam("verbose"))
- {
- // Give [INFO ] output.
- Log::Info.ignoreInput = false;
- }
-
- // Notify the user if we are debugging. This is not done in the constructor
- // because the output streams may not be set up yet. We also don't want this
- // message twice if the user just asked for help or information.
- Log::Debug << "Compiled with debugging symbols." << std::endl;
-}
-
-/**
- * Destroy the CLI object. This resets the pointer to the singleton, so in case
- * someone tries to access it after destruction, a new one will be made (the
- * program will not fail).
- */
-void CLI::Destroy()
-{
- if (singleton != NULL)
- {
- delete singleton;
- singleton = NULL; // Reset pointer.
- }
-}
-
-/**
- * See if the specified flag was found while parsing.
- *
- * @param identifier The name of the parameter in question.
- */
-bool CLI::HasParam(const std::string& key)
-{
- std::string used_key = key;
- po::variables_map vmap = GetSingleton().vmap;
- gmap_t& gmap = GetSingleton().globalValues;
-
- // Take any possible alias into account.
- amap_t& amap = GetSingleton().aliasValues;
- if (amap.count(key))
- used_key = amap[key];
-
- // Does the parameter exist at all?
- int isInGmap = gmap.count(used_key);
-
- // Check if the parameter is boolean; if it is, we just want to see if it was
- // passed.
- if (isInGmap)
- return gmap[used_key].wasPassed;
-
- // The parameter was not passed in; return false.
- return false;
-}
-
-/**
- * Hyphenate a string or split it onto multiple 80-character lines, with some
- * amount of padding on each line. This is used for option output.
- *
- * @param str String to hyphenate (splits are on ' ').
- * @param padding Amount of padding on the left for each new line.
- */
-std::string CLI::HyphenateString(const std::string& str, int padding)
-{
- size_t margin = 80 - padding;
- if (str.length() < margin)
- return str;
- std::string out("");
- unsigned int pos = 0;
- // First try to look as far as possible.
- while (pos < str.length())
- {
- size_t splitpos;
- // Check that we don't have a newline first.
- splitpos = str.find('\n', pos);
- if (splitpos == std::string::npos || splitpos > (pos + margin))
- {
- // We did not find a newline.
- if (str.length() - pos < margin)
- {
- splitpos = str.length(); // The rest fits on one line.
- }
- else
- {
- splitpos = str.rfind(' ', margin + pos); // Find nearest space.
- if (splitpos <= pos || splitpos == std::string::npos) // Not found.
- splitpos = pos + margin;
- }
- }
- out += str.substr(pos, (splitpos - pos));
- if (splitpos < str.length())
- {
- out += '\n';
- out += std::string(padding, ' ');
- }
-
- pos = splitpos;
- if (str[pos] == ' ' || str[pos] == '\n')
- pos++;
- }
- return out;
-}
-
-/**
- * Grab the description of the specified node.
- *
- * @param identifier Name of the node in question.
- * @return Description of the node in question.
- */
-std::string CLI::GetDescription(const std::string& identifier)
-{
- gmap_t& gmap = GetSingleton().globalValues;
- std::string name = std::string(identifier);
-
- //Take any possible alias into account
- amap_t& amap = GetSingleton().aliasValues;
- if (amap.count(name))
- name = amap[name];
-
-
- if (gmap.count(name))
- return gmap[name].desc;
- else
- return "";
-
-}
-
-// Returns the sole instance of this class.
-CLI& CLI::GetSingleton()
-{
- if (singleton == NULL)
- singleton = new CLI();
-
- return *singleton;
-}
-
-/**
- * Parses the commandline for arguments.
- *
- * @param argc The number of arguments on the commandline.
- * @param argv The array of arguments as strings
- */
-void CLI::ParseCommandLine(int argc, char** line)
-{
- Timer::Start("total_time");
-
- po::variables_map& vmap = GetSingleton().vmap;
- po::options_description& desc = GetSingleton().desc;
-
- // Parse the command line, place the options & values into vmap
- try
- {
- // Get the basic_parsed_options
- po::basic_parsed_options<char> bpo(
- po::parse_command_line(argc, line, desc));
-
- // Look for any duplicate parameters, removing duplicate flags
- RemoveDuplicateFlags(bpo);
-
- // Record the basic_parsed_options
- po::store(bpo, vmap);
- }
- catch (std::exception& ex)
- {
- Log::Fatal << "Caught exception from parsing command line:\t";
- Log::Fatal << ex.what() << std::endl;
- }
-
- // Flush the buffer, make sure changes are propagated to vmap
- po::notify(vmap);
- UpdateGmap();
- DefaultMessages();
- RequiredOptions();
-}
-
-/*
- * Removes duplicate flags.
- *
- * @param bpo The basic_program_options to remove duplicate flags from.
- */
-void CLI::RemoveDuplicateFlags(po::basic_parsed_options<char>& bpo)
-{
- // Iterate over all the program_options, looking for duplicate parameters
- for (unsigned int i = 0; i < bpo.options.size(); i++)
- {
- for (unsigned int j = i + 1; j < bpo.options.size(); j++)
- {
- if (bpo.options[i].string_key == bpo.options[j].string_key)
- {
- // If a duplicate is found, check to see if either one has a value
- if (bpo.options[i].value.size() == 0 &&
- bpo.options[j].value.size() == 0)
- {
- // If neither has a value, consider it a duplicate flag and remove the
- // duplicate. It's important to not break out of this loop because
- // there might be another duplicate later on in the vector.
- bpo.options.erase(bpo.options.begin()+j);
- }
- else
- {
- // If one or both has a value, produce an error and politely
- // terminate. We pull the name from the original_tokens, rather than
- // from the string_key, because the string_key is the parameter after
- // aliases have been expanded.
- Log::Fatal << "\"" << bpo.options[j].original_tokens[0] << "\""
- << " is defined multiple times." << std::endl;
- }
- }
- }
- }
-}
-
-/**
- * Parses a stream for arguments
- *
- * @param stream The stream to be parsed.
- */
-void CLI::ParseStream(std::istream& stream)
-{
- po::variables_map& vmap = GetSingleton().vmap;
- po::options_description& desc = GetSingleton().desc;
-
- // Parse the stream; place options & values into vmap.
- try
- {
- po::store(po::parse_config_file(stream, desc), vmap);
- }
- catch (std::exception& ex)
- {
- Log::Fatal << ex.what() << std::endl;
- }
-
- // Flush the buffer; make sure changes are propagated to vmap.
- po::notify(vmap);
-
- UpdateGmap();
- DefaultMessages();
- RequiredOptions();
-
- Timer::Start("total_time");
-}
-
-/* Prints out the current hierarchy. */
-void CLI::Print()
-{
- gmap_t& gmap = GetSingleton().globalValues;
- gmap_t::iterator iter;
-
- // Print out all the values.
- for (iter = gmap.begin(); iter != gmap.end(); ++iter)
- {
- std::string key = iter->first;
-
- Log::Info << " " << key << ": ";
-
- // Now, figure out what type it is, and print it.
- // We can handle strings, ints, bools, floats, doubles.
- ParamData data = iter->second;
- if (data.tname == TYPENAME(std::string))
- {
- std::string value = GetParam<std::string>(key.c_str());
- if (value == "")
- Log::Info << "\"\"";
- Log::Info << value;
- }
- else if (data.tname == TYPENAME(int))
- {
- int value = GetParam<int>(key.c_str());
- Log::Info << value;
- }
- else if (data.tname == TYPENAME(bool))
- {
- bool value = HasParam(key.c_str());
- Log::Info << (value ? "true" : "false");
- }
- else if (data.tname == TYPENAME(float))
- {
- float value = GetParam<float>(key.c_str());
- Log::Info << value;
- }
- else if (data.tname == TYPENAME(double))
- {
- double value = GetParam<double>(key.c_str());
- Log::Info << value;
- }
- else
- {
- // We don't know how to print this, or it's a timeval which is printed
- // later.
- Log::Info << "(Unknown data type - " << data.tname << ")";
- }
-
- Log::Info << std::endl;
- }
- Log::Info << std::endl;
-}
-
-/* Prints the descriptions of the current hierarchy. */
-void CLI::PrintHelp(const std::string& param)
-{
- std::string used_param = param;
- gmap_t& gmap = GetSingleton().globalValues;
- amap_t& amap = GetSingleton().aliasValues;
- gmap_t::iterator iter;
- ProgramDoc docs = *GetSingleton().doc;
-
- // If we pass a single param, alias it if necessary.
- if (used_param != "" && amap.count(used_param))
- used_param = amap[used_param];
-
- // Do we only want to print out one value?
- if (used_param != "" && gmap.count(used_param))
- {
- ParamData data = gmap[used_param];
- std::string alias = AliasReverseLookup(used_param);
- alias = alias.length() ? " (-" + alias + ")" : alias;
-
- // Figure out the name of the type.
- std::string type = "";
- if (data.tname == TYPENAME(std::string))
- type = " [string]";
- else if (data.tname == TYPENAME(int))
- type = " [int]";
- else if (data.tname == TYPENAME(bool))
- type = ""; // Nothing to pass for a flag.
- else if (data.tname == TYPENAME(float))
- type = " [float]";
- else if (data.tname == TYPENAME(double))
- type = " [double]";
-
- // Now, print the descriptions.
- std::string fullDesc = " --" + used_param + alias + type + " ";
-
- if (fullDesc.length() <= 32) // It all fits on one line.
- std::cout << fullDesc << std::string(32 - fullDesc.length(), ' ');
- else // We need multiple lines.
- std::cout << fullDesc << std::endl << std::string(32, ' ');
-
- std::cout << HyphenateString(data.desc, 32) << std::endl;
- return;
- }
- else if (used_param != "")
- {
- // User passed a single variable, but it doesn't exist.
- std::cerr << "Parameter --" << used_param << " does not exist."
- << std::endl;
- exit(1); // Nothing left to do.
- }
-
- // Print out the descriptions.
- if (docs.programName != "")
- {
- std::cout << docs.programName << std::endl << std::endl;
- std::cout << " " << HyphenateString(docs.documentation, 2) << std::endl
- << std::endl;
- }
- else
- std::cout << "[undocumented program]" << std::endl << std::endl;
-
- for (size_t pass = 0; pass < 2; ++pass)
- {
- if (pass == 0)
- std::cout << "Required options:" << std::endl << std::endl;
- else
- std::cout << "Options: " << std::endl << std::endl;
-
- // Print out the descriptions of everything else.
- for (iter = gmap.begin(); iter != gmap.end(); ++iter)
- {
- std::string key = iter->first;
- ParamData data = iter->second;
- std::string desc = data.desc;
- std::string alias = AliasReverseLookup(key);
- alias = alias.length() ? " (-" + alias + ")" : alias;
-
- // Is the option required or not?
- bool required = false;
- std::list<std::string>::iterator iter;
- std::list<std::string>& rOpt = GetSingleton().requiredOptions;
- for (iter = rOpt.begin(); iter != rOpt.end(); ++iter)
- if ((*iter) == key)
- required = true;
-
- if ((pass == 0) && !required)
- continue; // Don't print this one.
- if ((pass == 1) && required)
- continue; // Don't print this one.
-
- if (pass == 1) // Append default value to description.
- {
- desc += " Default value ";
- std::stringstream tmp;
-
- if (data.tname == TYPENAME(std::string))
- tmp << "'" << boost::any_cast<std::string>(data.value) << "'.";
- else if (data.tname == TYPENAME(int))
- tmp << boost::any_cast<int>(data.value) << '.';
- else if (data.tname == TYPENAME(bool))
- desc = data.desc; // No extra output for that.
- else if (data.tname == TYPENAME(float))
- tmp << boost::any_cast<float>(data.value) << '.';
- else if (data.tname == TYPENAME(double))
- tmp << boost::any_cast<double>(data.value) << '.';
-
- desc += tmp.str();
- }
-
- // Figure out the name of the type.
- std::string type = "";
- if (data.tname == TYPENAME(std::string))
- type = " [string]";
- else if (data.tname == TYPENAME(int))
- type = " [int]";
- else if (data.tname == TYPENAME(bool))
- type = ""; // Nothing to pass for a flag.
- else if (data.tname == TYPENAME(float))
- type = " [float]";
- else if (data.tname == TYPENAME(double))
- type = " [double]";
-
- // Now, print the descriptions.
- std::string fullDesc = " --" + key + alias + type + " ";
-
- if (fullDesc.length() <= 32) // It all fits on one line.
- std::cout << fullDesc << std::string(32 - fullDesc.length(), ' ');
- else // We need multiple lines.
- std::cout << fullDesc << std::endl << std::string(32, ' ');
-
- std::cout << HyphenateString(desc, 32) << std::endl;
- }
-
- std::cout << std::endl;
-
- }
-
- // Helpful information at the bottom of the help output, to point the user to
- // citations and better documentation (if necessary). See ticket #201.
- std::cout << HyphenateString("For further information, including relevant "
- "papers, citations, and theory, consult the documentation found at "
- "http://www.mlpack.org or included with your distribution of MLPACK.", 0)
- << std::endl;
-}
-
-/**
- * Registers a ProgramDoc object, which contains documentation about the
- * program.
- *
- * @param doc Pointer to the ProgramDoc object.
- */
-void CLI::RegisterProgramDoc(ProgramDoc* doc)
-{
- // Only register the doc if it is not the dummy object we created at the
- // beginning of the file (as a default value in case this is never called).
- if (doc != &emptyProgramDoc)
- GetSingleton().doc = doc;
-}
-
-/**
- * Checks that all parameters specified as required have been specified on the
- * command line. If they havent, prints an error message and kills the program.
- */
-void CLI::RequiredOptions()
-{
- po::variables_map& vmap = GetSingleton().vmap;
- std::list<std::string> rOpt = GetSingleton().requiredOptions;
-
- // Now, warn the user if they missed any required options.
- std::list<std::string>::iterator iter;
- for (iter = rOpt.begin(); iter != rOpt.end(); ++iter)
- {
- std::string str = *iter;
- if (!vmap.count(str))
- { // If a required option isn't there...
- Log::Fatal << "Required option --" << str.c_str() << " is undefined."
- << std::endl;
- }
- }
-}
-
-/**
- * Parses the values given on the command line, overriding any default values.
- */
-void CLI::UpdateGmap()
-{
- gmap_t& gmap = GetSingleton().globalValues;
- po::variables_map& vmap = GetSingleton().vmap;
-
- // Iterate through vmap, and overwrite default values with anything found on
- // command line.
- po::variables_map::iterator i;
- for (i = vmap.begin(); i != vmap.end(); ++i)
- {
- ParamData param;
- if (gmap.count(i->first)) // We need to preserve certain data
- param = gmap[i->first];
-
- param.value = vmap[i->first].value();
- param.wasPassed = true;
- gmap[i->first] = param;
- }
-}
-
-// Add help parameter.
-PARAM_FLAG("help", "Default help info.", "h");
-PARAM_STRING("info", "Get help on a specific module or option.", "", "");
-PARAM_FLAG("verbose", "Display informational messages and the full list of "
- "parameters and timers at the end of execution.", "v");
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/cli.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,750 @@
+/**
+ * @file cli.cpp
+ * @author Matthew Amidon
+ *
+ * Implementation of the CLI module for parsing parameters.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <list>
+#include <boost/program_options.hpp>
+#include <boost/any.hpp>
+#include <boost/scoped_ptr.hpp>
+#include <iostream>
+#include <string>
+
+#ifndef _WIN32
+ #include <sys/time.h> // For Linux.
+ #include <execinfo.h>
+#else
+ #include <winsock.h> // timeval on Windows.
+ #include <windows.h> // GetSystemTimeAsFileTime() on Windows.
+// gettimeofday() has no equivalent; we will need to write extra code for that.
+ #if defined(_MSC_VER) || defined(_MSC_EXTENSCLINS)
+ #define DELTA_EPOCH_IN_MICROSECS 11644473600000000Ui64
+ #else
+ #define DELTA_EPOCH_IN_MICROSECS 11644473600000000ULL
+ #endif
+#endif // _WIN32
+
+#include "cli.hpp"
+#include "log.hpp"
+
+#include "option.hpp"
+
+using namespace mlpack;
+using namespace mlpack::util;
+
+CLI* CLI::singleton = NULL;
+
+/* For clarity, we will alias boost's namespace. */
+namespace po = boost::program_options;
+
+// Fake ProgramDoc in case none is supplied.
+static ProgramDoc emptyProgramDoc = ProgramDoc("", "");
+
+/* Constructors, Destructors, Copy */
+/* Make the constructor private, to preclude unauthorized instances */
+CLI::CLI() : desc("Allowed Options") , didParse(false), doc(&emptyProgramDoc)
+{
+ return;
+}
+
+/**
+ * Initialize desc with a particular name.
+ *
+ * @param optionsName Name of the module, as far as boost is concerned.
+ */
+CLI::CLI(const std::string& optionsName) :
+ desc(optionsName.c_str()), didParse(false), doc(&emptyProgramDoc)
+{
+ return;
+}
+
+// Private copy constructor; don't want copies floating around.
+CLI::CLI(const CLI& other) : desc(other.desc),
+ didParse(false), doc(&emptyProgramDoc)
+{
+ return;
+}
+
+CLI::~CLI()
+{
+ // Terminate the program timer.
+ Timer::Stop("total_time");
+
+ // Did the user ask for verbose output? If so we need to print everything.
+ // But only if the user did not ask for help or info.
+ if (HasParam("verbose") && !HasParam("help") && !HasParam("info"))
+ {
+ Log::Info << std::endl << "Execution parameters:" << std::endl;
+ Print();
+
+ Log::Info << "Program timers:" << std::endl;
+ std::map<std::string, timeval>::iterator it;
+ for (it = timer.GetAllTimers().begin(); it != timer.GetAllTimers().end();
+ ++it)
+ {
+ std::string i = (*it).first;
+ Log::Info << " " << i << ": ";
+ timer.PrintTimer((*it).first.c_str());
+ }
+ }
+
+ // Notify the user if we are debugging, but only if we actually parsed the
+ // options. This way this output doesn't show up inexplicably for someone who
+ // may not have wanted it there (i.e. in Boost unit tests).
+ if (didParse)
+ Log::Debug << "Compiled with debugging symbols." << std::endl;
+
+ return;
+}
+
+/* Methods */
+
+/**
+ * Adds a parameter to the hierarchy. Use char* and not std::string since the
+ * vast majority of use cases will be literal strings.
+ *
+ * @param identifier The name of the parameter.
+ * @param description Short string description of the parameter.
+ * @param alias An alias for the parameter.
+ * @param required Indicates if parameter must be set on command line.
+ */
+void CLI::Add(const std::string& path,
+ const std::string& description,
+ const std::string& alias,
+ bool required)
+{
+ po::options_description& desc = CLI::GetSingleton().desc;
+
+ // Must make use of boost option name syntax.
+ std::string progOptId = alias.length() ? path + "," + alias : path;
+
+ // Deal with a required alias.
+ AddAlias(alias, path);
+
+ // Add the option to boost::program_options.
+ desc.add_options()(progOptId.c_str(), description.c_str());
+
+ // Make sure the description, etc. ends up in gmap.
+ gmap_t& gmap = GetSingleton().globalValues;
+
+ ParamData data;
+ data.desc = description;
+ data.tname = "";
+ data.name = path;
+ data.isFlag = false;
+ data.wasPassed = false;
+
+ gmap[path] = data;
+
+ // If the option is required, add it to the required options list.
+ if (required)
+ GetSingleton().requiredOptions.push_front(path);
+
+ return;
+}
+
+/*
+ * Adds an alias mapping for a given parameter.
+ *
+ * @param alias The alias we will use for the parameter.
+ * @param original The name of the actual parameter we will be mapping to.
+ */
+void CLI::AddAlias(const std::string& alias, const std::string& original)
+{
+ //Conduct the mapping
+ if (alias.length())
+ {
+ amap_t& amap = GetSingleton().aliasValues;
+ amap[alias] = original;
+ }
+}
+
+/*
+ * @brief Adds a flag parameter to CLI.
+ */
+void CLI::AddFlag(const std::string& identifier,
+ const std::string& description,
+ const std::string& alias)
+{
+ // Reuse functionality from add
+ Add(identifier, description, alias, false);
+
+ // Insert the proper metadata in gmap.
+ gmap_t& gmap = GetSingleton().globalValues;
+
+ ParamData data;
+ data.desc = description;
+ data.tname = TYPENAME(bool);
+ data.name = std::string(identifier);
+ data.isFlag = true;
+ data.wasPassed = false;
+
+ gmap[data.name] = data;
+}
+
+std::string CLI::AliasReverseLookup(const std::string& value)
+{
+ amap_t& amap = GetSingleton().aliasValues;
+ amap_t::iterator iter;
+ for (iter = amap.begin(); iter != amap.end(); ++iter)
+ if (iter->second == value) // Found our match.
+ return iter->first;
+
+ return ""; // Nothing found.
+}
+
+/**
+ * Parses the parameters for 'help' and 'info'
+ * If found, will print out the appropriate information
+ * and kill the program.
+ */
+void CLI::DefaultMessages()
+{
+ // Default help message
+ if (HasParam("help"))
+ {
+ Log::Info.ignoreInput = false;
+ PrintHelp();
+ exit(0); // The user doesn't want to run the program, he wants help.
+ }
+
+ if (HasParam("info"))
+ {
+ Log::Info.ignoreInput = false;
+ std::string str = GetParam<std::string>("info");
+
+ // The info node should always be there, but the user may not have specified
+ // anything.
+ if (str != "")
+ {
+ PrintHelp(str);
+ exit(0);
+ }
+
+ // Otherwise just print the generalized help.
+ PrintHelp();
+ exit(0);
+ }
+
+ if (HasParam("verbose"))
+ {
+ // Give [INFO ] output.
+ Log::Info.ignoreInput = false;
+ }
+
+ // Notify the user if we are debugging. This is not done in the constructor
+ // because the output streams may not be set up yet. We also don't want this
+ // message twice if the user just asked for help or information.
+ Log::Debug << "Compiled with debugging symbols." << std::endl;
+}
+
+/**
+ * Destroy the CLI object. This resets the pointer to the singleton, so in case
+ * someone tries to access it after destruction, a new one will be made (the
+ * program will not fail).
+ */
+void CLI::Destroy()
+{
+ if (singleton != NULL)
+ {
+ delete singleton;
+ singleton = NULL; // Reset pointer.
+ }
+}
+
+/**
+ * See if the specified flag was found while parsing.
+ *
+ * @param identifier The name of the parameter in question.
+ */
+bool CLI::HasParam(const std::string& key)
+{
+ std::string used_key = key;
+ po::variables_map vmap = GetSingleton().vmap;
+ gmap_t& gmap = GetSingleton().globalValues;
+
+ // Take any possible alias into account.
+ amap_t& amap = GetSingleton().aliasValues;
+ if (amap.count(key))
+ used_key = amap[key];
+
+ // Does the parameter exist at all?
+ int isInGmap = gmap.count(used_key);
+
+ // Check if the parameter is boolean; if it is, we just want to see if it was
+ // passed.
+ if (isInGmap)
+ return gmap[used_key].wasPassed;
+
+ // The parameter was not passed in; return false.
+ return false;
+}
+
+/**
+ * Hyphenate a string or split it onto multiple 80-character lines, with some
+ * amount of padding on each line. This is used for option output.
+ *
+ * @param str String to hyphenate (splits are on ' ').
+ * @param padding Amount of padding on the left for each new line.
+ */
+std::string CLI::HyphenateString(const std::string& str, int padding)
+{
+ size_t margin = 80 - padding;
+ if (str.length() < margin)
+ return str;
+ std::string out("");
+ unsigned int pos = 0;
+ // First try to look as far as possible.
+ while (pos < str.length())
+ {
+ size_t splitpos;
+ // Check that we don't have a newline first.
+ splitpos = str.find('\n', pos);
+ if (splitpos == std::string::npos || splitpos > (pos + margin))
+ {
+ // We did not find a newline.
+ if (str.length() - pos < margin)
+ {
+ splitpos = str.length(); // The rest fits on one line.
+ }
+ else
+ {
+ splitpos = str.rfind(' ', margin + pos); // Find nearest space.
+ if (splitpos <= pos || splitpos == std::string::npos) // Not found.
+ splitpos = pos + margin;
+ }
+ }
+ out += str.substr(pos, (splitpos - pos));
+ if (splitpos < str.length())
+ {
+ out += '\n';
+ out += std::string(padding, ' ');
+ }
+
+ pos = splitpos;
+ if (str[pos] == ' ' || str[pos] == '\n')
+ pos++;
+ }
+ return out;
+}
+
+/**
+ * Grab the description of the specified node.
+ *
+ * @param identifier Name of the node in question.
+ * @return Description of the node in question.
+ */
+std::string CLI::GetDescription(const std::string& identifier)
+{
+ gmap_t& gmap = GetSingleton().globalValues;
+ std::string name = std::string(identifier);
+
+ //Take any possible alias into account
+ amap_t& amap = GetSingleton().aliasValues;
+ if (amap.count(name))
+ name = amap[name];
+
+
+ if (gmap.count(name))
+ return gmap[name].desc;
+ else
+ return "";
+
+}
+
+// Returns the sole instance of this class.
+CLI& CLI::GetSingleton()
+{
+ if (singleton == NULL)
+ singleton = new CLI();
+
+ return *singleton;
+}
+
+/**
+ * Parses the commandline for arguments.
+ *
+ * @param argc The number of arguments on the commandline.
+ * @param argv The array of arguments as strings
+ */
+void CLI::ParseCommandLine(int argc, char** line)
+{
+ Timer::Start("total_time");
+
+ po::variables_map& vmap = GetSingleton().vmap;
+ po::options_description& desc = GetSingleton().desc;
+
+ // Parse the command line, place the options & values into vmap
+ try
+ {
+ // Get the basic_parsed_options
+ po::basic_parsed_options<char> bpo(
+ po::parse_command_line(argc, line, desc));
+
+ // Look for any duplicate parameters, removing duplicate flags
+ RemoveDuplicateFlags(bpo);
+
+ // Record the basic_parsed_options
+ po::store(bpo, vmap);
+ }
+ catch (std::exception& ex)
+ {
+ Log::Fatal << "Caught exception from parsing command line:\t";
+ Log::Fatal << ex.what() << std::endl;
+ }
+
+ // Flush the buffer, make sure changes are propagated to vmap
+ po::notify(vmap);
+ UpdateGmap();
+ DefaultMessages();
+ RequiredOptions();
+}
+
+/*
+ * Removes duplicate flags.
+ *
+ * @param bpo The basic_program_options to remove duplicate flags from.
+ */
+void CLI::RemoveDuplicateFlags(po::basic_parsed_options<char>& bpo)
+{
+ // Iterate over all the program_options, looking for duplicate parameters
+ for (unsigned int i = 0; i < bpo.options.size(); i++)
+ {
+ for (unsigned int j = i + 1; j < bpo.options.size(); j++)
+ {
+ if (bpo.options[i].string_key == bpo.options[j].string_key)
+ {
+ // If a duplicate is found, check to see if either one has a value
+ if (bpo.options[i].value.size() == 0 &&
+ bpo.options[j].value.size() == 0)
+ {
+ // If neither has a value, consider it a duplicate flag and remove the
+ // duplicate. It's important to not break out of this loop because
+ // there might be another duplicate later on in the vector.
+ bpo.options.erase(bpo.options.begin()+j);
+ }
+ else
+ {
+ // If one or both has a value, produce an error and politely
+ // terminate. We pull the name from the original_tokens, rather than
+ // from the string_key, because the string_key is the parameter after
+ // aliases have been expanded.
+ Log::Fatal << "\"" << bpo.options[j].original_tokens[0] << "\""
+ << " is defined multiple times." << std::endl;
+ }
+ }
+ }
+ }
+}
+
+/**
+ * Parses a stream for arguments
+ *
+ * @param stream The stream to be parsed.
+ */
+void CLI::ParseStream(std::istream& stream)
+{
+ po::variables_map& vmap = GetSingleton().vmap;
+ po::options_description& desc = GetSingleton().desc;
+
+ // Parse the stream; place options & values into vmap.
+ try
+ {
+ po::store(po::parse_config_file(stream, desc), vmap);
+ }
+ catch (std::exception& ex)
+ {
+ Log::Fatal << ex.what() << std::endl;
+ }
+
+ // Flush the buffer; make sure changes are propagated to vmap.
+ po::notify(vmap);
+
+ UpdateGmap();
+ DefaultMessages();
+ RequiredOptions();
+
+ Timer::Start("total_time");
+}
+
+/* Prints out the current hierarchy. */
+void CLI::Print()
+{
+ gmap_t& gmap = GetSingleton().globalValues;
+ gmap_t::iterator iter;
+
+ // Print out all the values.
+ for (iter = gmap.begin(); iter != gmap.end(); ++iter)
+ {
+ std::string key = iter->first;
+
+ Log::Info << " " << key << ": ";
+
+ // Now, figure out what type it is, and print it.
+ // We can handle strings, ints, bools, floats, doubles.
+ ParamData data = iter->second;
+ if (data.tname == TYPENAME(std::string))
+ {
+ std::string value = GetParam<std::string>(key.c_str());
+ if (value == "")
+ Log::Info << "\"\"";
+ Log::Info << value;
+ }
+ else if (data.tname == TYPENAME(int))
+ {
+ int value = GetParam<int>(key.c_str());
+ Log::Info << value;
+ }
+ else if (data.tname == TYPENAME(bool))
+ {
+ bool value = HasParam(key.c_str());
+ Log::Info << (value ? "true" : "false");
+ }
+ else if (data.tname == TYPENAME(float))
+ {
+ float value = GetParam<float>(key.c_str());
+ Log::Info << value;
+ }
+ else if (data.tname == TYPENAME(double))
+ {
+ double value = GetParam<double>(key.c_str());
+ Log::Info << value;
+ }
+ else
+ {
+ // We don't know how to print this, or it's a timeval which is printed
+ // later.
+ Log::Info << "(Unknown data type - " << data.tname << ")";
+ }
+
+ Log::Info << std::endl;
+ }
+ Log::Info << std::endl;
+}
+
+/* Prints the descriptions of the current hierarchy. */
+void CLI::PrintHelp(const std::string& param)
+{
+ std::string used_param = param;
+ gmap_t& gmap = GetSingleton().globalValues;
+ amap_t& amap = GetSingleton().aliasValues;
+ gmap_t::iterator iter;
+ ProgramDoc docs = *GetSingleton().doc;
+
+ // If we pass a single param, alias it if necessary.
+ if (used_param != "" && amap.count(used_param))
+ used_param = amap[used_param];
+
+ // Do we only want to print out one value?
+ if (used_param != "" && gmap.count(used_param))
+ {
+ ParamData data = gmap[used_param];
+ std::string alias = AliasReverseLookup(used_param);
+ alias = alias.length() ? " (-" + alias + ")" : alias;
+
+ // Figure out the name of the type.
+ std::string type = "";
+ if (data.tname == TYPENAME(std::string))
+ type = " [string]";
+ else if (data.tname == TYPENAME(int))
+ type = " [int]";
+ else if (data.tname == TYPENAME(bool))
+ type = ""; // Nothing to pass for a flag.
+ else if (data.tname == TYPENAME(float))
+ type = " [float]";
+ else if (data.tname == TYPENAME(double))
+ type = " [double]";
+
+ // Now, print the descriptions.
+ std::string fullDesc = " --" + used_param + alias + type + " ";
+
+ if (fullDesc.length() <= 32) // It all fits on one line.
+ std::cout << fullDesc << std::string(32 - fullDesc.length(), ' ');
+ else // We need multiple lines.
+ std::cout << fullDesc << std::endl << std::string(32, ' ');
+
+ std::cout << HyphenateString(data.desc, 32) << std::endl;
+ return;
+ }
+ else if (used_param != "")
+ {
+ // User passed a single variable, but it doesn't exist.
+ std::cerr << "Parameter --" << used_param << " does not exist."
+ << std::endl;
+ exit(1); // Nothing left to do.
+ }
+
+ // Print out the descriptions.
+ if (docs.programName != "")
+ {
+ std::cout << docs.programName << std::endl << std::endl;
+ std::cout << " " << HyphenateString(docs.documentation, 2) << std::endl
+ << std::endl;
+ }
+ else
+ std::cout << "[undocumented program]" << std::endl << std::endl;
+
+ for (size_t pass = 0; pass < 2; ++pass)
+ {
+ if (pass == 0)
+ std::cout << "Required options:" << std::endl << std::endl;
+ else
+ std::cout << "Options: " << std::endl << std::endl;
+
+ // Print out the descriptions of everything else.
+ for (iter = gmap.begin(); iter != gmap.end(); ++iter)
+ {
+ std::string key = iter->first;
+ ParamData data = iter->second;
+ std::string desc = data.desc;
+ std::string alias = AliasReverseLookup(key);
+ alias = alias.length() ? " (-" + alias + ")" : alias;
+
+ // Is the option required or not?
+ bool required = false;
+ std::list<std::string>::iterator iter;
+ std::list<std::string>& rOpt = GetSingleton().requiredOptions;
+ for (iter = rOpt.begin(); iter != rOpt.end(); ++iter)
+ if ((*iter) == key)
+ required = true;
+
+ if ((pass == 0) && !required)
+ continue; // Don't print this one.
+ if ((pass == 1) && required)
+ continue; // Don't print this one.
+
+ if (pass == 1) // Append default value to description.
+ {
+ desc += " Default value ";
+ std::stringstream tmp;
+
+ if (data.tname == TYPENAME(std::string))
+ tmp << "'" << boost::any_cast<std::string>(data.value) << "'.";
+ else if (data.tname == TYPENAME(int))
+ tmp << boost::any_cast<int>(data.value) << '.';
+ else if (data.tname == TYPENAME(bool))
+ desc = data.desc; // No extra output for that.
+ else if (data.tname == TYPENAME(float))
+ tmp << boost::any_cast<float>(data.value) << '.';
+ else if (data.tname == TYPENAME(double))
+ tmp << boost::any_cast<double>(data.value) << '.';
+
+ desc += tmp.str();
+ }
+
+ // Figure out the name of the type.
+ std::string type = "";
+ if (data.tname == TYPENAME(std::string))
+ type = " [string]";
+ else if (data.tname == TYPENAME(int))
+ type = " [int]";
+ else if (data.tname == TYPENAME(bool))
+ type = ""; // Nothing to pass for a flag.
+ else if (data.tname == TYPENAME(float))
+ type = " [float]";
+ else if (data.tname == TYPENAME(double))
+ type = " [double]";
+
+ // Now, print the descriptions.
+ std::string fullDesc = " --" + key + alias + type + " ";
+
+ if (fullDesc.length() <= 32) // It all fits on one line.
+ std::cout << fullDesc << std::string(32 - fullDesc.length(), ' ');
+ else // We need multiple lines.
+ std::cout << fullDesc << std::endl << std::string(32, ' ');
+
+ std::cout << HyphenateString(desc, 32) << std::endl;
+ }
+
+ std::cout << std::endl;
+
+ }
+
+ // Helpful information at the bottom of the help output, to point the user to
+ // citations and better documentation (if necessary). See ticket #201.
+ std::cout << HyphenateString("For further information, including relevant "
+ "papers, citations, and theory, consult the documentation found at "
+ "http://www.mlpack.org or included with your distribution of MLPACK.", 0)
+ << std::endl;
+}
+
+/**
+ * Registers a ProgramDoc object, which contains documentation about the
+ * program.
+ *
+ * @param doc Pointer to the ProgramDoc object.
+ */
+void CLI::RegisterProgramDoc(ProgramDoc* doc)
+{
+ // Only register the doc if it is not the dummy object we created at the
+ // beginning of the file (as a default value in case this is never called).
+ if (doc != &emptyProgramDoc)
+ GetSingleton().doc = doc;
+}
+
+/**
+ * Checks that all parameters specified as required have been specified on the
+ * command line. If they havent, prints an error message and kills the program.
+ */
+void CLI::RequiredOptions()
+{
+ po::variables_map& vmap = GetSingleton().vmap;
+ std::list<std::string> rOpt = GetSingleton().requiredOptions;
+
+ // Now, warn the user if they missed any required options.
+ std::list<std::string>::iterator iter;
+ for (iter = rOpt.begin(); iter != rOpt.end(); ++iter)
+ {
+ std::string str = *iter;
+ if (!vmap.count(str))
+ { // If a required option isn't there...
+ Log::Fatal << "Required option --" << str.c_str() << " is undefined."
+ << std::endl;
+ }
+ }
+}
+
+/**
+ * Parses the values given on the command line, overriding any default values.
+ */
+void CLI::UpdateGmap()
+{
+ gmap_t& gmap = GetSingleton().globalValues;
+ po::variables_map& vmap = GetSingleton().vmap;
+
+ // Iterate through vmap, and overwrite default values with anything found on
+ // command line.
+ po::variables_map::iterator i;
+ for (i = vmap.begin(); i != vmap.end(); ++i)
+ {
+ ParamData param;
+ if (gmap.count(i->first)) // We need to preserve certain data
+ param = gmap[i->first];
+
+ param.value = vmap[i->first].value();
+ param.wasPassed = true;
+ gmap[i->first] = param;
+ }
+}
+
+// Add help parameter.
+PARAM_FLAG("help", "Default help info.", "h");
+PARAM_STRING("info", "Get help on a specific module or option.", "", "");
+PARAM_FLAG("verbose", "Display informational messages and the full list of "
+ "parameters and timers at the end of execution.", "v");
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/cli.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,786 +0,0 @@
-/**
- * @file cli.hpp
- * @author Matthew Amidon
- *
- * This file implements the CLI subsystem which is intended to replace FX.
- * This can be used more or less regardless of context. In the future,
- * it might be expanded to include file I/O.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_UTIL_CLI_HPP
-#define __MLPACK_CORE_UTIL_CLI_HPP
-
-#include <list>
-#include <iostream>
-#include <map>
-#include <string>
-
-#include <boost/any.hpp>
-#include <boost/program_options.hpp>
-
-#include "timers.hpp"
-#include "cli_deleter.hpp" // To make sure we can delete the singleton.
-
-/**
- * Document an executable. Only one instance of this macro should be
- * present in your program! Therefore, use it in the main.cpp
- * (or corresponding executable) in your program.
- *
- * @see mlpack::CLI, PARAM_FLAG(), PARAM_INT(), PARAM_DOUBLE(), PARAM_STRING(),
- * PARAM_VECTOR(), PARAM_INT_REQ(), PARAM_DOUBLE_REQ(), PARAM_STRING_REQ(),
- * PARAM_VECTOR_REQ().
- *
- * @param NAME Short string representing the name of the program.
- * @param DESC Long string describing what the program does and possibly a
- * simple usage example. Newlines should not be used here; this is taken
- * care of by CLI (however, you can explicitly specify newlines to denote
- * new paragraphs).
- */
-#define PROGRAM_INFO(NAME, DESC) static mlpack::util::ProgramDoc \
- io_programdoc_dummy_object = mlpack::util::ProgramDoc(NAME, DESC);
-
-/**
- * Define a flag parameter.
- *
- * @param ID Name of the parameter.
- * @param DESC Quick description of the parameter (1-2 sentences).
- * @param ALIAS An alias for the parameter (one letter).
- *
- * @see mlpack::CLI, PROGRAM_INFO()
- *
- * @bug
- * The __COUNTER__ variable is used in most cases to guarantee a unique global
- * identifier for options declared using the PARAM_*() macros. However, not all
- * compilers have this support--most notably, gcc < 4.3. In that case, the
- * __LINE__ macro is used as an attempt to get a unique global identifier, but
- * collisions are still possible, and they produce bizarre error messages. See
- * http://mlpack.org/trac/ticket/74 for more information.
- */
-#define PARAM_FLAG(ID, DESC, ALIAS) \
- PARAM_FLAG_INTERNAL(ID, DESC, ALIAS);
-
-/**
- * Define an integer parameter.
- *
- * The parameter can then be specified on the command line with
- * --ID=value.
- *
- * @param ID Name of the parameter.
- * @param DESC Quick description of the parameter (1-2 sentences).
- * @param ALIAS An alias for the parameter (one letter).
- * @param DEF Default value of the parameter.
- *
- * @see mlpack::CLI, PROGRAM_INFO()
- *
- * @bug
- * The __COUNTER__ variable is used in most cases to guarantee a unique global
- * identifier for options declared using the PARAM_*() macros. However, not all
- * compilers have this support--most notably, gcc < 4.3. In that case, the
- * __LINE__ macro is used as an attempt to get a unique global identifier, but
- * collisions are still possible, and they produce bizarre error messages. See
- * http://mlpack.org/trac/ticket/74 for more information.
- */
-#define PARAM_INT(ID, DESC, ALIAS, DEF) \
- PARAM(int, ID, DESC, ALIAS, DEF, false)
-
-/**
- * Define a floating-point parameter. You should use PARAM_DOUBLE instead.
- *
- * The parameter can then be specified on the command line with
- * --ID=value.
- *
- * @param ID Name of the parameter.
- * @param DESC Quick description of the parameter (1-2 sentences).
- * @param ALIAS An alias for the parameter (one letter).
- * @param DEF Default value of the parameter.
- *
- * @see mlpack::CLI, PROGRAM_INFO()
- *
- * @bug
- * The __COUNTER__ variable is used in most cases to guarantee a unique global
- * identifier for options declared using the PARAM_*() macros. However, not all
- * compilers have this support--most notably, gcc < 4.3. In that case, the
- * __LINE__ macro is used as an attempt to get a unique global identifier, but
- * collisions are still possible, and they produce bizarre error messages. See
- * http://mlpack.org/trac/ticket/74 for more information.
- */
-#define PARAM_FLOAT(ID, DESC, ALIAS, DEF) \
- PARAM(float, ID, DESC, ALIAS, DEF, false)
-
-/**
- * Define a double parameter.
- *
- * The parameter can then be specified on the command line with
- * --ID=value.
- *
- * @param ID Name of the parameter.
- * @param DESC Quick description of the parameter (1-2 sentences).
- * @param ALIAS An alias for the parameter (one letter).
- * @param DEF Default value of the parameter.
- *
- * @see mlpack::CLI, PROGRAM_INFO()
- *
- * @bug
- * The __COUNTER__ variable is used in most cases to guarantee a unique global
- * identifier for options declared using the PARAM_*() macros. However, not all
- * compilers have this support--most notably, gcc < 4.3. In that case, the
- * __LINE__ macro is used as an attempt to get a unique global identifier, but
- * collisions are still possible, and they produce bizarre error messages. See
- * http://mlpack.org/trac/ticket/74 for more information.
- */
-#define PARAM_DOUBLE(ID, DESC, ALIAS, DEF) \
- PARAM(double, ID, DESC, ALIAS, DEF, false)
-
-/**
- * Define a string parameter.
- *
- * The parameter can then be specified on the command line with
- * --ID=value. If ALIAS is equal to DEF_MOD (which is set using the
- * PROGRAM_INFO() macro), the parameter can be specified with just --ID=value.
- *
- * @param ID Name of the parameter.
- * @param DESC Quick description of the parameter (1-2 sentences).
- * @param ALIAS An alias for the parameter (one letter).
- * @param DEF Default value of the parameter.
- *
- * @see mlpack::CLI, PROGRAM_INFO()
- *
- * @bug
- * The __COUNTER__ variable is used in most cases to guarantee a unique global
- * identifier for options declared using the PARAM_*() macros. However, not all
- * compilers have this support--most notably, gcc < 4.3. In that case, the
- * __LINE__ macro is used as an attempt to get a unique global identifier, but
- * collisions are still possible, and they produce bizarre error messages. See
- * http://mlpack.org/trac/ticket/74 for more information.
- */
-#define PARAM_STRING(ID, DESC, ALIAS, DEF) \
- PARAM(std::string, ID, DESC, ALIAS, DEF, false)
-
-/**
- * Define a vector parameter.
- *
- * The parameter can then be specified on the command line with
- * --ID=value.
- *
- * @param ID Name of the parameter.
- * @param DESC Quick description of the parameter (1-2 sentences).
- * @param ALIAS An alias for the parameter (one letter).
- * @param DEF Default value of the parameter.
- *
- * @see mlpack::CLI, PROGRAM_INFO()
- *
- * @bug
- * The __COUNTER__ variable is used in most cases to guarantee a unique global
- * identifier for options declared using the PARAM_*() macros. However, not all
- * compilers have this support--most notably, gcc < 4.3. In that case, the
- * __LINE__ macro is used as an attempt to get a unique global identifier, but
- * collisions are still possible, and they produce bizarre error messages. See
- * http://mlpack.org/trac/ticket/74 for more information.
- */
-#define PARAM_VECTOR(T, ID, DESC, ALIAS) \
- PARAM(std::vector<T>, ID, DESC, ALIAS, std::vector<T>(), false)
-
-// A required flag doesn't make sense and isn't given here.
-
-/**
- * Define a required integer parameter.
- *
- * The parameter must then be specified on the command line with
- * --ID=value.
- *
- * @param ID Name of the parameter.
- * @param DESC Quick description of the parameter (1-2 sentences).
- * @param ALIAS An alias for the parameter (one letter).
- *
- * @see mlpack::CLI, PROGRAM_INFO()
- *
- * @bug
- * The __COUNTER__ variable is used in most cases to guarantee a unique global
- * identifier for options declared using the PARAM_*() macros. However, not all
- * compilers have this support--most notably, gcc < 4.3. In that case, the
- * __LINE__ macro is used as an attempt to get a unique global identifier, but
- * collisions are still possible, and they produce bizarre error messages. See
- * http://mlpack.org/trac/ticket/74 for more information.
- */
-#define PARAM_INT_REQ(ID, DESC, ALIAS) PARAM(int, ID, DESC, ALIAS, 0, true)
-
-/**
- * Define a required floating-point parameter. You should probably use a double
- * instead.
- *
- * The parameter must then be specified on the command line with
- * --ID=value. If ALIAS is equal to DEF_MOD (which is set using the
- * PROGRAM_INFO() macro), the parameter can be specified with just --ID=value.
- *
- * @param ID Name of the parameter.
- * @param DESC Quick description of the parameter (1-2 sentences).
- * @param ALIAS An alias for the parameter (one letter).
- *
- * @see mlpack::CLI, PROGRAM_INFO()
- *
- * @bug
- * The __COUNTER__ variable is used in most cases to guarantee a unique global
- * identifier for options declared using the PARAM_*() macros. However, not all
- * compilers have this support--most notably, gcc < 4.3. In that case, the
- * __LINE__ macro is used as an attempt to get a unique global identifier, but
- * collisions are still possible, and they produce bizarre error messages. See
- * http://mlpack.org/trac/ticket/74 for more information.
- */
-#define PARAM_FLOAT_REQ(ID, DESC, ALIAS) PARAM(float, ID, DESC, ALIAS, 0.0f, \
- true)
-
-/**
- * Define a required double parameter.
- *
- * The parameter must then be specified on the command line with
- * --ID=value.
- *
- * @param ID Name of the parameter.
- * @param DESC Quick description of the parameter (1-2 sentences).
- * @param ALIAS An alias for the parameter (one letter).
- *
- * @see mlpack::CLI, PROGRAM_INFO()
- *
- * @bug
- * The __COUNTER__ variable is used in most cases to guarantee a unique global
- * identifier for options declared using the PARAM_*() macros. However, not all
- * compilers have this support--most notably, gcc < 4.3. In that case, the
- * __LINE__ macro is used as an attempt to get a unique global identifier, but
- * collisions are still possible, and they produce bizarre error messages. See
- * http://mlpack.org/trac/ticket/74 for more information.
- */
-#define PARAM_DOUBLE_REQ(ID, DESC, ALIAS) PARAM(double, ID, DESC, ALIAS, \
- 0.0f, true)
-
-/**
- * Define a required string parameter.
- *
- * The parameter must then be specified on the command line with
- * --ID=value.
- *
- * @param ID Name of the parameter.
- * @param DESC Quick description of the parameter (1-2 sentences).
- * @param ALIAS An alias for the parameter (one letter).
- *
- * @see mlpack::CLI, PROGRAM_INFO()
- *
- * @bug
- * The __COUNTER__ variable is used in most cases to guarantee a unique global
- * identifier for options declared using the PARAM_*() macros. However, not all
- * compilers have this support--most notably, gcc < 4.3. In that case, the
- * __LINE__ macro is used as an attempt to get a unique global identifier, but
- * collisions are still possible, and they produce bizarre error messages. See
- * http://mlpack.org/trac/ticket/74 for more information.
- */
-#define PARAM_STRING_REQ(ID, DESC, ALIAS) PARAM(std::string, ID, DESC, \
- ALIAS, "", true);
-
-/**
- * Define a required vector parameter.
- *
- * The parameter must then be specified on the command line with
- * --ID=value.
- *
- * @param ID Name of the parameter.
- * @param DESC Quick description of the parameter (1-2 sentences).
- * @param ALIAS An alias for the parameter (one letter).
- *
- * @see mlpack::CLI, PROGRAM_INFO()
- *
- * @bug
- * The __COUNTER__ variable is used in most cases to guarantee a unique global
- * identifier for options declared using the PARAM_*() macros. However, not all
- * compilers have this support--most notably, gcc < 4.3. In that case, the
- * __LINE__ macro is used as an attempt to get a unique global identifier, but
- * collisions are still possible, and they produce bizarre error messages. See
- * http://mlpack.org/trac/ticket/74 for more information.
- */
-#define PARAM_VECTOR_REQ(T, ID, DESC, ALIAS) PARAM(std::vector<T>, ID, DESC, \
- ALIAS, std::vector<T>(), true);
-
-/**
- * @cond
- * Don't document internal macros.
- */
-
-// These are ugly, but necessary utility functions we must use to generate a
-// unique identifier inside of the PARAM() module.
-#define JOIN(x, y) JOIN_AGAIN(x, y)
-#define JOIN_AGAIN(x, y) x ## y
-/** @endcond */
-
-/**
- * Define an input parameter. Don't use this function; use the other ones above
- * that call it. Note that we are using the __LINE__ macro for naming these
- * actual parameters when __COUNTER__ does not exist, which is a bit of an ugly
- * hack... but this is the preprocessor, after all. We don't have much choice
- * other than ugliness.
- *
- * @param T Type of the parameter.
- * @param ID Name of the parameter.
- * @param DESC Description of the parameter (1-2 sentences).
- * @param ALIAS Alias for this parameter (one letter).
- * @param DEF Default value of the parameter.
- * @param REQ Whether or not parameter is required (boolean value).
- */
-#ifdef __COUNTER__
- #define PARAM(T, ID, DESC, ALIAS, DEF, REQ) static mlpack::util::Option<T> \
- JOIN(io_option_dummy_object_, __COUNTER__) \
- (false, DEF, ID, DESC, ALIAS, REQ);
-
- /** @cond Don't document internal macros. */
- #define PARAM_FLAG_INTERNAL(ID, DESC, ALIAS) static \
- mlpack::util::Option<bool> JOIN(__io_option_flag_object_, __COUNTER__) \
- (ID, DESC, ALIAS);
- /** @endcond */
-
-#else
- // We have to do some really bizarre stuff since __COUNTER__ isn't defined. I
- // don't think we can absolutely guarantee success, but it should be "good
- // enough". We use the __LINE__ macro and the type of the parameter to try
- // and get a good guess at something unique.
- #define PARAM(T, ID, DESC, ALIAS, DEF, REQ) static mlpack::util::Option<T> \
- JOIN(JOIN(io_option_dummy_object_, __LINE__), opt) (false, DEF, ID, \
- DESC, ALIAS, REQ);
-
- /** @cond Don't document internal macros. */
- #define PARAM_FLAG_INTERNAL(ID, DESC, ALIAS) static \
- mlpack::util::Option<bool> JOIN(__io_option_flag_object_, __LINE__) \
- (ID, DESC, ALIAS);
- /** @endcond */
-
-#endif
-
-/**
- * The TYPENAME macro is used internally to convert a type into a string.
- */
-#define TYPENAME(x) (std::string(typeid(x).name()))
-
-namespace po = boost::program_options;
-
-namespace mlpack {
-
-namespace util {
-
-// Externally defined in option.hpp, this class holds information about the
-// program being run.
-class ProgramDoc;
-
-}; // namespace util
-
-/**
- * Aids in the extensibility of CLI by focusing potential
- * changes into one structure.
- */
-struct ParamData
-{
- //! Name of this parameter.
- std::string name;
- //! Description of this parameter, if any.
- std::string desc;
- //! Type information of this parameter.
- std::string tname;
- //! The actual value of this parameter.
- boost::any value;
- //! True if this parameter was passed in via command line or file.
- bool wasPassed;
- //! True if the wasPassed value should not be ignored
- bool isFlag;
-};
-
-/**
- * @brief Parses the command line for parameters and holds user-specified
- * parameters.
- *
- * The CLI class is a subsystem by which parameters for machine learning methods
- * can be specified and accessed. In conjunction with the macros PARAM_DOUBLE,
- * PARAM_INT, PARAM_STRING, PARAM_FLAG, and others, this class aims to make user
- * configurability of MLPACK methods very easy. There are only three methods in
- * CLI that a user should need: CLI::ParseCommandLine(), CLI::GetParam(), and
- * CLI::HasParam() (in addition to the PARAM_*() macros).
- *
- * @section addparam Adding parameters to a program
- *
- * @code
- * $ ./executable --bar=5
- * @endcode
- *
- * @note The = is optional; a space can also be used.
- *
- * A parameter is specified by using one of the following macros (this is not a
- * complete list; see core/io/cli.hpp):
- *
- * - PARAM_FLAG(ID, DESC, ALIAS)
- * - PARAM_DOUBLE(ID, DESC, ALIAS, DEF)
- * - PARAM_INT(ID, DESC, ALIAS, DEF)
- * - PARAM_STRING(ID, DESC, ALIAS, DEF)
- *
- * @param ID Name of the parameter.
- * @param DESC Short description of the parameter (one/two sentences).
- * @param ALIAS An alias for the parameter.
- * @param DEF Default value of the parameter.
- *
- * The flag (boolean) type automatically defaults to false; it is specified
- * merely as a flag on the command line (no '=true' is required).
- *
- * Here is an example of a few parameters being defined; this is for the AllkNN
- * executable (methods/neighbor_search/allknn_main.cpp):
- *
- * @code
- * PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
- * "r");
- * PARAM_STRING_REQ("distances_file", "File to output distances into.", "d");
- * PARAM_STRING_REQ("neighbors_file", "File to output neighbors into.", "n");
- * PARAM_INT_REQ("k", "Number of furthest neighbors to find.", "k");
- * PARAM_STRING("query_file", "File containing query points (optional).", "q",
- * "");
- * PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
- * PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.",
- * "N");
- * PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed "
- * "to dual-tree search.", "s");
- * @endcode
- *
- * More documentation is available on the PARAM_*() macros in the documentation
- * for core/io/cli.hpp.
- *
- * @section programinfo Documenting the program itself
- *
- * In addition to allowing documentation for each individual parameter and
- * module, the PROGRAM_INFO() macro provides support for documenting the program
- * itself. There should only be one instance of the PROGRAM_INFO() macro.
- * Below is an example:
- *
- * @code
- * PROGRAM_INFO("Maximum Variance Unfolding", "This program performs maximum "
- * "variance unfolding on the given dataset, writing a lower-dimensional "
- * "unfolded dataset to the given output file.");
- * @endcode
- *
- * This description should be verbose, and explain to a non-expert user what the
- * program does and how to use it. If relevant, paper citations should be
- * included.
- *
- * @section parsecli Parsing the command line with CLI
- *
- * To have CLI parse the command line at the beginning of code execution, only a
- * call to ParseCommandLine() is necessary:
- *
- * @code
- * int main(int argc, char** argv)
- * {
- * CLI::ParseCommandLine(argc, argv);
- *
- * ...
- * }
- * @endcode
- *
- * CLI provides --help and --info options which give nicely formatted
- * documentation of each option; the documentation is generated from the DESC
- * arguments in the PARAM_*() macros.
- *
- * @section getparam Getting parameters with CLI
- *
- * When the parameters have been defined, the next important thing is how to
- * access them. For this, the HasParam() and GetParam() methods are
- * used. For instance, to see if the user passed the flag (boolean) "naive":
- *
- * @code
- * if (CLI::HasParam("naive"))
- * {
- * Log::Info << "Naive has been passed!" << std::endl;
- * }
- * @endcode
- *
- * To get the value of a parameter, such as a string, use GetParam:
- *
- * @code
- * const std::string filename = CLI::GetParam<std::string>("filename");
- * @endcode
- *
- * @note
- * Options should only be defined in files which define `main()` (that is, main
- * executables). If options are defined elsewhere, they may be spuriously
- * included into other executables and confuse users. Similarly, if your
- * executable has options which you did not define, it is probably because the
- * option is defined somewhere else and included in your executable.
- *
- * @bug
- * The __COUNTER__ variable is used in most cases to guarantee a unique global
- * identifier for options declared using the PARAM_*() macros. However, not all
- * compilers have this support--most notably, gcc < 4.3. In that case, the
- * __LINE__ macro is used as an attempt to get a unique global identifier, but
- * collisions are still possible, and they produce bizarre error messages. See
- * http://mlpack.org/trac/ticket/74 for more information.
- */
-class CLI
-{
- public:
- /**
- * Adds a parameter to the hierarchy; use the PARAM_*() macros instead of this
- * (i.e. PARAM_INT()). Uses char* and not std::string since the vast majority
- * of use cases will be literal strings.
- *
- * @param identifier The name of the parameter.
- * @param description Short string description of the parameter.
- * @param alias An alias for the parameter, defaults to "" which is no alias.
- * ("").
- * @param required Indicates if parameter must be set on command line.
- */
- static void Add(const std::string& path,
- const std::string& description,
- const std::string& alias = "",
- bool required = false);
-
- /**
- * Adds a parameter to the hierarchy; use the PARAM_*() macros instead of this
- * (i.e. PARAM_INT()). Uses char* and not std::string since the vast majority
- * of use cases will be literal strings. If the argument requires a
- * parameter, you must specify a type.
- *
- * @param identifier The name of the parameter.
- * @param description Short string description of the parameter.
- * @param alias An alias for the parameter, defaults to "" which is no alias.
- * @param required Indicates if parameter must be set on command line.
- */
- template<class T>
- static void Add(const std::string& identifier,
- const std::string& description,
- const std::string& alias = "",
- bool required = false);
-
- /**
- * Adds a flag parameter to the hierarchy; use PARAM_FLAG() instead of this.
- *
- * @param identifier The name of the paramater.
- * @param description Short string description of the parameter.
- * @param alias An alias for the parameter, defaults to "" which is no alias.
- */
- static void AddFlag(const std::string& identifier,
- const std::string& description,
- const std::string& alias = "");
-
- /**
- * Parses the parameters for 'help' and 'info'.
- * If found, will print out the appropriate information and kill the program.
- */
- static void DefaultMessages();
-
- /**
- * Destroy the CLI object. This resets the pointer to the singleton, so in
- * case someone tries to access it after destruction, a new one will be made
- * (the program will not fail).
- */
- static void Destroy();
-
- /**
- * Grab the value of type T found while parsing. You can set the value using
- * this reference safely.
- *
- * @param identifier The name of the parameter in question.
- */
- template<typename T>
- static T& GetParam(const std::string& identifier);
-
- /**
- * Get the description of the specified node.
- *
- * @param identifier Name of the node in question.
- * @return Description of the node in question.
- */
- static std::string GetDescription(const std::string& identifier);
-
- /**
- * Retrieve the singleton.
- *
- * Not exposed to the outside, so as to spare users some ungainly
- * x.GetSingleton().foo() syntax.
- *
- * In this case, the singleton is used to store data for the static methods,
- * as there is no point in defining static methods only to have users call
- * private instance methods
- *
- * @return The singleton instance for use in the static methods.
- */
- static CLI& GetSingleton();
-
- /**
- * See if the specified flag was found while parsing.
- *
- * @param identifier The name of the parameter in question.
- */
- static bool HasParam(const std::string& identifier);
-
- /**
- * Hyphenate a string or split it onto multiple 80-character lines, with some
- * amount of padding on each line. This is ued for option output.
- *
- * @param str String to hyphenate (splits are on ' ').
- * @param padding Amount of padding on the left for each new line.
- */
- static std::string HyphenateString(const std::string& str, int padding);
-
- /**
- * Parses the commandline for arguments.
- *
- * @param argc The number of arguments on the commandline.
- * @param argv The array of arguments as strings.
- */
- static void ParseCommandLine(int argc, char** argv);
-
- /**
- * Removes duplicate flags.
- *
- * @param bpo The basic_program_options to remove duplicate flags from.
- */
- static void RemoveDuplicateFlags(po::basic_parsed_options<char>& bpo);
-
- /**
- * Parses a stream for arguments.
- *
- * @param stream The stream to be parsed.
- */
- static void ParseStream(std::istream& stream);
-
- /**
- * Print out the current hierarchy.
- */
- static void Print();
-
- /**
- * Print out the help info of the hierarchy.
- */
- static void PrintHelp(const std::string& param = "");
-
- /**
- * Registers a ProgramDoc object, which contains documentation about the
- * program. If this method has been called before (that is, if two
- * ProgramDocs are instantiated in the program), a fatal error will occur.
- *
- * @param doc Pointer to the ProgramDoc object.
- */
- static void RegisterProgramDoc(util::ProgramDoc* doc);
-
- /**
- * Destructor.
- */
- ~CLI();
-
- private:
- //! The documentation and names of options.
- po::options_description desc;
-
- //! Values of the options given by user.
- po::variables_map vmap;
-
- //! Pathnames of required options.
- std::list<std::string> requiredOptions;
-
- //! Map of global values.
- typedef std::map<std::string, ParamData> gmap_t;
- gmap_t globalValues;
-
- //! Map for aliases, from alias to actual name.
- typedef std::map<std::string, std::string> amap_t;
- amap_t aliasValues;
-
- //! The singleton itself.
- static CLI* singleton;
-
- //! True, if CLI was used to parse command line options.
- bool didParse;
-
- //! Holds the timer objects.
- Timers timer;
-
- //! So that Timer::Start() and Timer::Stop() can access the timer variable.
- friend class Timer;
-
- public:
- //! Pointer to the ProgramDoc object.
- util::ProgramDoc *doc;
-
- private:
- /**
- * Maps a given alias to a given parameter.
- *
- * @param alias The name of the alias to be mapped.
- * @param original The name of the parameter to be mapped.
- */
- static void AddAlias(const std::string& alias, const std::string& original);
-
- /**
- * Returns an alias, if given the name of the original.
- *
- * @param value The value in a key:value pair where the key
- * is an alias.
- * @return The alias associated with value.
- */
- static std::string AliasReverseLookup(const std::string& value);
-
-#ifdef _WIN32
- /**
- * Converts a FILETIME structure to an equivalent timeval structure.
- * Only necessary on windows platforms.
- * @param tv Valid timeval structure.
- */
- void FileTimeToTimeVal(timeval* tv);
-#endif
-
- /**
- * Checks that all required parameters have been specified on the command
- * line. If any have not been specified, an error message is printed and the
- * program is terminated.
- */
- static void RequiredOptions();
-
- /**
- * Cleans up input pathnames, rendering strings such as /foo/bar
- * and foo/bar/ equivalent inputs.
- *
- * @param str Input string.
- * @return Sanitized string.
- */
- static std::string SanitizeString(const std::string& str);
-
- /**
- * Parses the values given on the command line, overriding any default values.
- */
- static void UpdateGmap();
-
- /**
- * Make the constructor private, to preclude unauthorized instances.
- */
- CLI();
-
- /**
- * Initialize desc with a particular name.
- *
- * @param optionsName Name of the module, as far as boost is concerned.
- */
- CLI(const std::string& optionsName);
-
- //! Private copy constructor; we don't want copies floating around.
- CLI(const CLI& other);
-};
-
-}; // namespace mlpack
-
-// Include the actual definitions of templated methods
-#include "cli_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/cli.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,786 @@
+/**
+ * @file cli.hpp
+ * @author Matthew Amidon
+ *
+ * This file implements the CLI subsystem which is intended to replace FX.
+ * This can be used more or less regardless of context. In the future,
+ * it might be expanded to include file I/O.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_UTIL_CLI_HPP
+#define __MLPACK_CORE_UTIL_CLI_HPP
+
+#include <list>
+#include <iostream>
+#include <map>
+#include <string>
+
+#include <boost/any.hpp>
+#include <boost/program_options.hpp>
+
+#include "timers.hpp"
+#include "cli_deleter.hpp" // To make sure we can delete the singleton.
+
+/**
+ * Document an executable. Only one instance of this macro should be
+ * present in your program! Therefore, use it in the main.cpp
+ * (or corresponding executable) in your program.
+ *
+ * @see mlpack::CLI, PARAM_FLAG(), PARAM_INT(), PARAM_DOUBLE(), PARAM_STRING(),
+ * PARAM_VECTOR(), PARAM_INT_REQ(), PARAM_DOUBLE_REQ(), PARAM_STRING_REQ(),
+ * PARAM_VECTOR_REQ().
+ *
+ * @param NAME Short string representing the name of the program.
+ * @param DESC Long string describing what the program does and possibly a
+ * simple usage example. Newlines should not be used here; this is taken
+ * care of by CLI (however, you can explicitly specify newlines to denote
+ * new paragraphs).
+ */
+#define PROGRAM_INFO(NAME, DESC) static mlpack::util::ProgramDoc \
+ io_programdoc_dummy_object = mlpack::util::ProgramDoc(NAME, DESC);
+
+/**
+ * Define a flag parameter.
+ *
+ * @param ID Name of the parameter.
+ * @param DESC Quick description of the parameter (1-2 sentences).
+ * @param ALIAS An alias for the parameter (one letter).
+ *
+ * @see mlpack::CLI, PROGRAM_INFO()
+ *
+ * @bug
+ * The __COUNTER__ variable is used in most cases to guarantee a unique global
+ * identifier for options declared using the PARAM_*() macros. However, not all
+ * compilers have this support--most notably, gcc < 4.3. In that case, the
+ * __LINE__ macro is used as an attempt to get a unique global identifier, but
+ * collisions are still possible, and they produce bizarre error messages. See
+ * http://mlpack.org/trac/ticket/74 for more information.
+ */
+#define PARAM_FLAG(ID, DESC, ALIAS) \
+ PARAM_FLAG_INTERNAL(ID, DESC, ALIAS);
+
+/**
+ * Define an integer parameter.
+ *
+ * The parameter can then be specified on the command line with
+ * --ID=value.
+ *
+ * @param ID Name of the parameter.
+ * @param DESC Quick description of the parameter (1-2 sentences).
+ * @param ALIAS An alias for the parameter (one letter).
+ * @param DEF Default value of the parameter.
+ *
+ * @see mlpack::CLI, PROGRAM_INFO()
+ *
+ * @bug
+ * The __COUNTER__ variable is used in most cases to guarantee a unique global
+ * identifier for options declared using the PARAM_*() macros. However, not all
+ * compilers have this support--most notably, gcc < 4.3. In that case, the
+ * __LINE__ macro is used as an attempt to get a unique global identifier, but
+ * collisions are still possible, and they produce bizarre error messages. See
+ * http://mlpack.org/trac/ticket/74 for more information.
+ */
+#define PARAM_INT(ID, DESC, ALIAS, DEF) \
+ PARAM(int, ID, DESC, ALIAS, DEF, false)
+
+/**
+ * Define a floating-point parameter. You should use PARAM_DOUBLE instead.
+ *
+ * The parameter can then be specified on the command line with
+ * --ID=value.
+ *
+ * @param ID Name of the parameter.
+ * @param DESC Quick description of the parameter (1-2 sentences).
+ * @param ALIAS An alias for the parameter (one letter).
+ * @param DEF Default value of the parameter.
+ *
+ * @see mlpack::CLI, PROGRAM_INFO()
+ *
+ * @bug
+ * The __COUNTER__ variable is used in most cases to guarantee a unique global
+ * identifier for options declared using the PARAM_*() macros. However, not all
+ * compilers have this support--most notably, gcc < 4.3. In that case, the
+ * __LINE__ macro is used as an attempt to get a unique global identifier, but
+ * collisions are still possible, and they produce bizarre error messages. See
+ * http://mlpack.org/trac/ticket/74 for more information.
+ */
+#define PARAM_FLOAT(ID, DESC, ALIAS, DEF) \
+ PARAM(float, ID, DESC, ALIAS, DEF, false)
+
+/**
+ * Define a double parameter.
+ *
+ * The parameter can then be specified on the command line with
+ * --ID=value.
+ *
+ * @param ID Name of the parameter.
+ * @param DESC Quick description of the parameter (1-2 sentences).
+ * @param ALIAS An alias for the parameter (one letter).
+ * @param DEF Default value of the parameter.
+ *
+ * @see mlpack::CLI, PROGRAM_INFO()
+ *
+ * @bug
+ * The __COUNTER__ variable is used in most cases to guarantee a unique global
+ * identifier for options declared using the PARAM_*() macros. However, not all
+ * compilers have this support--most notably, gcc < 4.3. In that case, the
+ * __LINE__ macro is used as an attempt to get a unique global identifier, but
+ * collisions are still possible, and they produce bizarre error messages. See
+ * http://mlpack.org/trac/ticket/74 for more information.
+ */
+#define PARAM_DOUBLE(ID, DESC, ALIAS, DEF) \
+ PARAM(double, ID, DESC, ALIAS, DEF, false)
+
+/**
+ * Define a string parameter.
+ *
+ * The parameter can then be specified on the command line with
+ * --ID=value. If ALIAS is equal to DEF_MOD (which is set using the
+ * PROGRAM_INFO() macro), the parameter can be specified with just --ID=value.
+ *
+ * @param ID Name of the parameter.
+ * @param DESC Quick description of the parameter (1-2 sentences).
+ * @param ALIAS An alias for the parameter (one letter).
+ * @param DEF Default value of the parameter.
+ *
+ * @see mlpack::CLI, PROGRAM_INFO()
+ *
+ * @bug
+ * The __COUNTER__ variable is used in most cases to guarantee a unique global
+ * identifier for options declared using the PARAM_*() macros. However, not all
+ * compilers have this support--most notably, gcc < 4.3. In that case, the
+ * __LINE__ macro is used as an attempt to get a unique global identifier, but
+ * collisions are still possible, and they produce bizarre error messages. See
+ * http://mlpack.org/trac/ticket/74 for more information.
+ */
+#define PARAM_STRING(ID, DESC, ALIAS, DEF) \
+ PARAM(std::string, ID, DESC, ALIAS, DEF, false)
+
+/**
+ * Define a vector parameter.
+ *
+ * The parameter can then be specified on the command line with
+ * --ID=value.
+ *
+ * @param ID Name of the parameter.
+ * @param DESC Quick description of the parameter (1-2 sentences).
+ * @param ALIAS An alias for the parameter (one letter).
+ * @param DEF Default value of the parameter.
+ *
+ * @see mlpack::CLI, PROGRAM_INFO()
+ *
+ * @bug
+ * The __COUNTER__ variable is used in most cases to guarantee a unique global
+ * identifier for options declared using the PARAM_*() macros. However, not all
+ * compilers have this support--most notably, gcc < 4.3. In that case, the
+ * __LINE__ macro is used as an attempt to get a unique global identifier, but
+ * collisions are still possible, and they produce bizarre error messages. See
+ * http://mlpack.org/trac/ticket/74 for more information.
+ */
+#define PARAM_VECTOR(T, ID, DESC, ALIAS) \
+ PARAM(std::vector<T>, ID, DESC, ALIAS, std::vector<T>(), false)
+
+// A required flag doesn't make sense and isn't given here.
+
+/**
+ * Define a required integer parameter.
+ *
+ * The parameter must then be specified on the command line with
+ * --ID=value.
+ *
+ * @param ID Name of the parameter.
+ * @param DESC Quick description of the parameter (1-2 sentences).
+ * @param ALIAS An alias for the parameter (one letter).
+ *
+ * @see mlpack::CLI, PROGRAM_INFO()
+ *
+ * @bug
+ * The __COUNTER__ variable is used in most cases to guarantee a unique global
+ * identifier for options declared using the PARAM_*() macros. However, not all
+ * compilers have this support--most notably, gcc < 4.3. In that case, the
+ * __LINE__ macro is used as an attempt to get a unique global identifier, but
+ * collisions are still possible, and they produce bizarre error messages. See
+ * http://mlpack.org/trac/ticket/74 for more information.
+ */
+#define PARAM_INT_REQ(ID, DESC, ALIAS) PARAM(int, ID, DESC, ALIAS, 0, true)
+
+/**
+ * Define a required floating-point parameter. You should probably use a double
+ * instead.
+ *
+ * The parameter must then be specified on the command line with
+ * --ID=value. If ALIAS is equal to DEF_MOD (which is set using the
+ * PROGRAM_INFO() macro), the parameter can be specified with just --ID=value.
+ *
+ * @param ID Name of the parameter.
+ * @param DESC Quick description of the parameter (1-2 sentences).
+ * @param ALIAS An alias for the parameter (one letter).
+ *
+ * @see mlpack::CLI, PROGRAM_INFO()
+ *
+ * @bug
+ * The __COUNTER__ variable is used in most cases to guarantee a unique global
+ * identifier for options declared using the PARAM_*() macros. However, not all
+ * compilers have this support--most notably, gcc < 4.3. In that case, the
+ * __LINE__ macro is used as an attempt to get a unique global identifier, but
+ * collisions are still possible, and they produce bizarre error messages. See
+ * http://mlpack.org/trac/ticket/74 for more information.
+ */
+#define PARAM_FLOAT_REQ(ID, DESC, ALIAS) PARAM(float, ID, DESC, ALIAS, 0.0f, \
+ true)
+
+/**
+ * Define a required double parameter.
+ *
+ * The parameter must then be specified on the command line with
+ * --ID=value.
+ *
+ * @param ID Name of the parameter.
+ * @param DESC Quick description of the parameter (1-2 sentences).
+ * @param ALIAS An alias for the parameter (one letter).
+ *
+ * @see mlpack::CLI, PROGRAM_INFO()
+ *
+ * @bug
+ * The __COUNTER__ variable is used in most cases to guarantee a unique global
+ * identifier for options declared using the PARAM_*() macros. However, not all
+ * compilers have this support--most notably, gcc < 4.3. In that case, the
+ * __LINE__ macro is used as an attempt to get a unique global identifier, but
+ * collisions are still possible, and they produce bizarre error messages. See
+ * http://mlpack.org/trac/ticket/74 for more information.
+ */
+#define PARAM_DOUBLE_REQ(ID, DESC, ALIAS) PARAM(double, ID, DESC, ALIAS, \
+ 0.0f, true)
+
+/**
+ * Define a required string parameter.
+ *
+ * The parameter must then be specified on the command line with
+ * --ID=value.
+ *
+ * @param ID Name of the parameter.
+ * @param DESC Quick description of the parameter (1-2 sentences).
+ * @param ALIAS An alias for the parameter (one letter).
+ *
+ * @see mlpack::CLI, PROGRAM_INFO()
+ *
+ * @bug
+ * The __COUNTER__ variable is used in most cases to guarantee a unique global
+ * identifier for options declared using the PARAM_*() macros. However, not all
+ * compilers have this support--most notably, gcc < 4.3. In that case, the
+ * __LINE__ macro is used as an attempt to get a unique global identifier, but
+ * collisions are still possible, and they produce bizarre error messages. See
+ * http://mlpack.org/trac/ticket/74 for more information.
+ */
+#define PARAM_STRING_REQ(ID, DESC, ALIAS) PARAM(std::string, ID, DESC, \
+ ALIAS, "", true);
+
+/**
+ * Define a required vector parameter.
+ *
+ * The parameter must then be specified on the command line with
+ * --ID=value.
+ *
+ * @param ID Name of the parameter.
+ * @param DESC Quick description of the parameter (1-2 sentences).
+ * @param ALIAS An alias for the parameter (one letter).
+ *
+ * @see mlpack::CLI, PROGRAM_INFO()
+ *
+ * @bug
+ * The __COUNTER__ variable is used in most cases to guarantee a unique global
+ * identifier for options declared using the PARAM_*() macros. However, not all
+ * compilers have this support--most notably, gcc < 4.3. In that case, the
+ * __LINE__ macro is used as an attempt to get a unique global identifier, but
+ * collisions are still possible, and they produce bizarre error messages. See
+ * http://mlpack.org/trac/ticket/74 for more information.
+ */
+#define PARAM_VECTOR_REQ(T, ID, DESC, ALIAS) PARAM(std::vector<T>, ID, DESC, \
+ ALIAS, std::vector<T>(), true);
+
+/**
+ * @cond
+ * Don't document internal macros.
+ */
+
+// These are ugly, but necessary utility functions we must use to generate a
+// unique identifier inside of the PARAM() module.
+#define JOIN(x, y) JOIN_AGAIN(x, y)
+#define JOIN_AGAIN(x, y) x ## y
+/** @endcond */
+
+/**
+ * Define an input parameter. Don't use this function; use the other ones above
+ * that call it. Note that we are using the __LINE__ macro for naming these
+ * actual parameters when __COUNTER__ does not exist, which is a bit of an ugly
+ * hack... but this is the preprocessor, after all. We don't have much choice
+ * other than ugliness.
+ *
+ * @param T Type of the parameter.
+ * @param ID Name of the parameter.
+ * @param DESC Description of the parameter (1-2 sentences).
+ * @param ALIAS Alias for this parameter (one letter).
+ * @param DEF Default value of the parameter.
+ * @param REQ Whether or not parameter is required (boolean value).
+ */
+#ifdef __COUNTER__
+ #define PARAM(T, ID, DESC, ALIAS, DEF, REQ) static mlpack::util::Option<T> \
+ JOIN(io_option_dummy_object_, __COUNTER__) \
+ (false, DEF, ID, DESC, ALIAS, REQ);
+
+ /** @cond Don't document internal macros. */
+ #define PARAM_FLAG_INTERNAL(ID, DESC, ALIAS) static \
+ mlpack::util::Option<bool> JOIN(__io_option_flag_object_, __COUNTER__) \
+ (ID, DESC, ALIAS);
+ /** @endcond */
+
+#else
+ // We have to do some really bizarre stuff since __COUNTER__ isn't defined. I
+ // don't think we can absolutely guarantee success, but it should be "good
+ // enough". We use the __LINE__ macro and the type of the parameter to try
+ // and get a good guess at something unique.
+ #define PARAM(T, ID, DESC, ALIAS, DEF, REQ) static mlpack::util::Option<T> \
+ JOIN(JOIN(io_option_dummy_object_, __LINE__), opt) (false, DEF, ID, \
+ DESC, ALIAS, REQ);
+
+ /** @cond Don't document internal macros. */
+ #define PARAM_FLAG_INTERNAL(ID, DESC, ALIAS) static \
+ mlpack::util::Option<bool> JOIN(__io_option_flag_object_, __LINE__) \
+ (ID, DESC, ALIAS);
+ /** @endcond */
+
+#endif
+
+/**
+ * The TYPENAME macro is used internally to convert a type into a string.
+ */
+#define TYPENAME(x) (std::string(typeid(x).name()))
+
+namespace po = boost::program_options;
+
+namespace mlpack {
+
+namespace util {
+
+// Externally defined in option.hpp, this class holds information about the
+// program being run.
+class ProgramDoc;
+
+}; // namespace util
+
+/**
+ * Aids in the extensibility of CLI by focusing potential
+ * changes into one structure.
+ */
+struct ParamData
+{
+ //! Name of this parameter.
+ std::string name;
+ //! Description of this parameter, if any.
+ std::string desc;
+ //! Type information of this parameter.
+ std::string tname;
+ //! The actual value of this parameter.
+ boost::any value;
+ //! True if this parameter was passed in via command line or file.
+ bool wasPassed;
+ //! True if the wasPassed value should not be ignored
+ bool isFlag;
+};
+
+/**
+ * @brief Parses the command line for parameters and holds user-specified
+ * parameters.
+ *
+ * The CLI class is a subsystem by which parameters for machine learning methods
+ * can be specified and accessed. In conjunction with the macros PARAM_DOUBLE,
+ * PARAM_INT, PARAM_STRING, PARAM_FLAG, and others, this class aims to make user
+ * configurability of MLPACK methods very easy. There are only three methods in
+ * CLI that a user should need: CLI::ParseCommandLine(), CLI::GetParam(), and
+ * CLI::HasParam() (in addition to the PARAM_*() macros).
+ *
+ * @section addparam Adding parameters to a program
+ *
+ * @code
+ * $ ./executable --bar=5
+ * @endcode
+ *
+ * @note The = is optional; a space can also be used.
+ *
+ * A parameter is specified by using one of the following macros (this is not a
+ * complete list; see core/io/cli.hpp):
+ *
+ * - PARAM_FLAG(ID, DESC, ALIAS)
+ * - PARAM_DOUBLE(ID, DESC, ALIAS, DEF)
+ * - PARAM_INT(ID, DESC, ALIAS, DEF)
+ * - PARAM_STRING(ID, DESC, ALIAS, DEF)
+ *
+ * @param ID Name of the parameter.
+ * @param DESC Short description of the parameter (one/two sentences).
+ * @param ALIAS An alias for the parameter.
+ * @param DEF Default value of the parameter.
+ *
+ * The flag (boolean) type automatically defaults to false; it is specified
+ * merely as a flag on the command line (no '=true' is required).
+ *
+ * Here is an example of a few parameters being defined; this is for the AllkNN
+ * executable (methods/neighbor_search/allknn_main.cpp):
+ *
+ * @code
+ * PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
+ * "r");
+ * PARAM_STRING_REQ("distances_file", "File to output distances into.", "d");
+ * PARAM_STRING_REQ("neighbors_file", "File to output neighbors into.", "n");
+ * PARAM_INT_REQ("k", "Number of furthest neighbors to find.", "k");
+ * PARAM_STRING("query_file", "File containing query points (optional).", "q",
+ * "");
+ * PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
+ * PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.",
+ * "N");
+ * PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed "
+ * "to dual-tree search.", "s");
+ * @endcode
+ *
+ * More documentation is available on the PARAM_*() macros in the documentation
+ * for core/io/cli.hpp.
+ *
+ * @section programinfo Documenting the program itself
+ *
+ * In addition to allowing documentation for each individual parameter and
+ * module, the PROGRAM_INFO() macro provides support for documenting the program
+ * itself. There should only be one instance of the PROGRAM_INFO() macro.
+ * Below is an example:
+ *
+ * @code
+ * PROGRAM_INFO("Maximum Variance Unfolding", "This program performs maximum "
+ * "variance unfolding on the given dataset, writing a lower-dimensional "
+ * "unfolded dataset to the given output file.");
+ * @endcode
+ *
+ * This description should be verbose, and explain to a non-expert user what the
+ * program does and how to use it. If relevant, paper citations should be
+ * included.
+ *
+ * @section parsecli Parsing the command line with CLI
+ *
+ * To have CLI parse the command line at the beginning of code execution, only a
+ * call to ParseCommandLine() is necessary:
+ *
+ * @code
+ * int main(int argc, char** argv)
+ * {
+ * CLI::ParseCommandLine(argc, argv);
+ *
+ * ...
+ * }
+ * @endcode
+ *
+ * CLI provides --help and --info options which give nicely formatted
+ * documentation of each option; the documentation is generated from the DESC
+ * arguments in the PARAM_*() macros.
+ *
+ * @section getparam Getting parameters with CLI
+ *
+ * When the parameters have been defined, the next important thing is how to
+ * access them. For this, the HasParam() and GetParam() methods are
+ * used. For instance, to see if the user passed the flag (boolean) "naive":
+ *
+ * @code
+ * if (CLI::HasParam("naive"))
+ * {
+ * Log::Info << "Naive has been passed!" << std::endl;
+ * }
+ * @endcode
+ *
+ * To get the value of a parameter, such as a string, use GetParam:
+ *
+ * @code
+ * const std::string filename = CLI::GetParam<std::string>("filename");
+ * @endcode
+ *
+ * @note
+ * Options should only be defined in files which define `main()` (that is, main
+ * executables). If options are defined elsewhere, they may be spuriously
+ * included into other executables and confuse users. Similarly, if your
+ * executable has options which you did not define, it is probably because the
+ * option is defined somewhere else and included in your executable.
+ *
+ * @bug
+ * The __COUNTER__ variable is used in most cases to guarantee a unique global
+ * identifier for options declared using the PARAM_*() macros. However, not all
+ * compilers have this support--most notably, gcc < 4.3. In that case, the
+ * __LINE__ macro is used as an attempt to get a unique global identifier, but
+ * collisions are still possible, and they produce bizarre error messages. See
+ * http://mlpack.org/trac/ticket/74 for more information.
+ */
+class CLI
+{
+ public:
+ /**
+ * Adds a parameter to the hierarchy; use the PARAM_*() macros instead of this
+ * (i.e. PARAM_INT()). Uses char* and not std::string since the vast majority
+ * of use cases will be literal strings.
+ *
+ * @param identifier The name of the parameter.
+ * @param description Short string description of the parameter.
+ * @param alias An alias for the parameter, defaults to "" which is no alias.
+ * ("").
+ * @param required Indicates if parameter must be set on command line.
+ */
+ static void Add(const std::string& path,
+ const std::string& description,
+ const std::string& alias = "",
+ bool required = false);
+
+ /**
+ * Adds a parameter to the hierarchy; use the PARAM_*() macros instead of this
+ * (i.e. PARAM_INT()). Uses char* and not std::string since the vast majority
+ * of use cases will be literal strings. If the argument requires a
+ * parameter, you must specify a type.
+ *
+ * @param identifier The name of the parameter.
+ * @param description Short string description of the parameter.
+ * @param alias An alias for the parameter, defaults to "" which is no alias.
+ * @param required Indicates if parameter must be set on command line.
+ */
+ template<class T>
+ static void Add(const std::string& identifier,
+ const std::string& description,
+ const std::string& alias = "",
+ bool required = false);
+
+ /**
+ * Adds a flag parameter to the hierarchy; use PARAM_FLAG() instead of this.
+ *
+ * @param identifier The name of the paramater.
+ * @param description Short string description of the parameter.
+ * @param alias An alias for the parameter, defaults to "" which is no alias.
+ */
+ static void AddFlag(const std::string& identifier,
+ const std::string& description,
+ const std::string& alias = "");
+
+ /**
+ * Parses the parameters for 'help' and 'info'.
+ * If found, will print out the appropriate information and kill the program.
+ */
+ static void DefaultMessages();
+
+ /**
+ * Destroy the CLI object. This resets the pointer to the singleton, so in
+ * case someone tries to access it after destruction, a new one will be made
+ * (the program will not fail).
+ */
+ static void Destroy();
+
+ /**
+ * Grab the value of type T found while parsing. You can set the value using
+ * this reference safely.
+ *
+ * @param identifier The name of the parameter in question.
+ */
+ template<typename T>
+ static T& GetParam(const std::string& identifier);
+
+ /**
+ * Get the description of the specified node.
+ *
+ * @param identifier Name of the node in question.
+ * @return Description of the node in question.
+ */
+ static std::string GetDescription(const std::string& identifier);
+
+ /**
+ * Retrieve the singleton.
+ *
+ * Not exposed to the outside, so as to spare users some ungainly
+ * x.GetSingleton().foo() syntax.
+ *
+ * In this case, the singleton is used to store data for the static methods,
+ * as there is no point in defining static methods only to have users call
+ * private instance methods
+ *
+ * @return The singleton instance for use in the static methods.
+ */
+ static CLI& GetSingleton();
+
+ /**
+ * See if the specified flag was found while parsing.
+ *
+ * @param identifier The name of the parameter in question.
+ */
+ static bool HasParam(const std::string& identifier);
+
+ /**
+ * Hyphenate a string or split it onto multiple 80-character lines, with some
+ * amount of padding on each line. This is ued for option output.
+ *
+ * @param str String to hyphenate (splits are on ' ').
+ * @param padding Amount of padding on the left for each new line.
+ */
+ static std::string HyphenateString(const std::string& str, int padding);
+
+ /**
+ * Parses the commandline for arguments.
+ *
+ * @param argc The number of arguments on the commandline.
+ * @param argv The array of arguments as strings.
+ */
+ static void ParseCommandLine(int argc, char** argv);
+
+ /**
+ * Removes duplicate flags.
+ *
+ * @param bpo The basic_program_options to remove duplicate flags from.
+ */
+ static void RemoveDuplicateFlags(po::basic_parsed_options<char>& bpo);
+
+ /**
+ * Parses a stream for arguments.
+ *
+ * @param stream The stream to be parsed.
+ */
+ static void ParseStream(std::istream& stream);
+
+ /**
+ * Print out the current hierarchy.
+ */
+ static void Print();
+
+ /**
+ * Print out the help info of the hierarchy.
+ */
+ static void PrintHelp(const std::string& param = "");
+
+ /**
+ * Registers a ProgramDoc object, which contains documentation about the
+ * program. If this method has been called before (that is, if two
+ * ProgramDocs are instantiated in the program), a fatal error will occur.
+ *
+ * @param doc Pointer to the ProgramDoc object.
+ */
+ static void RegisterProgramDoc(util::ProgramDoc* doc);
+
+ /**
+ * Destructor.
+ */
+ ~CLI();
+
+ private:
+ //! The documentation and names of options.
+ po::options_description desc;
+
+ //! Values of the options given by user.
+ po::variables_map vmap;
+
+ //! Pathnames of required options.
+ std::list<std::string> requiredOptions;
+
+ //! Map of global values.
+ typedef std::map<std::string, ParamData> gmap_t;
+ gmap_t globalValues;
+
+ //! Map for aliases, from alias to actual name.
+ typedef std::map<std::string, std::string> amap_t;
+ amap_t aliasValues;
+
+ //! The singleton itself.
+ static CLI* singleton;
+
+ //! True, if CLI was used to parse command line options.
+ bool didParse;
+
+ //! Holds the timer objects.
+ Timers timer;
+
+ //! So that Timer::Start() and Timer::Stop() can access the timer variable.
+ friend class Timer;
+
+ public:
+ //! Pointer to the ProgramDoc object.
+ util::ProgramDoc *doc;
+
+ private:
+ /**
+ * Maps a given alias to a given parameter.
+ *
+ * @param alias The name of the alias to be mapped.
+ * @param original The name of the parameter to be mapped.
+ */
+ static void AddAlias(const std::string& alias, const std::string& original);
+
+ /**
+ * Returns an alias, if given the name of the original.
+ *
+ * @param value The value in a key:value pair where the key
+ * is an alias.
+ * @return The alias associated with value.
+ */
+ static std::string AliasReverseLookup(const std::string& value);
+
+#ifdef _WIN32
+ /**
+ * Converts a FILETIME structure to an equivalent timeval structure.
+ * Only necessary on windows platforms.
+ * @param tv Valid timeval structure.
+ */
+ void FileTimeToTimeVal(timeval* tv);
+#endif
+
+ /**
+ * Checks that all required parameters have been specified on the command
+ * line. If any have not been specified, an error message is printed and the
+ * program is terminated.
+ */
+ static void RequiredOptions();
+
+ /**
+ * Cleans up input pathnames, rendering strings such as /foo/bar
+ * and foo/bar/ equivalent inputs.
+ *
+ * @param str Input string.
+ * @return Sanitized string.
+ */
+ static std::string SanitizeString(const std::string& str);
+
+ /**
+ * Parses the values given on the command line, overriding any default values.
+ */
+ static void UpdateGmap();
+
+ /**
+ * Make the constructor private, to preclude unauthorized instances.
+ */
+ CLI();
+
+ /**
+ * Initialize desc with a particular name.
+ *
+ * @param optionsName Name of the module, as far as boost is concerned.
+ */
+ CLI(const std::string& optionsName);
+
+ //! Private copy constructor; we don't want copies floating around.
+ CLI(const CLI& other);
+};
+
+}; // namespace mlpack
+
+// Include the actual definitions of templated methods
+#include "cli_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli_deleter.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/cli_deleter.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli_deleter.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,47 +0,0 @@
-/**
- * @file cli_deleter.cpp
- * @author Ryan Curtin
- *
- * Extremely simple class whose only job is to delete the existing CLI object at
- * the end of execution. This is meant to allow the user to avoid typing
- * 'CLI::Destroy()' at the end of their program. The file also defines a static
- * CLIDeleter class, which will be initialized at the beginning of the program
- * and deleted at the end. The destructor destroys the CLI singleton.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "cli_deleter.hpp"
-#include "cli.hpp"
-
-using namespace mlpack;
-using namespace mlpack::util;
-
-/***
- * Empty constructor that does nothing.
- */
-CLIDeleter::CLIDeleter()
-{
- /* Nothing to do. */
-}
-
-/***
- * This destructor deletes the CLI singleton.
- */
-CLIDeleter::~CLIDeleter()
-{
- // Delete the singleton!
- CLI::Destroy();
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli_deleter.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/cli_deleter.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli_deleter.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli_deleter.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,47 @@
+/**
+ * @file cli_deleter.cpp
+ * @author Ryan Curtin
+ *
+ * Extremely simple class whose only job is to delete the existing CLI object at
+ * the end of execution. This is meant to allow the user to avoid typing
+ * 'CLI::Destroy()' at the end of their program. The file also defines a static
+ * CLIDeleter class, which will be initialized at the beginning of the program
+ * and deleted at the end. The destructor destroys the CLI singleton.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "cli_deleter.hpp"
+#include "cli.hpp"
+
+using namespace mlpack;
+using namespace mlpack::util;
+
+/***
+ * Empty constructor that does nothing.
+ */
+CLIDeleter::CLIDeleter()
+{
+ /* Nothing to do. */
+}
+
+/***
+ * This destructor deletes the CLI singleton.
+ */
+CLIDeleter::~CLIDeleter()
+{
+ // Delete the singleton!
+ CLI::Destroy();
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli_deleter.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/cli_deleter.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli_deleter.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,48 +0,0 @@
-/**
- * @file cli_deleter.hpp
- * @author Ryan Curtin
- *
- * Definition of the CLIDeleter() class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_UTIL_CLI_DELETER_HPP
-#define __MLPACK_CORE_UTIL_CLI_DELETER_HPP
-
-namespace mlpack {
-namespace util {
-
-/**
- * Extremely simple class whose only job is to delete the existing CLI object at
- * the end of execution. This is meant to allow the user to avoid typing
- * 'CLI::Destroy()' at the end of their program. The file also defines a static
- * CLIDeleter class, which will be initialized at the beginning of the program
- * and deleted at the end. The destructor destroys the CLI singleton.
- */
-class CLIDeleter
-{
- public:
- CLIDeleter();
- ~CLIDeleter();
-};
-
-//! Declare the deleter.
-static CLIDeleter cliDeleter;
-
-}; // namespace io
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli_deleter.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/cli_deleter.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli_deleter.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli_deleter.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,48 @@
+/**
+ * @file cli_deleter.hpp
+ * @author Ryan Curtin
+ *
+ * Definition of the CLIDeleter() class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_UTIL_CLI_DELETER_HPP
+#define __MLPACK_CORE_UTIL_CLI_DELETER_HPP
+
+namespace mlpack {
+namespace util {
+
+/**
+ * Extremely simple class whose only job is to delete the existing CLI object at
+ * the end of execution. This is meant to allow the user to avoid typing
+ * 'CLI::Destroy()' at the end of their program. The file also defines a static
+ * CLIDeleter class, which will be initialized at the beginning of the program
+ * and deleted at the end. The destructor destroys the CLI singleton.
+ */
+class CLIDeleter
+{
+ public:
+ CLIDeleter();
+ ~CLIDeleter();
+};
+
+//! Declare the deleter.
+static CLIDeleter cliDeleter;
+
+}; // namespace io
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/cli_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,127 +0,0 @@
-/**
- * @file cli_impl.hpp
- * @author Matthew Amidon
- *
- * Implementation of templated functions of the CLI class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_UTIL_CLI_IMPL_HPP
-#define __MLPACK_CORE_UTIL_CLI_IMPL_HPP
-
-// In case it has not already been included.
-#include "cli.hpp"
-
-// Include option.hpp here because it requires CLI but is also templated.
-#include "option.hpp"
-
-namespace mlpack {
-
-/**
- * @brief Adds a parameter to CLI, making it accessibile via GetParam &
- * CheckValue.
- *
- * @tparam T The type of the parameter.
- * @param identifier The name of the parameter, eg foo in bar/foo.
- * @param description A string description of the parameter.
- * @param parent The name of the parent of the parameter,
- * eg bar/foo in bar/foo/buzz.
- * @param required If required, the program will refuse to run
- * unless the parameter is specified.
- */
-template<typename T>
-void CLI::Add(const std::string& path,
- const std::string& description,
- const std::string& alias,
- bool required)
-{
-
- po::options_description& desc = CLI::GetSingleton().desc;
- // Must make use of boost syntax here.
- std::string progOptId = alias.length() ? path + "," + alias : path;
-
- // Add the alias, if necessary
- AddAlias(alias, path);
-
- // Add the option to boost program_options.
- desc.add_options()
- (progOptId.c_str(), po::value<T>(), description.c_str());
-
- // Make sure the appropriate metadata is inserted into gmap.
- gmap_t& gmap = GetSingleton().globalValues;
-
- ParamData data;
- T tmp = T();
-
- data.desc = description;
- data.name = path;
- data.tname = TYPENAME(T);
- data.value = boost::any(tmp);
- data.wasPassed = false;
-
- gmap[path] = data;
-
- // If the option is required, add it to the required options list.
- if (required)
- GetSingleton().requiredOptions.push_front(path);
-}
-
-
-/**
- * @brief Returns the value of the specified parameter.
- * If the parameter is unspecified, an undefined but
- * more or less valid value is returned.
- *
- * @tparam T The type of the parameter.
- * @param identifier The full pathname of the parameter.
- *
- * @return The value of the parameter. Use CLI::CheckValue to determine if it's
- * valid.
- */
-template<typename T>
-T& CLI::GetParam(const std::string& identifier)
-{
- // Used to ensure we have a valid value.
- T tmp = T();
-
- // Used to index into the globalValues map.
- std::string key = std::string(identifier);
- gmap_t& gmap = GetSingleton().globalValues;
-
- // Now check if we have an alias.
- amap_t& amap = GetSingleton().aliasValues;
- if (amap.count(key))
- key = amap[key];
-
- // What if we don't actually have any value?
- if (!gmap.count(key))
- {
- gmap[key] = ParamData();
- gmap[key].value = boost::any(tmp);
- *boost::any_cast<T>(&gmap[key].value) = tmp;
- }
-
- // What if we have meta-data, but no data?
- boost::any val = gmap[key].value;
- if (val.empty())
- gmap[key].value = boost::any(tmp);
-
- return *boost::any_cast<T>(&gmap[key].value);
-}
-
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/cli_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/cli_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,127 @@
+/**
+ * @file cli_impl.hpp
+ * @author Matthew Amidon
+ *
+ * Implementation of templated functions of the CLI class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_UTIL_CLI_IMPL_HPP
+#define __MLPACK_CORE_UTIL_CLI_IMPL_HPP
+
+// In case it has not already been included.
+#include "cli.hpp"
+
+// Include option.hpp here because it requires CLI but is also templated.
+#include "option.hpp"
+
+namespace mlpack {
+
+/**
+ * @brief Adds a parameter to CLI, making it accessibile via GetParam &
+ * CheckValue.
+ *
+ * @tparam T The type of the parameter.
+ * @param identifier The name of the parameter, eg foo in bar/foo.
+ * @param description A string description of the parameter.
+ * @param parent The name of the parent of the parameter,
+ * eg bar/foo in bar/foo/buzz.
+ * @param required If required, the program will refuse to run
+ * unless the parameter is specified.
+ */
+template<typename T>
+void CLI::Add(const std::string& path,
+ const std::string& description,
+ const std::string& alias,
+ bool required)
+{
+
+ po::options_description& desc = CLI::GetSingleton().desc;
+ // Must make use of boost syntax here.
+ std::string progOptId = alias.length() ? path + "," + alias : path;
+
+ // Add the alias, if necessary
+ AddAlias(alias, path);
+
+ // Add the option to boost program_options.
+ desc.add_options()
+ (progOptId.c_str(), po::value<T>(), description.c_str());
+
+ // Make sure the appropriate metadata is inserted into gmap.
+ gmap_t& gmap = GetSingleton().globalValues;
+
+ ParamData data;
+ T tmp = T();
+
+ data.desc = description;
+ data.name = path;
+ data.tname = TYPENAME(T);
+ data.value = boost::any(tmp);
+ data.wasPassed = false;
+
+ gmap[path] = data;
+
+ // If the option is required, add it to the required options list.
+ if (required)
+ GetSingleton().requiredOptions.push_front(path);
+}
+
+
+/**
+ * @brief Returns the value of the specified parameter.
+ * If the parameter is unspecified, an undefined but
+ * more or less valid value is returned.
+ *
+ * @tparam T The type of the parameter.
+ * @param identifier The full pathname of the parameter.
+ *
+ * @return The value of the parameter. Use CLI::CheckValue to determine if it's
+ * valid.
+ */
+template<typename T>
+T& CLI::GetParam(const std::string& identifier)
+{
+ // Used to ensure we have a valid value.
+ T tmp = T();
+
+ // Used to index into the globalValues map.
+ std::string key = std::string(identifier);
+ gmap_t& gmap = GetSingleton().globalValues;
+
+ // Now check if we have an alias.
+ amap_t& amap = GetSingleton().aliasValues;
+ if (amap.count(key))
+ key = amap[key];
+
+ // What if we don't actually have any value?
+ if (!gmap.count(key))
+ {
+ gmap[key] = ParamData();
+ gmap[key].value = boost::any(tmp);
+ *boost::any_cast<T>(&gmap[key].value) = tmp;
+ }
+
+ // What if we have meta-data, but no data?
+ boost::any val = gmap[key].value;
+ if (val.empty())
+ gmap[key].value = boost::any(tmp);
+
+ return *boost::any_cast<T>(&gmap[key].value);
+}
+
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/log.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/log.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/log.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,146 +0,0 @@
-/**
- * @file log.cpp
- * @author Matthew Amidon
- *
- * Implementation of the Log class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef _WIN32
- #include <cxxabi.h>
- #include <execinfo.h>
-#endif
-
-#include "log.hpp"
-
-// Color code escape sequences -- but not on Windows.
-#ifndef _WIN32
- #define BASH_RED "\033[0;31m"
- #define BASH_GREEN "\033[0;32m"
- #define BASH_YELLOW "\033[0;33m"
- #define BASH_CYAN "\033[0;36m"
- #define BASH_CLEAR "\033[0m"
-#else
- #define BASH_RED ""
- #define BASH_GREEN ""
- #define BASH_YELLOW ""
- #define BASH_CYAN ""
- #define BASH_CLEAR ""
-#endif
-
-using namespace mlpack;
-using namespace mlpack::util;
-
-// Only output debugging output if in debug mode.
-#ifdef DEBUG
-PrefixedOutStream Log::Debug = PrefixedOutStream(std::cout,
- BASH_CYAN "[DEBUG] " BASH_CLEAR);
-#else
-NullOutStream Log::Debug = NullOutStream();
-#endif
-
-PrefixedOutStream Log::Info = PrefixedOutStream(std::cout,
- BASH_GREEN "[INFO ] " BASH_CLEAR, true /* unless --verbose */, false);
-PrefixedOutStream Log::Warn = PrefixedOutStream(std::cout,
- BASH_YELLOW "[WARN ] " BASH_CLEAR, false, false);
-PrefixedOutStream Log::Fatal = PrefixedOutStream(std::cerr,
- BASH_RED "[FATAL] " BASH_CLEAR, false, true /* fatal */);
-
-std::ostream& Log::cout = std::cout;
-
-// Only do anything for Assert() if in debugging mode.
-#ifdef DEBUG
-void Log::Assert(bool condition, const std::string& message)
-{
- if (!condition)
- {
-#ifndef _WIN32
- void* array[25];
- size_t size = backtrace (array, sizeof(array)/sizeof(void*));
- char** messages = backtrace_symbols(array, size);
-
- // skip first stack frame (points here)
- for (size_t i = 1; i < size && messages != NULL; ++i)
- {
- char *mangledName = 0, *offsetBegin = 0, *offsetEnd = 0;
-
- // find parantheses and +address offset surrounding mangled name
- for (char *p = messages[i]; *p; ++p)
- {
- if (*p == '(')
- {
- mangledName = p;
- }
- else if (*p == '+')
- {
- offsetBegin = p;
- }
- else if (*p == ')')
- {
- offsetEnd = p;
- break;
- }
- }
-
- // if the line could be processed, attempt to demangle the symbol
- if (mangledName && offsetBegin && offsetEnd &&
- mangledName < offsetBegin)
- {
- *mangledName++ = '\0';
- *offsetBegin++ = '\0';
- *offsetEnd++ = '\0';
-
- int status;
- char* realName = abi::__cxa_demangle(mangledName, 0, 0, &status);
-
- // if demangling is successful, output the demangled function name
- if (status == 0)
- {
- Log::Debug << "[bt]: (" << i << ") " << messages[i] << " : "
- << realName << "+" << offsetBegin << offsetEnd
- << std::endl;
-
- }
- // otherwise, output the mangled function name
- else
- {
- Log::Debug << "[bt]: (" << i << ") " << messages[i] << " : "
- << mangledName << "+" << offsetBegin << offsetEnd
- << std::endl;
- }
- free(realName);
- }
- // otherwise, print the whole line
- else
- {
- Log::Debug << "[bt]: (" << i << ") " << messages[i] << std::endl;
- }
- }
-#endif
- Log::Debug << message << std::endl;
-
-#ifndef _WIN32
- free(messages);
-#endif
-
- //backtrace_symbols_fd (array, size, 2);
- exit(1);
- }
-}
-#else
-void Log::Assert(bool /* condition */, const std::string& /* message */)
-{ }
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/log.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/log.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/log.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/log.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,146 @@
+/**
+ * @file log.cpp
+ * @author Matthew Amidon
+ *
+ * Implementation of the Log class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef _WIN32
+ #include <cxxabi.h>
+ #include <execinfo.h>
+#endif
+
+#include "log.hpp"
+
+// Color code escape sequences -- but not on Windows.
+#ifndef _WIN32
+ #define BASH_RED "\033[0;31m"
+ #define BASH_GREEN "\033[0;32m"
+ #define BASH_YELLOW "\033[0;33m"
+ #define BASH_CYAN "\033[0;36m"
+ #define BASH_CLEAR "\033[0m"
+#else
+ #define BASH_RED ""
+ #define BASH_GREEN ""
+ #define BASH_YELLOW ""
+ #define BASH_CYAN ""
+ #define BASH_CLEAR ""
+#endif
+
+using namespace mlpack;
+using namespace mlpack::util;
+
+// Only output debugging output if in debug mode.
+#ifdef DEBUG
+PrefixedOutStream Log::Debug = PrefixedOutStream(std::cout,
+ BASH_CYAN "[DEBUG] " BASH_CLEAR);
+#else
+NullOutStream Log::Debug = NullOutStream();
+#endif
+
+PrefixedOutStream Log::Info = PrefixedOutStream(std::cout,
+ BASH_GREEN "[INFO ] " BASH_CLEAR, true /* unless --verbose */, false);
+PrefixedOutStream Log::Warn = PrefixedOutStream(std::cout,
+ BASH_YELLOW "[WARN ] " BASH_CLEAR, false, false);
+PrefixedOutStream Log::Fatal = PrefixedOutStream(std::cerr,
+ BASH_RED "[FATAL] " BASH_CLEAR, false, true /* fatal */);
+
+std::ostream& Log::cout = std::cout;
+
+// Only do anything for Assert() if in debugging mode.
+#ifdef DEBUG
+void Log::Assert(bool condition, const std::string& message)
+{
+ if (!condition)
+ {
+#ifndef _WIN32
+ void* array[25];
+ size_t size = backtrace (array, sizeof(array)/sizeof(void*));
+ char** messages = backtrace_symbols(array, size);
+
+ // skip first stack frame (points here)
+ for (size_t i = 1; i < size && messages != NULL; ++i)
+ {
+ char *mangledName = 0, *offsetBegin = 0, *offsetEnd = 0;
+
+ // find parantheses and +address offset surrounding mangled name
+ for (char *p = messages[i]; *p; ++p)
+ {
+ if (*p == '(')
+ {
+ mangledName = p;
+ }
+ else if (*p == '+')
+ {
+ offsetBegin = p;
+ }
+ else if (*p == ')')
+ {
+ offsetEnd = p;
+ break;
+ }
+ }
+
+ // if the line could be processed, attempt to demangle the symbol
+ if (mangledName && offsetBegin && offsetEnd &&
+ mangledName < offsetBegin)
+ {
+ *mangledName++ = '\0';
+ *offsetBegin++ = '\0';
+ *offsetEnd++ = '\0';
+
+ int status;
+ char* realName = abi::__cxa_demangle(mangledName, 0, 0, &status);
+
+ // if demangling is successful, output the demangled function name
+ if (status == 0)
+ {
+ Log::Debug << "[bt]: (" << i << ") " << messages[i] << " : "
+ << realName << "+" << offsetBegin << offsetEnd
+ << std::endl;
+
+ }
+ // otherwise, output the mangled function name
+ else
+ {
+ Log::Debug << "[bt]: (" << i << ") " << messages[i] << " : "
+ << mangledName << "+" << offsetBegin << offsetEnd
+ << std::endl;
+ }
+ free(realName);
+ }
+ // otherwise, print the whole line
+ else
+ {
+ Log::Debug << "[bt]: (" << i << ") " << messages[i] << std::endl;
+ }
+ }
+#endif
+ Log::Debug << message << std::endl;
+
+#ifndef _WIN32
+ free(messages);
+#endif
+
+ //backtrace_symbols_fd (array, size, 2);
+ exit(1);
+ }
+}
+#else
+void Log::Assert(bool /* condition */, const std::string& /* message */)
+{ }
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/log.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/log.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/log.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,103 +0,0 @@
-/**
- * @file log.hpp
- * @author Matthew Amidon
- *
- * Definition of the Log class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_IO_LOG_HPP
-#define __MLPACK_CORE_IO_LOG_HPP
-
-#include <string>
-
-#include "prefixedoutstream.hpp"
-#include "nulloutstream.hpp"
-
-namespace mlpack {
-
-/**
- * Provides a convenient way to give formatted output.
- *
- * The Log class has four members which can be used in the same way ostreams can
- * be used:
- *
- * - Log::Debug
- * - Log::Info
- * - Log::Warn
- * - Log::Fatal
- *
- * Each of these will prefix a tag to the output (for easy filtering), and the
- * fatal output will terminate the program when a newline is encountered. An
- * example is given below.
- *
- * @code
- * Log::Info << "Checking a condition." << std::endl;
- * if (!someCondition())
- * Log::Warn << "someCondition() is not satisfied!" << std::endl;
- * Log::Info << "Checking an important condition." << std::endl;
- * if (!someImportantCondition())
- * {
- * Log::Fatal << "someImportantCondition() is not satisfied! Terminating.";
- * Log::Fatal << std::endl;
- * }
- * @endcode
- *
- * Any messages sent to Log::Debug will not be shown when compiling in non-debug
- * mode. Messages to Log::Info will only be shown when the --verbose flag is
- * given to the program (or rather, the CLI class).
- *
- * @see PrefixedOutStream, NullOutStream, CLI
- */
-class Log
-{
- public:
- /**
- * Checks if the specified condition is true.
- * If not, halts program execution and prints a custom error message.
- * Does nothing in non-debug mode.
- */
- static void Assert(bool condition,
- const std::string& message = "Assert Failed.");
-
-
- // We only use PrefixedOutStream if the program is compiled with debug
- // symbols.
-#ifdef DEBUG
- //! Prints debug output with the appropriate tag: [DEBUG].
- static util::PrefixedOutStream Debug;
-#else
- //! Dumps debug output into the bit nether regions.
- static util::NullOutStream Debug;
-#endif
-
- //! Prints informational messages if --verbose is specified, prefixed with
- //! [INFO ].
- static util::PrefixedOutStream Info;
-
- //! Prints warning messages prefixed with [WARN ].
- static util::PrefixedOutStream Warn;
-
- //! Prints fatal messages prefixed with [FATAL], then terminates the program.
- static util::PrefixedOutStream Fatal;
-
- //! Reference to cout, if necessary.
- static std::ostream& cout;
-};
-
-}; //namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/log.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/log.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/log.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/log.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,103 @@
+/**
+ * @file log.hpp
+ * @author Matthew Amidon
+ *
+ * Definition of the Log class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_IO_LOG_HPP
+#define __MLPACK_CORE_IO_LOG_HPP
+
+#include <string>
+
+#include "prefixedoutstream.hpp"
+#include "nulloutstream.hpp"
+
+namespace mlpack {
+
+/**
+ * Provides a convenient way to give formatted output.
+ *
+ * The Log class has four members which can be used in the same way ostreams can
+ * be used:
+ *
+ * - Log::Debug
+ * - Log::Info
+ * - Log::Warn
+ * - Log::Fatal
+ *
+ * Each of these will prefix a tag to the output (for easy filtering), and the
+ * fatal output will terminate the program when a newline is encountered. An
+ * example is given below.
+ *
+ * @code
+ * Log::Info << "Checking a condition." << std::endl;
+ * if (!someCondition())
+ * Log::Warn << "someCondition() is not satisfied!" << std::endl;
+ * Log::Info << "Checking an important condition." << std::endl;
+ * if (!someImportantCondition())
+ * {
+ * Log::Fatal << "someImportantCondition() is not satisfied! Terminating.";
+ * Log::Fatal << std::endl;
+ * }
+ * @endcode
+ *
+ * Any messages sent to Log::Debug will not be shown when compiling in non-debug
+ * mode. Messages to Log::Info will only be shown when the --verbose flag is
+ * given to the program (or rather, the CLI class).
+ *
+ * @see PrefixedOutStream, NullOutStream, CLI
+ */
+class Log
+{
+ public:
+ /**
+ * Checks if the specified condition is true.
+ * If not, halts program execution and prints a custom error message.
+ * Does nothing in non-debug mode.
+ */
+ static void Assert(bool condition,
+ const std::string& message = "Assert Failed.");
+
+
+ // We only use PrefixedOutStream if the program is compiled with debug
+ // symbols.
+#ifdef DEBUG
+ //! Prints debug output with the appropriate tag: [DEBUG].
+ static util::PrefixedOutStream Debug;
+#else
+ //! Dumps debug output into the bit nether regions.
+ static util::NullOutStream Debug;
+#endif
+
+ //! Prints informational messages if --verbose is specified, prefixed with
+ //! [INFO ].
+ static util::PrefixedOutStream Info;
+
+ //! Prints warning messages prefixed with [WARN ].
+ static util::PrefixedOutStream Warn;
+
+ //! Prints fatal messages prefixed with [FATAL], then terminates the program.
+ static util::PrefixedOutStream Fatal;
+
+ //! Reference to cout, if necessary.
+ static std::ostream& cout;
+};
+
+}; //namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/nulloutstream.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/nulloutstream.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/nulloutstream.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,103 +0,0 @@
-/**
- * @file nulloutstream.hpp
- * @author Ryan Curtin
- * @author Matthew Amidon
- *
- * Definition of the NullOutStream class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_UTIL_NULLOUTSTREAM_HPP
-#define __MLPACK_CORE_UTIL_NULLOUTSTREAM_HPP
-
-#include <iostream>
-#include <streambuf>
-#include <string>
-
-namespace mlpack {
-namespace util {
-
-/**
- * Used for Log::Debug when not compiled with debugging symbols. This class
- * does nothing and should be optimized out entirely by the compiler.
- */
-class NullOutStream
-{
- public:
- /**
- * Does nothing.
- */
- NullOutStream() { }
-
- /**
- * Does nothing.
- */
- NullOutStream(const NullOutStream& /* other */) { }
-
- /*
- We use (void) paramName in order to avoid the warning generated by
- -Wextra. For some currently unknown reason, simply deleting the
- parameter name (aka, outperator<<(bool) {...}) causes a compilation
- error (with -Werror off) for only this class.
- */
-
- //! Does nothing.
- NullOutStream& operator<<(bool val) { (void) val; return *this; }
- //! Does nothing.
- NullOutStream& operator<<(short val) { (void) val; return *this; }
- //! Does nothing.
- NullOutStream& operator<<(unsigned short val) { (void) val; return *this; }
- //! Does nothing.
- NullOutStream& operator<<(int val) { (void) val; return *this; }
- //! Does nothing.
- NullOutStream& operator<<(unsigned int val) { (void) val; return *this; }
- //! Does nothing.
- NullOutStream& operator<<(long val) { (void) val; return *this; }
- //! Does nothing.
- NullOutStream& operator<<(unsigned long val) { (void) val; return *this; }
- //! Does nothing.
- NullOutStream& operator<<(float val) { (void) val; return *this; }
- //! Does nothing.
- NullOutStream& operator<<(double val) { (void) val; return *this; }
- //! Does nothing.
- NullOutStream& operator<<(long double val) { (void) val; return *this; }
- //! Does nothing.
- NullOutStream& operator<<(void* val) { (void) val; return *this; }
- //! Does nothing.
- NullOutStream& operator<<(const char* str) { (void) str; return *this; }
- //! Does nothing.
- NullOutStream& operator<<(std::string& str) { (void) str; return *this; }
- //! Does nothing.
- NullOutStream& operator<<(std::streambuf* sb) { (void) sb; return *this; }
- //! Does nothing.
- NullOutStream& operator<<(std::ostream& (*pf) (std::ostream&))
- { (void) pf; return *this; }
- //! Does nothing.
- NullOutStream& operator<<(std::ios& (*pf) (std::ios&)) { (void) pf; return *this; }
- //! Does nothing.
- NullOutStream& operator<<(std::ios_base& (*pf) (std::ios_base&))
- { (void) pf; return *this; }
-
- //! Does nothing.
- template<typename T>
- NullOutStream& operator<<(T& s)
- { (void) s; return *this; }
-};
-
-} // namespace util
-} // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/nulloutstream.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/nulloutstream.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/nulloutstream.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/nulloutstream.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,103 @@
+/**
+ * @file nulloutstream.hpp
+ * @author Ryan Curtin
+ * @author Matthew Amidon
+ *
+ * Definition of the NullOutStream class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_UTIL_NULLOUTSTREAM_HPP
+#define __MLPACK_CORE_UTIL_NULLOUTSTREAM_HPP
+
+#include <iostream>
+#include <streambuf>
+#include <string>
+
+namespace mlpack {
+namespace util {
+
+/**
+ * Used for Log::Debug when not compiled with debugging symbols. This class
+ * does nothing and should be optimized out entirely by the compiler.
+ */
+class NullOutStream
+{
+ public:
+ /**
+ * Does nothing.
+ */
+ NullOutStream() { }
+
+ /**
+ * Does nothing.
+ */
+ NullOutStream(const NullOutStream& /* other */) { }
+
+ /*
+ We use (void) paramName in order to avoid the warning generated by
+ -Wextra. For some currently unknown reason, simply deleting the
+ parameter name (aka, outperator<<(bool) {...}) causes a compilation
+ error (with -Werror off) for only this class.
+ */
+
+ //! Does nothing.
+ NullOutStream& operator<<(bool val) { (void) val; return *this; }
+ //! Does nothing.
+ NullOutStream& operator<<(short val) { (void) val; return *this; }
+ //! Does nothing.
+ NullOutStream& operator<<(unsigned short val) { (void) val; return *this; }
+ //! Does nothing.
+ NullOutStream& operator<<(int val) { (void) val; return *this; }
+ //! Does nothing.
+ NullOutStream& operator<<(unsigned int val) { (void) val; return *this; }
+ //! Does nothing.
+ NullOutStream& operator<<(long val) { (void) val; return *this; }
+ //! Does nothing.
+ NullOutStream& operator<<(unsigned long val) { (void) val; return *this; }
+ //! Does nothing.
+ NullOutStream& operator<<(float val) { (void) val; return *this; }
+ //! Does nothing.
+ NullOutStream& operator<<(double val) { (void) val; return *this; }
+ //! Does nothing.
+ NullOutStream& operator<<(long double val) { (void) val; return *this; }
+ //! Does nothing.
+ NullOutStream& operator<<(void* val) { (void) val; return *this; }
+ //! Does nothing.
+ NullOutStream& operator<<(const char* str) { (void) str; return *this; }
+ //! Does nothing.
+ NullOutStream& operator<<(std::string& str) { (void) str; return *this; }
+ //! Does nothing.
+ NullOutStream& operator<<(std::streambuf* sb) { (void) sb; return *this; }
+ //! Does nothing.
+ NullOutStream& operator<<(std::ostream& (*pf) (std::ostream&))
+ { (void) pf; return *this; }
+ //! Does nothing.
+ NullOutStream& operator<<(std::ios& (*pf) (std::ios&)) { (void) pf; return *this; }
+ //! Does nothing.
+ NullOutStream& operator<<(std::ios_base& (*pf) (std::ios_base&))
+ { (void) pf; return *this; }
+
+ //! Does nothing.
+ template<typename T>
+ NullOutStream& operator<<(T& s)
+ { (void) s; return *this; }
+};
+
+} // namespace util
+} // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/option.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/option.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/option.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,49 +0,0 @@
-/**
- * @file option.cpp
- * @author Ryan Curtin
- *
- * Implementation of the ProgramDoc class. The class registers itself with CLI
- * when constructed.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "cli.hpp"
-#include "option.hpp"
-
-#include <string>
-
-using namespace mlpack;
-using namespace mlpack::util;
-using namespace std;
-
-/**
- * Construct a ProgramDoc object. When constructed, it will register itself
- * with CLI. A fatal error will be thrown if more than one is constructed.
- *
- * @param programName Short string representing the name of the program.
- * @param documentation Long string containing documentation on how to use the
- * program and what it is. No newline characters are necessary; this is
- * taken care of by CLI later.
- * @param defaultModule Name of the default module.
- */
-ProgramDoc::ProgramDoc(const std::string& programName,
- const std::string& documentation) :
- programName(programName),
- documentation(documentation)
-{
- // Register this with CLI.
- CLI::RegisterProgramDoc(this);
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/option.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/option.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/option.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/option.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,49 @@
+/**
+ * @file option.cpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the ProgramDoc class. The class registers itself with CLI
+ * when constructed.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "cli.hpp"
+#include "option.hpp"
+
+#include <string>
+
+using namespace mlpack;
+using namespace mlpack::util;
+using namespace std;
+
+/**
+ * Construct a ProgramDoc object. When constructed, it will register itself
+ * with CLI. A fatal error will be thrown if more than one is constructed.
+ *
+ * @param programName Short string representing the name of the program.
+ * @param documentation Long string containing documentation on how to use the
+ * program and what it is. No newline characters are necessary; this is
+ * taken care of by CLI later.
+ * @param defaultModule Name of the default module.
+ */
+ProgramDoc::ProgramDoc(const std::string& programName,
+ const std::string& documentation) :
+ programName(programName),
+ documentation(documentation)
+{
+ // Register this with CLI.
+ CLI::RegisterProgramDoc(this);
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/option.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/option.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/option.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,117 +0,0 @@
-/**
- * @file option.hpp
- * @author Matthew Amidon
- *
- * Definition of the Option class, which is used to define parameters which are
- * used by CLI. The ProgramDoc class also resides here.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_IO_OPTION_HPP
-#define __MLPACK_CORE_IO_OPTION_HPP
-
-#include <string>
-
-#include "cli.hpp"
-
-namespace mlpack {
-namespace util {
-
-/**
- * A static object whose constructor registers a parameter with the CLI class.
- * This should not be used outside of CLI itself, and you should use the
- * PARAM_FLAG(), PARAM_DOUBLE(), PARAM_INT(), PARAM_STRING(), or other similar
- * macros to declare these objects instead of declaring them directly.
- *
- * @see core/io/cli.hpp, mlpack::CLI
- */
-template<typename N>
-class Option
-{
- public:
- /**
- * Construct an Option object. When constructed, it will register
- * itself with CLI.
- *
- * @param ignoreTemplate Whether or not the template type matters for this
- * option. Essentially differs options with no value (flags) from those
- * that do, and thus require a type.
- * @param defaultValue Default value this parameter will be initialized to.
- * @param identifier The name of the option (no dashes in front; for --help,
- * we would pass "help").
- * @param description A short string describing the option.
- * @param parent Full pathname of the parent module that "owns" this option.
- * The default is the root node (an empty string).
- * @param required Whether or not the option is required at runtime.
- */
- Option(bool ignoreTemplate,
- N defaultValue,
- const std::string& identifier,
- const std::string& description,
- const std::string& parent = std::string(""),
- bool required = false);
-
- /**
- * Constructs an Option object. When constructed, it will register a flag
- * with CLI.
- *
- * @param identifier The name of the option (no dashes in front); for --help
- * we would pass "help".
- * @param description A short string describing the option.
- * @param parent Full pathname of the parent module that "owns" this option.
- * The default is the root node (an empty string).
- */
- Option(const std::string& identifier,
- const std::string& description,
- const std::string& parent = std::string(""));
-};
-
-/**
- * A static object whose constructor registers program documentation with the
- * CLI class. This should not be used outside of CLI itself, and you should use
- * the PROGRAM_INFO() macro to declare these objects. Only one ProgramDoc
- * object should ever exist.
- *
- * @see core/io/cli.hpp, mlpack::CLI
- */
-class ProgramDoc
-{
- public:
- /**
- * Construct a ProgramDoc object. When constructed, it will register itself
- * with CLI.
- *
- * @param programName Short string representing the name of the program.
- * @param documentation Long string containing documentation on how to use the
- * program and what it is. No newline characters are necessary; this is
- * taken care of by CLI later.
- */
- ProgramDoc(const std::string& programName,
- const std::string& documentation);
-
- //! The name of the program.
- std::string programName;
- //! Documentation for what the program does.
- std::string documentation;
-};
-
-}; // namespace util
-}; // namespace mlpack
-
-// For implementations of templated functions
-#include "option_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/option.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/option.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/option.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/option.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,117 @@
+/**
+ * @file option.hpp
+ * @author Matthew Amidon
+ *
+ * Definition of the Option class, which is used to define parameters which are
+ * used by CLI. The ProgramDoc class also resides here.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_IO_OPTION_HPP
+#define __MLPACK_CORE_IO_OPTION_HPP
+
+#include <string>
+
+#include "cli.hpp"
+
+namespace mlpack {
+namespace util {
+
+/**
+ * A static object whose constructor registers a parameter with the CLI class.
+ * This should not be used outside of CLI itself, and you should use the
+ * PARAM_FLAG(), PARAM_DOUBLE(), PARAM_INT(), PARAM_STRING(), or other similar
+ * macros to declare these objects instead of declaring them directly.
+ *
+ * @see core/io/cli.hpp, mlpack::CLI
+ */
+template<typename N>
+class Option
+{
+ public:
+ /**
+ * Construct an Option object. When constructed, it will register
+ * itself with CLI.
+ *
+ * @param ignoreTemplate Whether or not the template type matters for this
+ * option. Essentially differs options with no value (flags) from those
+ * that do, and thus require a type.
+ * @param defaultValue Default value this parameter will be initialized to.
+ * @param identifier The name of the option (no dashes in front; for --help,
+ * we would pass "help").
+ * @param description A short string describing the option.
+ * @param parent Full pathname of the parent module that "owns" this option.
+ * The default is the root node (an empty string).
+ * @param required Whether or not the option is required at runtime.
+ */
+ Option(bool ignoreTemplate,
+ N defaultValue,
+ const std::string& identifier,
+ const std::string& description,
+ const std::string& parent = std::string(""),
+ bool required = false);
+
+ /**
+ * Constructs an Option object. When constructed, it will register a flag
+ * with CLI.
+ *
+ * @param identifier The name of the option (no dashes in front); for --help
+ * we would pass "help".
+ * @param description A short string describing the option.
+ * @param parent Full pathname of the parent module that "owns" this option.
+ * The default is the root node (an empty string).
+ */
+ Option(const std::string& identifier,
+ const std::string& description,
+ const std::string& parent = std::string(""));
+};
+
+/**
+ * A static object whose constructor registers program documentation with the
+ * CLI class. This should not be used outside of CLI itself, and you should use
+ * the PROGRAM_INFO() macro to declare these objects. Only one ProgramDoc
+ * object should ever exist.
+ *
+ * @see core/io/cli.hpp, mlpack::CLI
+ */
+class ProgramDoc
+{
+ public:
+ /**
+ * Construct a ProgramDoc object. When constructed, it will register itself
+ * with CLI.
+ *
+ * @param programName Short string representing the name of the program.
+ * @param documentation Long string containing documentation on how to use the
+ * program and what it is. No newline characters are necessary; this is
+ * taken care of by CLI later.
+ */
+ ProgramDoc(const std::string& programName,
+ const std::string& documentation);
+
+ //! The name of the program.
+ std::string programName;
+ //! Documentation for what the program does.
+ std::string documentation;
+};
+
+}; // namespace util
+}; // namespace mlpack
+
+// For implementations of templated functions
+#include "option_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/option_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/option_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/option_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,68 +0,0 @@
-/**
- * @file option_impl.hpp
- * @author Matthew Amidon
- *
- * Implementation of template functions for the Option class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_IO_OPTION_IMPL_HPP
-#define __MLPACK_CORE_IO_OPTION_IMPL_HPP
-
-// Just in case it has not been included.
-#include "option.hpp"
-
-namespace mlpack {
-namespace util {
-
-/**
- * Registers a parameter with CLI.
- */
-template<typename N>
-Option<N>::Option(bool ignoreTemplate,
- N defaultValue,
- const std::string& identifier,
- const std::string& description,
- const std::string& alias,
- bool required)
-{
- if (ignoreTemplate)
- {
- CLI::Add(identifier, description, alias, required);
- }
- else
- {
- CLI::Add<N>(identifier, description, alias, required);
- CLI::GetParam<N>(identifier) = defaultValue;
- }
-}
-
-
-/**
- * Registers a flag parameter with CLI.
- */
-template<typename N>
-Option<N>::Option(const std::string& identifier,
- const std::string& description,
- const std::string& alias)
-{
- CLI::AddFlag(identifier, description, alias);
-}
-
-}; // namespace util
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/option_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/option_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/option_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/option_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,68 @@
+/**
+ * @file option_impl.hpp
+ * @author Matthew Amidon
+ *
+ * Implementation of template functions for the Option class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_IO_OPTION_IMPL_HPP
+#define __MLPACK_CORE_IO_OPTION_IMPL_HPP
+
+// Just in case it has not been included.
+#include "option.hpp"
+
+namespace mlpack {
+namespace util {
+
+/**
+ * Registers a parameter with CLI.
+ */
+template<typename N>
+Option<N>::Option(bool ignoreTemplate,
+ N defaultValue,
+ const std::string& identifier,
+ const std::string& description,
+ const std::string& alias,
+ bool required)
+{
+ if (ignoreTemplate)
+ {
+ CLI::Add(identifier, description, alias, required);
+ }
+ else
+ {
+ CLI::Add<N>(identifier, description, alias, required);
+ CLI::GetParam<N>(identifier) = defaultValue;
+ }
+}
+
+
+/**
+ * Registers a flag parameter with CLI.
+ */
+template<typename N>
+Option<N>::Option(const std::string& identifier,
+ const std::string& description,
+ const std::string& alias)
+{
+ CLI::AddFlag(identifier, description, alias);
+}
+
+}; // namespace util
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/prefixedoutstream.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/prefixedoutstream.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/prefixedoutstream.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,142 +0,0 @@
-/**
- * @file prefixedoutstream.cpp
- * @author Ryan Curtin
- * @author Matthew Amidon
- *
- * Implementation of PrefixedOutStream methods.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <string>
-#include <iostream>
-#include <streambuf>
-#include <string.h>
-#include <stdlib.h>
-
-#include "prefixedoutstream.hpp"
-
-using namespace mlpack::util;
-
-/**
- * These are all necessary because gcc's template mechanism does not seem smart
- * enough to figure out what I want to pass into operator<< without these. That
- * may not be the actual case, but it works when these is here.
- */
-
-PrefixedOutStream& PrefixedOutStream::operator<<(bool val)
-{
- BaseLogic<bool>(val);
- return *this;
-}
-
-PrefixedOutStream& PrefixedOutStream::operator<<(short val)
-{
- BaseLogic<short>(val);
- return *this;
-}
-
-PrefixedOutStream& PrefixedOutStream::operator<<(unsigned short val)
-{
- BaseLogic<unsigned short>(val);
- return *this;
-}
-
-PrefixedOutStream& PrefixedOutStream::operator<<(int val)
-{
- BaseLogic<int>(val);
- return *this;
-}
-
-PrefixedOutStream& PrefixedOutStream::operator<<(unsigned int val)
-{
- BaseLogic<unsigned int>(val);
- return *this;
-}
-
-PrefixedOutStream& PrefixedOutStream::operator<<(long val)
-{
- BaseLogic<long>(val);
- return *this;
-}
-
-PrefixedOutStream& PrefixedOutStream::operator<<(unsigned long val)
-{
- BaseLogic<unsigned long>(val);
- return *this;
-}
-
-PrefixedOutStream& PrefixedOutStream::operator<<(float val)
-{
- BaseLogic<float>(val);
- return *this;
-}
-
-PrefixedOutStream& PrefixedOutStream::operator<<(double val)
-{
- BaseLogic<double>(val);
- return *this;
-}
-
-PrefixedOutStream& PrefixedOutStream::operator<<(long double val)
-{
- BaseLogic<long double>(val);
- return *this;
-}
-
-PrefixedOutStream& PrefixedOutStream::operator<<(void* val)
-{
- BaseLogic<void*>(val);
- return *this;
-}
-
-PrefixedOutStream& PrefixedOutStream::operator<<(const char* str)
-{
- BaseLogic<const char*>(str);
- return *this;
-}
-
-
-PrefixedOutStream& PrefixedOutStream::operator<<(std::string& str)
-{
- BaseLogic<std::string>(str);
- return *this;
-}
-
-PrefixedOutStream& PrefixedOutStream::operator<<(std::streambuf* sb)
-{
- BaseLogic<std::streambuf*>(sb);
- return *this;
-}
-
-PrefixedOutStream& PrefixedOutStream::operator<<(
- std::ostream& (*pf)(std::ostream&))
-{
- BaseLogic<std::ostream& (*)(std::ostream&)>(pf);
- return *this;
-}
-
-PrefixedOutStream& PrefixedOutStream::operator<<(std::ios& (*pf)(std::ios&))
-{
- BaseLogic<std::ios& (*)(std::ios&)>(pf);
- return *this;
-}
-
-PrefixedOutStream& PrefixedOutStream::operator<<(
- std::ios_base& (*pf) (std::ios_base&))
-{
- BaseLogic<std::ios_base& (*)(std::ios_base&)>(pf);
- return *this;
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/prefixedoutstream.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/prefixedoutstream.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/prefixedoutstream.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/prefixedoutstream.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,142 @@
+/**
+ * @file prefixedoutstream.cpp
+ * @author Ryan Curtin
+ * @author Matthew Amidon
+ *
+ * Implementation of PrefixedOutStream methods.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <string>
+#include <iostream>
+#include <streambuf>
+#include <string.h>
+#include <stdlib.h>
+
+#include "prefixedoutstream.hpp"
+
+using namespace mlpack::util;
+
+/**
+ * These are all necessary because gcc's template mechanism does not seem smart
+ * enough to figure out what I want to pass into operator<< without these. That
+ * may not be the actual case, but it works when these is here.
+ */
+
+PrefixedOutStream& PrefixedOutStream::operator<<(bool val)
+{
+ BaseLogic<bool>(val);
+ return *this;
+}
+
+PrefixedOutStream& PrefixedOutStream::operator<<(short val)
+{
+ BaseLogic<short>(val);
+ return *this;
+}
+
+PrefixedOutStream& PrefixedOutStream::operator<<(unsigned short val)
+{
+ BaseLogic<unsigned short>(val);
+ return *this;
+}
+
+PrefixedOutStream& PrefixedOutStream::operator<<(int val)
+{
+ BaseLogic<int>(val);
+ return *this;
+}
+
+PrefixedOutStream& PrefixedOutStream::operator<<(unsigned int val)
+{
+ BaseLogic<unsigned int>(val);
+ return *this;
+}
+
+PrefixedOutStream& PrefixedOutStream::operator<<(long val)
+{
+ BaseLogic<long>(val);
+ return *this;
+}
+
+PrefixedOutStream& PrefixedOutStream::operator<<(unsigned long val)
+{
+ BaseLogic<unsigned long>(val);
+ return *this;
+}
+
+PrefixedOutStream& PrefixedOutStream::operator<<(float val)
+{
+ BaseLogic<float>(val);
+ return *this;
+}
+
+PrefixedOutStream& PrefixedOutStream::operator<<(double val)
+{
+ BaseLogic<double>(val);
+ return *this;
+}
+
+PrefixedOutStream& PrefixedOutStream::operator<<(long double val)
+{
+ BaseLogic<long double>(val);
+ return *this;
+}
+
+PrefixedOutStream& PrefixedOutStream::operator<<(void* val)
+{
+ BaseLogic<void*>(val);
+ return *this;
+}
+
+PrefixedOutStream& PrefixedOutStream::operator<<(const char* str)
+{
+ BaseLogic<const char*>(str);
+ return *this;
+}
+
+
+PrefixedOutStream& PrefixedOutStream::operator<<(std::string& str)
+{
+ BaseLogic<std::string>(str);
+ return *this;
+}
+
+PrefixedOutStream& PrefixedOutStream::operator<<(std::streambuf* sb)
+{
+ BaseLogic<std::streambuf*>(sb);
+ return *this;
+}
+
+PrefixedOutStream& PrefixedOutStream::operator<<(
+ std::ostream& (*pf)(std::ostream&))
+{
+ BaseLogic<std::ostream& (*)(std::ostream&)>(pf);
+ return *this;
+}
+
+PrefixedOutStream& PrefixedOutStream::operator<<(std::ios& (*pf)(std::ios&))
+{
+ BaseLogic<std::ios& (*)(std::ios&)>(pf);
+ return *this;
+}
+
+PrefixedOutStream& PrefixedOutStream::operator<<(
+ std::ios_base& (*pf) (std::ios_base&))
+{
+ BaseLogic<std::ios_base& (*)(std::ios_base&)>(pf);
+ return *this;
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/prefixedoutstream.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/prefixedoutstream.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/prefixedoutstream.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,196 +0,0 @@
-/**
- * @file prefixedoutstream.hpp
- * @author Ryan Curtin
- * @author Matthew Amidon
- *
- * Declaration of the PrefixedOutStream class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_IO_PREFIXEDOUTSTREAM_HPP
-#define __MLPACK_CORE_IO_PREFIXEDOUTSTREAM_HPP
-
-#include <iostream>
-#include <iomanip>
-#include <string>
-#include <streambuf>
-
-#include <boost/lexical_cast.hpp>
-#include <boost/utility/enable_if.hpp>
-#include <boost/type_traits.hpp>
-
-#include <mlpack/core/util/sfinae_utility.hpp>
-#include <mlpack/core/util/string_util.hpp>
-
-namespace mlpack {
-namespace util {
-
-/**
- * Allows us to output to an ostream with a prefix at the beginning of each
- * line, in the same way we would output to cout or cerr. The prefix is
- * specified in the constructor (as well as the destination ostream). A newline
- * must be passed to the stream, and then the prefix will be prepended to the
- * next line. For example,
- *
- * @code
- * PrefixedOutStream outstr(std::cout, "[TEST] ");
- * outstr << "Hello world I like " << 7.5;
- * outstr << "...Continue" << std::endl;
- * outstr << "After the CR\n" << std::endl;
- * @endcode
- *
- * would give, on std::cout,
- *
- * @code
- * [TEST] Hello world I like 7.5...Continue
- * [TEST] After the CR
- * [TEST]
- * @endcode
- *
- * These objects are used for the mlpack::Log levels (DEBUG, INFO, WARN, and
- * FATAL).
- */
-class PrefixedOutStream
-{
- public:
- /**
- * Set up the PrefixedOutStream.
- *
- * @param destination ostream which receives output from this object.
- * @param prefix The prefix to prepend to each line.
- */
- PrefixedOutStream(std::ostream& destination,
- const char* prefix,
- bool ignoreInput = false,
- bool fatal = false) :
- destination(destination),
- ignoreInput(ignoreInput),
- prefix(prefix),
- // We want the first call to operator<< to prefix the prefix so we set
- // carriageReturned to true.
- carriageReturned(true),
- fatal(fatal)
- { /* nothing to do */ }
-
- //! Write a bool to the stream.
- PrefixedOutStream& operator<<(bool val);
- //! Write a short to the stream.
- PrefixedOutStream& operator<<(short val);
- //! Write an unsigned short to the stream.
- PrefixedOutStream& operator<<(unsigned short val);
- //! Write an int to the stream.
- PrefixedOutStream& operator<<(int val);
- //! Write an unsigned int to the stream.
- PrefixedOutStream& operator<<(unsigned int val);
- //! Write a long to the stream.
- PrefixedOutStream& operator<<(long val);
- //! Write an unsigned long to the stream.
- PrefixedOutStream& operator<<(unsigned long val);
- //! Write a float to the stream.
- PrefixedOutStream& operator<<(float val);
- //! Write a double to the stream.
- PrefixedOutStream& operator<<(double val);
- //! Write a long double to the stream.
- PrefixedOutStream& operator<<(long double val);
- //! Write a void pointer to the stream.
- PrefixedOutStream& operator<<(void* val);
- //! Write a character array to the stream.
- PrefixedOutStream& operator<<(const char* str);
- //! Write a string to the stream.
- PrefixedOutStream& operator<<(std::string& str);
- //! Write a streambuf to the stream.
- PrefixedOutStream& operator<<(std::streambuf* sb);
- //! Write an ostream manipulator function to the stream.
- PrefixedOutStream& operator<<(std::ostream& (*pf)(std::ostream&));
- //! Write an ios manipulator function to the stream.
- PrefixedOutStream& operator<<(std::ios& (*pf)(std::ios&));
- //! Write an ios_base manipulator function to the stream.
- PrefixedOutStream& operator<<(std::ios_base& (*pf)(std::ios_base&));
-
- //! Write anything else to the stream.
- template<typename T>
- PrefixedOutStream& operator<<(const T& s);
-
- //! The output stream that all data is to be sent too; example: std::cout.
- std::ostream& destination;
-
- //! Discards input, prints nothing if true.
- bool ignoreInput;
-
- private:
- HAS_MEM_FUNC(ToString, HasToString)
-
- //! This handles forwarding all primitive types transparently
- template<typename T>
- void CallBaseLogic(const T& s,
- typename boost::disable_if<
- boost::is_class<T>
- >::type* = 0);
-
- //! Forward all objects that do not implement a ToString() method
- template<typename T>
- void CallBaseLogic(const T& s,
- typename boost::enable_if<
- boost::is_class<T>
- >::type* = 0,
- typename boost::disable_if<
- HasToString<T, std::string(T::*)() const>
- >::type* = 0);
-
- //! Call ToString() on all objects that implement ToString() before forwarding
- template<typename T>
- void CallBaseLogic(const T& s,
- typename boost::enable_if<
- boost::is_class<T>
- >::type* = 0,
- typename boost::enable_if<
- HasToString<T, std::string(T::*)() const>
- >::type* = 0);
-
- /**
- * @brief Conducts the base logic required in all the operator << overloads.
- * Mostly just a good idea to reduce copy-pasta.
- *
- * @tparam T The type of the data to output.
- * @param val The The data to be output.
- */
- template<typename T>
- void BaseLogic(const T& val);
-
- /**
- * Output the prefix, but only if we need to and if we are allowed to.
- */
- inline void PrefixIfNeeded();
-
- //! Contains the prefix we must prepend to each line.
- std::string prefix;
-
- //! If true, the previous call to operator<< encountered a CR, and a prefix
- //! will be necessary.
- bool carriageReturned;
-
- //! If true, the application will terminate with an error code when a CR is
- //! encountered.
- bool fatal;
-};
-
-}; // namespace util
-}; // namespace mlpack
-
-// Template definitions.
-#include "prefixedoutstream_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/prefixedoutstream.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/prefixedoutstream.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/prefixedoutstream.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/prefixedoutstream.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,196 @@
+/**
+ * @file prefixedoutstream.hpp
+ * @author Ryan Curtin
+ * @author Matthew Amidon
+ *
+ * Declaration of the PrefixedOutStream class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_IO_PREFIXEDOUTSTREAM_HPP
+#define __MLPACK_CORE_IO_PREFIXEDOUTSTREAM_HPP
+
+#include <iostream>
+#include <iomanip>
+#include <string>
+#include <streambuf>
+
+#include <boost/lexical_cast.hpp>
+#include <boost/utility/enable_if.hpp>
+#include <boost/type_traits.hpp>
+
+#include <mlpack/core/util/sfinae_utility.hpp>
+#include <mlpack/core/util/string_util.hpp>
+
+namespace mlpack {
+namespace util {
+
+/**
+ * Allows us to output to an ostream with a prefix at the beginning of each
+ * line, in the same way we would output to cout or cerr. The prefix is
+ * specified in the constructor (as well as the destination ostream). A newline
+ * must be passed to the stream, and then the prefix will be prepended to the
+ * next line. For example,
+ *
+ * @code
+ * PrefixedOutStream outstr(std::cout, "[TEST] ");
+ * outstr << "Hello world I like " << 7.5;
+ * outstr << "...Continue" << std::endl;
+ * outstr << "After the CR\n" << std::endl;
+ * @endcode
+ *
+ * would give, on std::cout,
+ *
+ * @code
+ * [TEST] Hello world I like 7.5...Continue
+ * [TEST] After the CR
+ * [TEST]
+ * @endcode
+ *
+ * These objects are used for the mlpack::Log levels (DEBUG, INFO, WARN, and
+ * FATAL).
+ */
+class PrefixedOutStream
+{
+ public:
+ /**
+ * Set up the PrefixedOutStream.
+ *
+ * @param destination ostream which receives output from this object.
+ * @param prefix The prefix to prepend to each line.
+ */
+ PrefixedOutStream(std::ostream& destination,
+ const char* prefix,
+ bool ignoreInput = false,
+ bool fatal = false) :
+ destination(destination),
+ ignoreInput(ignoreInput),
+ prefix(prefix),
+ // We want the first call to operator<< to prefix the prefix so we set
+ // carriageReturned to true.
+ carriageReturned(true),
+ fatal(fatal)
+ { /* nothing to do */ }
+
+ //! Write a bool to the stream.
+ PrefixedOutStream& operator<<(bool val);
+ //! Write a short to the stream.
+ PrefixedOutStream& operator<<(short val);
+ //! Write an unsigned short to the stream.
+ PrefixedOutStream& operator<<(unsigned short val);
+ //! Write an int to the stream.
+ PrefixedOutStream& operator<<(int val);
+ //! Write an unsigned int to the stream.
+ PrefixedOutStream& operator<<(unsigned int val);
+ //! Write a long to the stream.
+ PrefixedOutStream& operator<<(long val);
+ //! Write an unsigned long to the stream.
+ PrefixedOutStream& operator<<(unsigned long val);
+ //! Write a float to the stream.
+ PrefixedOutStream& operator<<(float val);
+ //! Write a double to the stream.
+ PrefixedOutStream& operator<<(double val);
+ //! Write a long double to the stream.
+ PrefixedOutStream& operator<<(long double val);
+ //! Write a void pointer to the stream.
+ PrefixedOutStream& operator<<(void* val);
+ //! Write a character array to the stream.
+ PrefixedOutStream& operator<<(const char* str);
+ //! Write a string to the stream.
+ PrefixedOutStream& operator<<(std::string& str);
+ //! Write a streambuf to the stream.
+ PrefixedOutStream& operator<<(std::streambuf* sb);
+ //! Write an ostream manipulator function to the stream.
+ PrefixedOutStream& operator<<(std::ostream& (*pf)(std::ostream&));
+ //! Write an ios manipulator function to the stream.
+ PrefixedOutStream& operator<<(std::ios& (*pf)(std::ios&));
+ //! Write an ios_base manipulator function to the stream.
+ PrefixedOutStream& operator<<(std::ios_base& (*pf)(std::ios_base&));
+
+ //! Write anything else to the stream.
+ template<typename T>
+ PrefixedOutStream& operator<<(const T& s);
+
+ //! The output stream that all data is to be sent too; example: std::cout.
+ std::ostream& destination;
+
+ //! Discards input, prints nothing if true.
+ bool ignoreInput;
+
+ private:
+ HAS_MEM_FUNC(ToString, HasToString)
+
+ //! This handles forwarding all primitive types transparently
+ template<typename T>
+ void CallBaseLogic(const T& s,
+ typename boost::disable_if<
+ boost::is_class<T>
+ >::type* = 0);
+
+ //! Forward all objects that do not implement a ToString() method
+ template<typename T>
+ void CallBaseLogic(const T& s,
+ typename boost::enable_if<
+ boost::is_class<T>
+ >::type* = 0,
+ typename boost::disable_if<
+ HasToString<T, std::string(T::*)() const>
+ >::type* = 0);
+
+ //! Call ToString() on all objects that implement ToString() before forwarding
+ template<typename T>
+ void CallBaseLogic(const T& s,
+ typename boost::enable_if<
+ boost::is_class<T>
+ >::type* = 0,
+ typename boost::enable_if<
+ HasToString<T, std::string(T::*)() const>
+ >::type* = 0);
+
+ /**
+ * @brief Conducts the base logic required in all the operator << overloads.
+ * Mostly just a good idea to reduce copy-pasta.
+ *
+ * @tparam T The type of the data to output.
+ * @param val The The data to be output.
+ */
+ template<typename T>
+ void BaseLogic(const T& val);
+
+ /**
+ * Output the prefix, but only if we need to and if we are allowed to.
+ */
+ inline void PrefixIfNeeded();
+
+ //! Contains the prefix we must prepend to each line.
+ std::string prefix;
+
+ //! If true, the previous call to operator<< encountered a CR, and a prefix
+ //! will be necessary.
+ bool carriageReturned;
+
+ //! If true, the application will terminate with an error code when a CR is
+ //! encountered.
+ bool fatal;
+};
+
+}; // namespace util
+}; // namespace mlpack
+
+// Template definitions.
+#include "prefixedoutstream_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/prefixedoutstream_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/prefixedoutstream_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/prefixedoutstream_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,167 +0,0 @@
-/**
- * @file prefixedoutstream.hpp
- * @author Ryan Curtin
- * @author Matthew Amidon
- *
- * Implementation of templated PrefixedOutStream member functions.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_IO_PREFIXEDOUTSTREAM_IMPL_HPP
-#define __MLPACK_CORE_IO_PREFIXEDOUTSTREAM_IMPL_HPP
-
-// Just in case it hasn't been included.
-#include "prefixedoutstream.hpp"
-#include <iostream>
-
-namespace mlpack {
-namespace util {
-
-template<typename T>
-PrefixedOutStream& PrefixedOutStream::operator<<(const T& s)
-{
- CallBaseLogic<T>(s);
- return *this;
-}
-
-//! This handles forwarding all primitive types transparently
-template<typename T>
-void PrefixedOutStream::CallBaseLogic(const T& s,
- typename boost::disable_if<
- boost::is_class<T>
- >::type*)
-{
- BaseLogic<T>(s);
-}
-
-// Forward all objects that do not implement a ToString() method transparently
-template<typename T>
-void PrefixedOutStream::CallBaseLogic(const T& s,
- typename boost::enable_if<
- boost::is_class<T>
- >::type*,
- typename boost::disable_if<
- HasToString<T, std::string(T::*)() const>
- >::type*)
-{
- BaseLogic<T>(s);
-}
-
-// Call ToString() on all objects that implement ToString() before forwarding
-template<typename T>
-void PrefixedOutStream::CallBaseLogic(const T& s,
- typename boost::enable_if<
- boost::is_class<T>
- >::type*,
- typename boost::enable_if<
- HasToString<T, std::string(T::*)() const>
- >::type*)
-{
- std::string result = s.ToString();
- BaseLogic<std::string>(result);
-}
-
-template<typename T>
-void PrefixedOutStream::BaseLogic(const T& val)
-{
- // We will use this to track whether or not we need to terminate at the end of
- // this call (only for streams which terminate after a newline).
- bool newlined = false;
- std::string line;
-
- // If we need to, output the prefix.
- PrefixIfNeeded();
-
- std::ostringstream convert;
- convert << val;
-
- if(convert.fail())
- {
- PrefixIfNeeded();
- if (!ignoreInput)
- {
- destination << "Failed lexical_cast<std::string>(T) for output; output"
- " not shown." << std::endl;
- newlined = true;
- }
- }
- else
- {
- line = convert.str();
-
- // If the length of the casted thing was 0, it may have been a stream
- // manipulator, so send it directly to the stream and don't ask questions.
- if (line.length() == 0)
- {
- // The prefix cannot be necessary at this point.
- if (!ignoreInput) // Only if the user wants it.
- destination << val;
-
- return;
- }
-
- // Now, we need to check for newlines in this line. If we find one, output
- // up until the newline, then output the newline and the prefix and continue
- // looking.
- size_t nl;
- size_t pos = 0;
- while ((nl = line.find('\n', pos)) != std::string::npos)
- {
- PrefixIfNeeded();
-
- // Only output if the user wants it.
- if (!ignoreInput)
- {
- destination << line.substr(pos, nl - pos);
- destination << std::endl;
- newlined = true;
- }
-
- carriageReturned = true; // Regardless of whether or not we display it.
-
- pos = nl + 1;
- }
-
- if (pos != line.length()) // We need to display the rest.
- {
- PrefixIfNeeded();
- if (!ignoreInput)
- destination << line.substr(pos);
- }
- }
-
- // If we displayed a newline and we need to terminate afterwards, do that.
- if (fatal && newlined)
- exit(1);
-}
-
-// This is an inline function (that is why it is here and not in .cc).
-void PrefixedOutStream::PrefixIfNeeded()
-{
- // If we need to, output a prefix.
- if (carriageReturned)
- {
- if (!ignoreInput) // But only if we are allowed to.
- destination << prefix;
-
- carriageReturned = false; // Denote that the prefix has been displayed.
- }
-}
-
-}; // namespace util
-}; // namespace mlpack
-
-#endif // MLPACK_CLI_PREFIXEDOUTSTREAM_IMPL_H
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/prefixedoutstream_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/prefixedoutstream_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/prefixedoutstream_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/prefixedoutstream_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,167 @@
+/**
+ * @file prefixedoutstream.hpp
+ * @author Ryan Curtin
+ * @author Matthew Amidon
+ *
+ * Implementation of templated PrefixedOutStream member functions.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_IO_PREFIXEDOUTSTREAM_IMPL_HPP
+#define __MLPACK_CORE_IO_PREFIXEDOUTSTREAM_IMPL_HPP
+
+// Just in case it hasn't been included.
+#include "prefixedoutstream.hpp"
+#include <iostream>
+
+namespace mlpack {
+namespace util {
+
+template<typename T>
+PrefixedOutStream& PrefixedOutStream::operator<<(const T& s)
+{
+ CallBaseLogic<T>(s);
+ return *this;
+}
+
+//! This handles forwarding all primitive types transparently
+template<typename T>
+void PrefixedOutStream::CallBaseLogic(const T& s,
+ typename boost::disable_if<
+ boost::is_class<T>
+ >::type*)
+{
+ BaseLogic<T>(s);
+}
+
+// Forward all objects that do not implement a ToString() method transparently
+template<typename T>
+void PrefixedOutStream::CallBaseLogic(const T& s,
+ typename boost::enable_if<
+ boost::is_class<T>
+ >::type*,
+ typename boost::disable_if<
+ HasToString<T, std::string(T::*)() const>
+ >::type*)
+{
+ BaseLogic<T>(s);
+}
+
+// Call ToString() on all objects that implement ToString() before forwarding
+template<typename T>
+void PrefixedOutStream::CallBaseLogic(const T& s,
+ typename boost::enable_if<
+ boost::is_class<T>
+ >::type*,
+ typename boost::enable_if<
+ HasToString<T, std::string(T::*)() const>
+ >::type*)
+{
+ std::string result = s.ToString();
+ BaseLogic<std::string>(result);
+}
+
+template<typename T>
+void PrefixedOutStream::BaseLogic(const T& val)
+{
+ // We will use this to track whether or not we need to terminate at the end of
+ // this call (only for streams which terminate after a newline).
+ bool newlined = false;
+ std::string line;
+
+ // If we need to, output the prefix.
+ PrefixIfNeeded();
+
+ std::ostringstream convert;
+ convert << val;
+
+ if(convert.fail())
+ {
+ PrefixIfNeeded();
+ if (!ignoreInput)
+ {
+ destination << "Failed lexical_cast<std::string>(T) for output; output"
+ " not shown." << std::endl;
+ newlined = true;
+ }
+ }
+ else
+ {
+ line = convert.str();
+
+ // If the length of the casted thing was 0, it may have been a stream
+ // manipulator, so send it directly to the stream and don't ask questions.
+ if (line.length() == 0)
+ {
+ // The prefix cannot be necessary at this point.
+ if (!ignoreInput) // Only if the user wants it.
+ destination << val;
+
+ return;
+ }
+
+ // Now, we need to check for newlines in this line. If we find one, output
+ // up until the newline, then output the newline and the prefix and continue
+ // looking.
+ size_t nl;
+ size_t pos = 0;
+ while ((nl = line.find('\n', pos)) != std::string::npos)
+ {
+ PrefixIfNeeded();
+
+ // Only output if the user wants it.
+ if (!ignoreInput)
+ {
+ destination << line.substr(pos, nl - pos);
+ destination << std::endl;
+ newlined = true;
+ }
+
+ carriageReturned = true; // Regardless of whether or not we display it.
+
+ pos = nl + 1;
+ }
+
+ if (pos != line.length()) // We need to display the rest.
+ {
+ PrefixIfNeeded();
+ if (!ignoreInput)
+ destination << line.substr(pos);
+ }
+ }
+
+ // If we displayed a newline and we need to terminate afterwards, do that.
+ if (fatal && newlined)
+ exit(1);
+}
+
+// This is an inline function (that is why it is here and not in .cc).
+void PrefixedOutStream::PrefixIfNeeded()
+{
+ // If we need to, output a prefix.
+ if (carriageReturned)
+ {
+ if (!ignoreInput) // But only if we are allowed to.
+ destination << prefix;
+
+ carriageReturned = false; // Denote that the prefix has been displayed.
+ }
+}
+
+}; // namespace util
+}; // namespace mlpack
+
+#endif // MLPACK_CLI_PREFIXEDOUTSTREAM_IMPL_H
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/save_restore_utility.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/save_restore_utility.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/save_restore_utility.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,221 +0,0 @@
-/**
- * @file save_restore_utility.cpp
- * @author Neil Slagle
- *
- * The SaveRestoreUtility provides helper functions in saving and
- * restoring models. The current output file type is XML.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "save_restore_utility.hpp"
-
-using namespace mlpack;
-using namespace mlpack::util;
-
-bool SaveRestoreUtility::ReadFile(const std::string& filename)
-{
- xmlDocPtr xmlDocTree = NULL;
- if (NULL == (xmlDocTree = xmlReadFile(filename.c_str(), NULL, 0)))
- {
- Log::Fatal << "Could not load XML file '" << filename << "'!" << std::endl;
- }
-
- xmlNodePtr root = xmlDocGetRootElement(xmlDocTree);
- parameters.clear();
-
- RecurseOnNodes(root->children);
- xmlFreeDoc(xmlDocTree);
- return true;
-}
-
-void SaveRestoreUtility::RecurseOnNodes(xmlNode* n)
-{
- xmlNodePtr current = NULL;
- for (current = n; current; current = current->next)
- {
- if (current->type == XML_ELEMENT_NODE)
- {
- xmlChar* content = xmlNodeGetContent(current);
- parameters[(const char*) current->name] = (const char*) content;
- xmlFree(content);
- }
- RecurseOnNodes(current->children);
- }
-}
-
-bool SaveRestoreUtility::WriteFile(const std::string& filename)
-{
- bool success = false;
- xmlDocPtr xmlDocTree = xmlNewDoc(BAD_CAST "1.0");
- xmlNodePtr root = xmlNewNode(NULL, BAD_CAST "root");
-
- xmlDocSetRootElement(xmlDocTree, root);
-
- for (std::map<std::string, std::string>::iterator it = parameters.begin();
- it != parameters.end();
- ++it)
- {
- xmlNewChild(root, NULL, BAD_CAST(*it).first.c_str(),
- BAD_CAST(*it).second.c_str());
- /* TODO: perhaps we'll add more later?
- * xmlNewProp(child, BAD_CAST "attr", BAD_CAST "add more addibutes?"); */
- }
-
- // Actually save the file.
- success =
- (xmlSaveFormatFileEnc(filename.c_str(), xmlDocTree, "UTF-8", 1) != -1);
- xmlFreeDoc(xmlDocTree);
- return success;
-}
-
-arma::mat& SaveRestoreUtility::LoadParameter(arma::mat& matrix,
- const std::string& name)
-{
- std::map<std::string, std::string>::iterator it = parameters.find(name);
- if (it != parameters.end())
- {
- std::string value = (*it).second;
- boost::char_separator<char> sep ("\n");
- boost::tokenizer<boost::char_separator<char> > tok (value, sep);
- std::list<std::list<double> > rows;
- for (boost::tokenizer<boost::char_separator<char> >::iterator
- tokIt = tok.begin();
- tokIt != tok.end();
- ++tokIt)
- {
- std::string row = *tokIt;
- boost::char_separator<char> sepComma (",");
- boost::tokenizer<boost::char_separator<char> >
- tokInner (row, sepComma);
- std::list<double> rowList;
- for (boost::tokenizer<boost::char_separator<char> >::iterator
- tokInnerIt = tokInner.begin();
- tokInnerIt != tokInner.end();
- ++tokInnerIt)
- {
- double element;
- std::istringstream iss (*tokInnerIt);
- iss >> element;
- rowList.push_back(element);
- }
- rows.push_back(rowList);
- }
- matrix.zeros(rows.size(), (*(rows.begin())).size());
- size_t rowCounter = 0;
- size_t columnCounter = 0;
- for (std::list<std::list<double> >::iterator rowIt = rows.begin();
- rowIt != rows.end();
- ++rowIt)
- {
- std::list<double> row = *rowIt;
- columnCounter = 0;
- for (std::list<double>::iterator elementIt = row.begin();
- elementIt != row.end();
- ++elementIt)
- {
- matrix(rowCounter, columnCounter) = *elementIt;
- columnCounter++;
- }
- rowCounter++;
- }
- return matrix;
- }
- else
- {
- Log::Fatal << "LoadParameter(): node '" << name << "' not found.\n";
- }
- return matrix;
-}
-
-std::string SaveRestoreUtility::LoadParameter(std::string& str,
- const std::string& name)
-{
- std::map<std::string, std::string>::iterator it = parameters.find(name);
- if (it != parameters.end())
- {
- return str = (*it).second;
- }
- else
- {
- Log::Fatal << "LoadParameter(): node '" << name << "' not found.\n";
- }
- return "";
-}
-
-char SaveRestoreUtility::LoadParameter(char c, const std::string& name)
-{
- std::map<std::string, std::string>::iterator it = parameters.find(name);
- if (it != parameters.end())
- {
- int temp;
- std::string value = (*it).second;
- std::istringstream input (value);
- input >> temp;
- return c = (char) temp;
- }
- else
- {
- Log::Fatal << "LoadParameter(): node '" << name << "' not found.\n";
- }
- return 0;
-}
-
-void SaveRestoreUtility::SaveParameter(const char c, const std::string& name)
-{
- int temp = (int) c;
- std::ostringstream output;
- output << temp;
- parameters[name] = output.str();
-}
-
-void SaveRestoreUtility::SaveParameter(const arma::mat& mat,
- const std::string& name)
-{
- std::ostringstream output;
- size_t columns = mat.n_cols;
- size_t rows = mat.n_rows;
- for (size_t r = 0; r < rows; ++r)
- {
- for (size_t c = 0; c < columns - 1; ++c)
- {
- output << mat(r,c) << ",";
- }
- output << mat(r,columns - 1) << std::endl;
- }
- parameters[name] = output.str();
-}
-
-// Special template specializations for vectors.
-
-namespace mlpack {
-namespace util {
-
-template<>
-arma::vec& SaveRestoreUtility::LoadParameter(arma::vec& t,
- const std::string& name)
-{
- return (arma::vec&) LoadParameter((arma::mat&) t, name);
-}
-
-template<>
-void SaveRestoreUtility::SaveParameter(const arma::vec& t,
- const std::string& name)
-{
- SaveParameter((const arma::mat&) t, name);
-}
-
-}; // namespace util
-}; // namespace mlpack
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/save_restore_utility.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/save_restore_utility.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/save_restore_utility.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/save_restore_utility.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,221 @@
+/**
+ * @file save_restore_utility.cpp
+ * @author Neil Slagle
+ *
+ * The SaveRestoreUtility provides helper functions in saving and
+ * restoring models. The current output file type is XML.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "save_restore_utility.hpp"
+
+using namespace mlpack;
+using namespace mlpack::util;
+
+bool SaveRestoreUtility::ReadFile(const std::string& filename)
+{
+ xmlDocPtr xmlDocTree = NULL;
+ if (NULL == (xmlDocTree = xmlReadFile(filename.c_str(), NULL, 0)))
+ {
+ Log::Fatal << "Could not load XML file '" << filename << "'!" << std::endl;
+ }
+
+ xmlNodePtr root = xmlDocGetRootElement(xmlDocTree);
+ parameters.clear();
+
+ RecurseOnNodes(root->children);
+ xmlFreeDoc(xmlDocTree);
+ return true;
+}
+
+void SaveRestoreUtility::RecurseOnNodes(xmlNode* n)
+{
+ xmlNodePtr current = NULL;
+ for (current = n; current; current = current->next)
+ {
+ if (current->type == XML_ELEMENT_NODE)
+ {
+ xmlChar* content = xmlNodeGetContent(current);
+ parameters[(const char*) current->name] = (const char*) content;
+ xmlFree(content);
+ }
+ RecurseOnNodes(current->children);
+ }
+}
+
+bool SaveRestoreUtility::WriteFile(const std::string& filename)
+{
+ bool success = false;
+ xmlDocPtr xmlDocTree = xmlNewDoc(BAD_CAST "1.0");
+ xmlNodePtr root = xmlNewNode(NULL, BAD_CAST "root");
+
+ xmlDocSetRootElement(xmlDocTree, root);
+
+ for (std::map<std::string, std::string>::iterator it = parameters.begin();
+ it != parameters.end();
+ ++it)
+ {
+ xmlNewChild(root, NULL, BAD_CAST(*it).first.c_str(),
+ BAD_CAST(*it).second.c_str());
+ /* TODO: perhaps we'll add more later?
+ * xmlNewProp(child, BAD_CAST "attr", BAD_CAST "add more addibutes?"); */
+ }
+
+ // Actually save the file.
+ success =
+ (xmlSaveFormatFileEnc(filename.c_str(), xmlDocTree, "UTF-8", 1) != -1);
+ xmlFreeDoc(xmlDocTree);
+ return success;
+}
+
+arma::mat& SaveRestoreUtility::LoadParameter(arma::mat& matrix,
+ const std::string& name)
+{
+ std::map<std::string, std::string>::iterator it = parameters.find(name);
+ if (it != parameters.end())
+ {
+ std::string value = (*it).second;
+ boost::char_separator<char> sep ("\n");
+ boost::tokenizer<boost::char_separator<char> > tok (value, sep);
+ std::list<std::list<double> > rows;
+ for (boost::tokenizer<boost::char_separator<char> >::iterator
+ tokIt = tok.begin();
+ tokIt != tok.end();
+ ++tokIt)
+ {
+ std::string row = *tokIt;
+ boost::char_separator<char> sepComma (",");
+ boost::tokenizer<boost::char_separator<char> >
+ tokInner (row, sepComma);
+ std::list<double> rowList;
+ for (boost::tokenizer<boost::char_separator<char> >::iterator
+ tokInnerIt = tokInner.begin();
+ tokInnerIt != tokInner.end();
+ ++tokInnerIt)
+ {
+ double element;
+ std::istringstream iss (*tokInnerIt);
+ iss >> element;
+ rowList.push_back(element);
+ }
+ rows.push_back(rowList);
+ }
+ matrix.zeros(rows.size(), (*(rows.begin())).size());
+ size_t rowCounter = 0;
+ size_t columnCounter = 0;
+ for (std::list<std::list<double> >::iterator rowIt = rows.begin();
+ rowIt != rows.end();
+ ++rowIt)
+ {
+ std::list<double> row = *rowIt;
+ columnCounter = 0;
+ for (std::list<double>::iterator elementIt = row.begin();
+ elementIt != row.end();
+ ++elementIt)
+ {
+ matrix(rowCounter, columnCounter) = *elementIt;
+ columnCounter++;
+ }
+ rowCounter++;
+ }
+ return matrix;
+ }
+ else
+ {
+ Log::Fatal << "LoadParameter(): node '" << name << "' not found.\n";
+ }
+ return matrix;
+}
+
+std::string SaveRestoreUtility::LoadParameter(std::string& str,
+ const std::string& name)
+{
+ std::map<std::string, std::string>::iterator it = parameters.find(name);
+ if (it != parameters.end())
+ {
+ return str = (*it).second;
+ }
+ else
+ {
+ Log::Fatal << "LoadParameter(): node '" << name << "' not found.\n";
+ }
+ return "";
+}
+
+char SaveRestoreUtility::LoadParameter(char c, const std::string& name)
+{
+ std::map<std::string, std::string>::iterator it = parameters.find(name);
+ if (it != parameters.end())
+ {
+ int temp;
+ std::string value = (*it).second;
+ std::istringstream input (value);
+ input >> temp;
+ return c = (char) temp;
+ }
+ else
+ {
+ Log::Fatal << "LoadParameter(): node '" << name << "' not found.\n";
+ }
+ return 0;
+}
+
+void SaveRestoreUtility::SaveParameter(const char c, const std::string& name)
+{
+ int temp = (int) c;
+ std::ostringstream output;
+ output << temp;
+ parameters[name] = output.str();
+}
+
+void SaveRestoreUtility::SaveParameter(const arma::mat& mat,
+ const std::string& name)
+{
+ std::ostringstream output;
+ size_t columns = mat.n_cols;
+ size_t rows = mat.n_rows;
+ for (size_t r = 0; r < rows; ++r)
+ {
+ for (size_t c = 0; c < columns - 1; ++c)
+ {
+ output << mat(r,c) << ",";
+ }
+ output << mat(r,columns - 1) << std::endl;
+ }
+ parameters[name] = output.str();
+}
+
+// Special template specializations for vectors.
+
+namespace mlpack {
+namespace util {
+
+template<>
+arma::vec& SaveRestoreUtility::LoadParameter(arma::vec& t,
+ const std::string& name)
+{
+ return (arma::vec&) LoadParameter((arma::mat&) t, name);
+}
+
+template<>
+void SaveRestoreUtility::SaveParameter(const arma::vec& t,
+ const std::string& name)
+{
+ SaveParameter((const arma::mat&) t, name);
+}
+
+}; // namespace util
+}; // namespace mlpack
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/save_restore_utility.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/save_restore_utility.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/save_restore_utility.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,137 +0,0 @@
-/**
- * @file save_restore_utility.hpp
- * @author Neil Slagle
- *
- * The SaveRestoreUtility provides helper functions in saving and
- * restoring models. The current output file type is XML.
- *
- * @experimental
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_UTIL_SAVE_RESTORE_MODEL_HPP
-#define __MLPACK_CORE_UTIL_SAVE_RESTORE_MODEL_HPP
-
-#include <list>
-#include <map>
-#include <sstream>
-#include <string>
-
-#include <libxml/parser.h>
-#include <libxml/tree.h>
-
-#include <boost/tokenizer.hpp>
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace util {
-
-class SaveRestoreUtility
-{
- private:
- /**
- * parameters contains a list of names and parameters in string form.
- */
- std::map<std::string, std::string> parameters;
-
- /**
- * RecurseOnNodes performs a depth first search of the XML tree.
- */
- void RecurseOnNodes(xmlNode* n);
-
- public:
- SaveRestoreUtility() {}
- ~SaveRestoreUtility() { parameters.clear(); }
-
- /**
- * ReadFile reads an XML tree from a file.
- */
- bool ReadFile(const std::string& filename);
-
- /**
- * WriteFile writes the XML tree to a file.
- */
- bool WriteFile(const std::string& filename);
-
- /**
- * LoadParameter loads a parameter from the parameters map.
- */
- template<typename T>
- T& LoadParameter(T& t, const std::string& name);
-
- /**
- * LoadParameter loads a parameter from the parameters map.
- */
- template<typename T>
- std::vector<T>& LoadParameter(std::vector<T>& v, const std::string& name);
-
- /**
- * LoadParameter loads a character from the parameters map.
- */
- char LoadParameter(char c, const std::string& name);
-
- /**
- * LoadParameter loads a string from the parameters map.
- */
- std::string LoadParameter(std::string& str, const std::string& name);
-
- /**
- * LoadParameter loads an arma::mat from the parameters map.
- */
- arma::mat& LoadParameter(arma::mat& matrix, const std::string& name);
-
- /**
- * SaveParameter saves a parameter to the parameters map.
- */
- template<typename T>
- void SaveParameter(const T& t, const std::string& name);
-
-
-
- /**
- * SaveParameter saves a parameter to the parameters map.
- */
- template<typename T>
- void SaveParameter(const std::vector<T>& v, const std::string& name);
-
- /**
- * SaveParameter saves a character to the parameters map.
- */
- void SaveParameter(const char c, const std::string& name);
-
- /**
- * SaveParameter saves an arma::mat to the parameters map.
- */
- void SaveParameter(const arma::mat& mat, const std::string& name);
-};
-
-//! Specialization for arma::vec.
-template<>
-arma::vec& SaveRestoreUtility::LoadParameter(arma::vec& t,
- const std::string& name);
-
-//! Specialization for arma::vec.
-template<>
-void SaveRestoreUtility::SaveParameter(const arma::vec& t,
- const std::string& name);
-
-}; /* namespace util */
-}; /* namespace mlpack */
-
-// Include implementation.
-#include "save_restore_utility_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/save_restore_utility.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/save_restore_utility.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/save_restore_utility.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/save_restore_utility.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,137 @@
+/**
+ * @file save_restore_utility.hpp
+ * @author Neil Slagle
+ *
+ * The SaveRestoreUtility provides helper functions in saving and
+ * restoring models. The current output file type is XML.
+ *
+ * @experimental
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_UTIL_SAVE_RESTORE_MODEL_HPP
+#define __MLPACK_CORE_UTIL_SAVE_RESTORE_MODEL_HPP
+
+#include <list>
+#include <map>
+#include <sstream>
+#include <string>
+
+#include <libxml/parser.h>
+#include <libxml/tree.h>
+
+#include <boost/tokenizer.hpp>
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace util {
+
+class SaveRestoreUtility
+{
+ private:
+ /**
+ * parameters contains a list of names and parameters in string form.
+ */
+ std::map<std::string, std::string> parameters;
+
+ /**
+ * RecurseOnNodes performs a depth first search of the XML tree.
+ */
+ void RecurseOnNodes(xmlNode* n);
+
+ public:
+ SaveRestoreUtility() {}
+ ~SaveRestoreUtility() { parameters.clear(); }
+
+ /**
+ * ReadFile reads an XML tree from a file.
+ */
+ bool ReadFile(const std::string& filename);
+
+ /**
+ * WriteFile writes the XML tree to a file.
+ */
+ bool WriteFile(const std::string& filename);
+
+ /**
+ * LoadParameter loads a parameter from the parameters map.
+ */
+ template<typename T>
+ T& LoadParameter(T& t, const std::string& name);
+
+ /**
+ * LoadParameter loads a parameter from the parameters map.
+ */
+ template<typename T>
+ std::vector<T>& LoadParameter(std::vector<T>& v, const std::string& name);
+
+ /**
+ * LoadParameter loads a character from the parameters map.
+ */
+ char LoadParameter(char c, const std::string& name);
+
+ /**
+ * LoadParameter loads a string from the parameters map.
+ */
+ std::string LoadParameter(std::string& str, const std::string& name);
+
+ /**
+ * LoadParameter loads an arma::mat from the parameters map.
+ */
+ arma::mat& LoadParameter(arma::mat& matrix, const std::string& name);
+
+ /**
+ * SaveParameter saves a parameter to the parameters map.
+ */
+ template<typename T>
+ void SaveParameter(const T& t, const std::string& name);
+
+
+
+ /**
+ * SaveParameter saves a parameter to the parameters map.
+ */
+ template<typename T>
+ void SaveParameter(const std::vector<T>& v, const std::string& name);
+
+ /**
+ * SaveParameter saves a character to the parameters map.
+ */
+ void SaveParameter(const char c, const std::string& name);
+
+ /**
+ * SaveParameter saves an arma::mat to the parameters map.
+ */
+ void SaveParameter(const arma::mat& mat, const std::string& name);
+};
+
+//! Specialization for arma::vec.
+template<>
+arma::vec& SaveRestoreUtility::LoadParameter(arma::vec& t,
+ const std::string& name);
+
+//! Specialization for arma::vec.
+template<>
+void SaveRestoreUtility::SaveParameter(const arma::vec& t,
+ const std::string& name);
+
+}; /* namespace util */
+}; /* namespace mlpack */
+
+// Include implementation.
+#include "save_restore_utility_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/save_restore_utility_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/save_restore_utility_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/save_restore_utility_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,105 +0,0 @@
-/**
- * @file save_restore_utility_impl.hpp
- * @author Neil Slagle
- *
- * The SaveRestoreUtility provides helper functions in saving and
- * restoring models. The current output file type is XML.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_UTIL_SAVE_RESTORE_UTILITY_IMPL_HPP
-#define __MLPACK_CORE_UTIL_SAVE_RESTORE_UTILITY_IMPL_HPP
-
-// In case it hasn't been included already.
-#include "save_restore_utility.hpp"
-
-namespace mlpack {
-namespace util {
-
-template<typename T>
-T& SaveRestoreUtility::LoadParameter(T& t, const std::string& name)
-{
- std::map<std::string, std::string>::iterator it = parameters.find(name);
- if (it != parameters.end())
- {
- std::string value = (*it).second;
- std::istringstream input (value);
- input >> t;
- return t;
- }
- else
- {
- Log::Fatal << "LoadParameter(): node '" << name << "' not found.\n";
- }
- return t;
-}
-
-template<typename T>
-std::vector<T>& SaveRestoreUtility::LoadParameter(std::vector<T>& v,
- const std::string& name)
-{
- std::map<std::string, std::string>::iterator it = parameters.find(name);
- if (it != parameters.end())
- {
- v.clear();
- std::string value = (*it).second;
- boost::char_separator<char> sep (",");
- boost::tokenizer<boost::char_separator<char> > tok (value, sep);
- std::list<std::list<double> > rows;
- for (boost::tokenizer<boost::char_separator<char> >::iterator
- tokIt = tok.begin();
- tokIt != tok.end();
- ++tokIt)
- {
- T t;
- std::istringstream iss (*tokIt);
- iss >> t;
- v.push_back(t);
- }
- }
- else
- {
- Log::Fatal << "LoadParameter(): node '" << name << "' not found.\n";
- }
- return v;
-}
-
-template<typename T>
-void SaveRestoreUtility::SaveParameter(const T& t, const std::string& name)
-{
- std::ostringstream output;
- output << t;
- parameters[name] = output.str();
-}
-
-template<typename T>
-void SaveRestoreUtility::SaveParameter(const std::vector<T>& t,
- const std::string& name)
-{
- std::ostringstream output;
- for (size_t index = 0; index < t.size(); ++index)
- {
- output << t[index] << ",";
- }
- std::string vectorAsStr = output.str();
- vectorAsStr.erase(vectorAsStr.length() - 1);
- parameters[name] = vectorAsStr;
-}
-
-}; // namespace util
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/save_restore_utility_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/save_restore_utility_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/save_restore_utility_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/save_restore_utility_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,105 @@
+/**
+ * @file save_restore_utility_impl.hpp
+ * @author Neil Slagle
+ *
+ * The SaveRestoreUtility provides helper functions in saving and
+ * restoring models. The current output file type is XML.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_UTIL_SAVE_RESTORE_UTILITY_IMPL_HPP
+#define __MLPACK_CORE_UTIL_SAVE_RESTORE_UTILITY_IMPL_HPP
+
+// In case it hasn't been included already.
+#include "save_restore_utility.hpp"
+
+namespace mlpack {
+namespace util {
+
+template<typename T>
+T& SaveRestoreUtility::LoadParameter(T& t, const std::string& name)
+{
+ std::map<std::string, std::string>::iterator it = parameters.find(name);
+ if (it != parameters.end())
+ {
+ std::string value = (*it).second;
+ std::istringstream input (value);
+ input >> t;
+ return t;
+ }
+ else
+ {
+ Log::Fatal << "LoadParameter(): node '" << name << "' not found.\n";
+ }
+ return t;
+}
+
+template<typename T>
+std::vector<T>& SaveRestoreUtility::LoadParameter(std::vector<T>& v,
+ const std::string& name)
+{
+ std::map<std::string, std::string>::iterator it = parameters.find(name);
+ if (it != parameters.end())
+ {
+ v.clear();
+ std::string value = (*it).second;
+ boost::char_separator<char> sep (",");
+ boost::tokenizer<boost::char_separator<char> > tok (value, sep);
+ std::list<std::list<double> > rows;
+ for (boost::tokenizer<boost::char_separator<char> >::iterator
+ tokIt = tok.begin();
+ tokIt != tok.end();
+ ++tokIt)
+ {
+ T t;
+ std::istringstream iss (*tokIt);
+ iss >> t;
+ v.push_back(t);
+ }
+ }
+ else
+ {
+ Log::Fatal << "LoadParameter(): node '" << name << "' not found.\n";
+ }
+ return v;
+}
+
+template<typename T>
+void SaveRestoreUtility::SaveParameter(const T& t, const std::string& name)
+{
+ std::ostringstream output;
+ output << t;
+ parameters[name] = output.str();
+}
+
+template<typename T>
+void SaveRestoreUtility::SaveParameter(const std::vector<T>& t,
+ const std::string& name)
+{
+ std::ostringstream output;
+ for (size_t index = 0; index < t.size(); ++index)
+ {
+ output << t[index] << ",";
+ }
+ std::string vectorAsStr = output.str();
+ vectorAsStr.erase(vectorAsStr.length() - 1);
+ parameters[name] = vectorAsStr;
+}
+
+}; // namespace util
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/sfinae_utility.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/sfinae_utility.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/sfinae_utility.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,61 +0,0 @@
-/**
- * @file sfinae_utility.hpp
- * @author Trironk Kiatkungwanglai
- *
- * This file contains macro utilities for the SFINAE Paradigm. These utilities
- * determine if classes passed in as template parameters contain members at
- * compile time, which is useful for changing functionality depending on what
- * operations an object is capable of performing.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_SFINAE_UTILITY
-#define __MLPACK_CORE_SFINAE_UTILITY
-
-#include <boost/utility/enable_if.hpp>
-#include <boost/type_traits.hpp>
-
-/*
- * Constructs a template supporting the SFINAE pattern.
- *
- * This macro generates a template struct that is useful for enabling/disabling
- * a method if the template class passed in contains a member function matching
- * a given signature with a specified name.
- *
- * The generated struct should be used in conjunction with boost::disable_if and
- * boost::enable_if. Here is an example usage:
- *
- * For general references, see:
- * http://stackoverflow.com/a/264088/391618
- *
- * For an MLPACK specific use case, see /mlpack/core/util/prefixedoutstream.hpp
- * and /mlpack/core/util/prefixedoutstream_impl.hpp
- *
- * @param NAME the name of the struct to construct. For example: HasToString
- * @param FUNC the name of the function to check for. For example: ToString
- */
-#define HAS_MEM_FUNC(FUNC, NAME) \
-template<typename T, typename sig> \
-struct NAME { \
- typedef char yes[1]; \
- typedef char no [2]; \
- template<typename U, U> struct type_check; \
- template<typename _1> static yes &chk(type_check<sig, &_1::FUNC> *); \
- template<typename > static no &chk(...); \
- static bool const value = sizeof(chk<T>(0)) == sizeof(yes); \
-};
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/sfinae_utility.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/sfinae_utility.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/sfinae_utility.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/sfinae_utility.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,61 @@
+/**
+ * @file sfinae_utility.hpp
+ * @author Trironk Kiatkungwanglai
+ *
+ * This file contains macro utilities for the SFINAE Paradigm. These utilities
+ * determine if classes passed in as template parameters contain members at
+ * compile time, which is useful for changing functionality depending on what
+ * operations an object is capable of performing.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_SFINAE_UTILITY
+#define __MLPACK_CORE_SFINAE_UTILITY
+
+#include <boost/utility/enable_if.hpp>
+#include <boost/type_traits.hpp>
+
+/*
+ * Constructs a template supporting the SFINAE pattern.
+ *
+ * This macro generates a template struct that is useful for enabling/disabling
+ * a method if the template class passed in contains a member function matching
+ * a given signature with a specified name.
+ *
+ * The generated struct should be used in conjunction with boost::disable_if and
+ * boost::enable_if. Here is an example usage:
+ *
+ * For general references, see:
+ * http://stackoverflow.com/a/264088/391618
+ *
+ * For an MLPACK specific use case, see /mlpack/core/util/prefixedoutstream.hpp
+ * and /mlpack/core/util/prefixedoutstream_impl.hpp
+ *
+ * @param NAME the name of the struct to construct. For example: HasToString
+ * @param FUNC the name of the function to check for. For example: ToString
+ */
+#define HAS_MEM_FUNC(FUNC, NAME) \
+template<typename T, typename sig> \
+struct NAME { \
+ typedef char yes[1]; \
+ typedef char no [2]; \
+ template<typename U, U> struct type_check; \
+ template<typename _1> static yes &chk(type_check<sig, &_1::FUNC> *); \
+ template<typename > static no &chk(...); \
+ static bool const value = sizeof(chk<T>(0)) == sizeof(yes); \
+};
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/string_util.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/string_util.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/string_util.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,45 +0,0 @@
-/**
- * @file string_util.cpp
- *
- * Defines methods useful for formatting output.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "string_util.hpp"
-
-using namespace mlpack;
-using namespace mlpack::util;
-
-//! A utility function that replaces all all newlines with a number of spaces
-//! depending on the indentation level.
-std::string mlpack::util::Indent(std::string input)
-{
- // Tab the first line.
- input.insert(0, 1, ' ');
- input.insert(0, 1, ' ');
-
- // Get the character sequence to replace all newline characters.
- std::string tabbedNewline("\n ");
-
- // Replace all newline characters with the precomputed character sequence.
- size_t start_pos = 0;
- while((start_pos = input.find("\n", start_pos)) != std::string::npos) {
- input.replace(start_pos, 1, tabbedNewline);
- start_pos += tabbedNewline.length();
- }
-
- return input;
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/string_util.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/string_util.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/string_util.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/string_util.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,45 @@
+/**
+ * @file string_util.cpp
+ *
+ * Defines methods useful for formatting output.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "string_util.hpp"
+
+using namespace mlpack;
+using namespace mlpack::util;
+
+//! A utility function that replaces all all newlines with a number of spaces
+//! depending on the indentation level.
+std::string mlpack::util::Indent(std::string input)
+{
+ // Tab the first line.
+ input.insert(0, 1, ' ');
+ input.insert(0, 1, ' ');
+
+ // Get the character sequence to replace all newline characters.
+ std::string tabbedNewline("\n ");
+
+ // Replace all newline characters with the precomputed character sequence.
+ size_t start_pos = 0;
+ while((start_pos = input.find("\n", start_pos)) != std::string::npos) {
+ input.replace(start_pos, 1, tabbedNewline);
+ start_pos += tabbedNewline.length();
+ }
+
+ return input;
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/string_util.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/string_util.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/string_util.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,36 +0,0 @@
-/**
- * @file string_util.hpp
- *
- * Declares methods that are useful for writing formatting output.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_STRING_UTIL_HPP
-#define __MLPACK_CORE_STRING_UTIL_HPP
-
-#include <string>
-
-namespace mlpack {
-namespace util {
-
-//! A utility function that replaces all all newlines with a number of spaces
-//! depending on the indentation level.
-std::string Indent(std::string input);
-
-}; // namespace util
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/string_util.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/string_util.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/string_util.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/string_util.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,36 @@
+/**
+ * @file string_util.hpp
+ *
+ * Declares methods that are useful for writing formatting output.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_STRING_UTIL_HPP
+#define __MLPACK_CORE_STRING_UTIL_HPP
+
+#include <string>
+
+namespace mlpack {
+namespace util {
+
+//! A utility function that replaces all all newlines with a number of spaces
+//! depending on the indentation level.
+std::string Indent(std::string input);
+
+}; // namespace util
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/timers.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/timers.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/timers.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,193 +0,0 @@
-/**
- * @file timers.cpp
- * @author Matthew Amidon
- *
- * Implementation of timers.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "timers.hpp"
-#include "cli.hpp"
-#include "log.hpp"
-
-#include <map>
-#include <string>
-
-using namespace mlpack;
-
-// On Windows machines, we need to define timersub.
-#ifdef _WIN32
-inline void timersub(const timeval* tvp, const timeval* uvp, timeval* vvp)
-{
- vvp->tv_sec = tvp->tv_sec - uvp->tv_sec;
- vvp->tv_usec = tvp->tv_usec - uvp->tv_usec;
- if (vvp->tv_usec < 0)
- {
- --vvp->tv_sec;
- vvp->tv_usec += 1000000;
- }
-}
-#endif
-
-/**
- * Start the given timer.
- */
-void Timer::Start(const std::string& name)
-{
- CLI::GetSingleton().timer.StartTimer(name);
-}
-
-/**
- * Stop the given timer.
- */
-void Timer::Stop(const std::string& name)
-{
- CLI::GetSingleton().timer.StopTimer(name);
-}
-
-/**
- * Get the given timer.
- */
-timeval Timer::Get(const std::string& name)
-{
- return CLI::GetSingleton().timer.GetTimer(name);
-}
-
-std::map<std::string, timeval>& Timers::GetAllTimers()
-{
- return timers;
-}
-
-timeval Timers::GetTimer(const std::string& timerName)
-{
- std::string name(timerName);
- return timers[name];
-}
-
-void Timers::PrintTimer(const std::string& timerName)
-{
- timeval& t = timers[timerName];
- Log::Info << t.tv_sec << "." << std::setw(6) << std::setfill('0')
- << t.tv_usec << "s";
-
- // Also output convenient day/hr/min/sec.
- int days = t.tv_sec / 86400; // Integer division rounds down.
- int hours = (t.tv_sec % 86400) / 3600;
- int minutes = (t.tv_sec % 3600) / 60;
- int seconds = (t.tv_sec % 60);
- // No output if it didn't even take a minute.
- if (!(days == 0 && hours == 0 && minutes == 0))
- {
- bool output = false; // Denotes if we have output anything yet.
- Log::Info << " (";
-
- // Only output units if they have nonzero values (yes, a bit tedious).
- if (days > 0)
- {
- Log::Info << days << " days";
- output = true;
- }
-
- if (hours > 0)
- {
- if (output)
- Log::Info << ", ";
- Log::Info << hours << " hrs";
- output = true;
- }
-
- if (minutes > 0)
- {
- if (output)
- Log::Info << ", ";
- Log::Info << minutes << " mins";
- output = true;
- }
-
- if (seconds > 0)
- {
- if (output)
- Log::Info << ", ";
- Log::Info << seconds << "." << std::setw(1) << (t.tv_usec / 100000) <<
- "secs";
- output = true;
- }
-
- Log::Info << ")";
- }
-
- Log::Info << std::endl;
-}
-
-void Timers::StartTimer(const std::string& timerName)
-{
- timeval tmp;
-
- tmp.tv_sec = 0;
- tmp.tv_usec = 0;
-
-#ifndef _WIN32
- gettimeofday(&tmp, NULL);
-#else
- FileTimeToTimeVal(&tmp);
-#endif
-
- // Check to see if the timer already exists. If it does, we'll subtract the
- // old value.
- if (timers.count(timerName) == 1)
- {
- timeval tmpDelta;
-
- timersub(&tmp, &timers[timerName], &tmpDelta);
-
- tmp = tmpDelta;
- }
-
- timers[timerName] = tmp;
-}
-
-#ifdef _WIN32
-void Timers::FileTimeToTimeVal(timeval* tv)
-{
- FILETIME ftime;
- uint64_t ptime = 0;
- // Acquire the file time.
- GetSystemTimeAsFileTime(&ftime);
- // Now convert FILETIME to timeval.
- ptime |= ftime.dwHighDateTime;
- ptime = ptime << 32;
- ptime |= ftime.dwLowDateTime;
- ptime /= 10;
- ptime -= DELTA_EPOCH_IN_MICROSECS;
-
- tv->tv_sec = (long) (ptime / 1000000UL);
- tv->tv_usec = (long) (ptime % 1000000UL);
-}
-#endif // _WIN32
-
-void Timers::StopTimer(const std::string& timerName)
-{
- timeval delta, b, a = timers[timerName];
-
-#ifndef _WIN32
- gettimeofday(&b, NULL);
-#else
- FileTimeToTimeVal(&b);
-#endif
- // Calculate the delta time.
- timersub(&b, &a, &delta);
- timers[timerName] = delta;
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/timers.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/timers.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/timers.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/timers.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,193 @@
+/**
+ * @file timers.cpp
+ * @author Matthew Amidon
+ *
+ * Implementation of timers.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "timers.hpp"
+#include "cli.hpp"
+#include "log.hpp"
+
+#include <map>
+#include <string>
+
+using namespace mlpack;
+
+// On Windows machines, we need to define timersub.
+#ifdef _WIN32
+inline void timersub(const timeval* tvp, const timeval* uvp, timeval* vvp)
+{
+ vvp->tv_sec = tvp->tv_sec - uvp->tv_sec;
+ vvp->tv_usec = tvp->tv_usec - uvp->tv_usec;
+ if (vvp->tv_usec < 0)
+ {
+ --vvp->tv_sec;
+ vvp->tv_usec += 1000000;
+ }
+}
+#endif
+
+/**
+ * Start the given timer.
+ */
+void Timer::Start(const std::string& name)
+{
+ CLI::GetSingleton().timer.StartTimer(name);
+}
+
+/**
+ * Stop the given timer.
+ */
+void Timer::Stop(const std::string& name)
+{
+ CLI::GetSingleton().timer.StopTimer(name);
+}
+
+/**
+ * Get the given timer.
+ */
+timeval Timer::Get(const std::string& name)
+{
+ return CLI::GetSingleton().timer.GetTimer(name);
+}
+
+std::map<std::string, timeval>& Timers::GetAllTimers()
+{
+ return timers;
+}
+
+timeval Timers::GetTimer(const std::string& timerName)
+{
+ std::string name(timerName);
+ return timers[name];
+}
+
+void Timers::PrintTimer(const std::string& timerName)
+{
+ timeval& t = timers[timerName];
+ Log::Info << t.tv_sec << "." << std::setw(6) << std::setfill('0')
+ << t.tv_usec << "s";
+
+ // Also output convenient day/hr/min/sec.
+ int days = t.tv_sec / 86400; // Integer division rounds down.
+ int hours = (t.tv_sec % 86400) / 3600;
+ int minutes = (t.tv_sec % 3600) / 60;
+ int seconds = (t.tv_sec % 60);
+ // No output if it didn't even take a minute.
+ if (!(days == 0 && hours == 0 && minutes == 0))
+ {
+ bool output = false; // Denotes if we have output anything yet.
+ Log::Info << " (";
+
+ // Only output units if they have nonzero values (yes, a bit tedious).
+ if (days > 0)
+ {
+ Log::Info << days << " days";
+ output = true;
+ }
+
+ if (hours > 0)
+ {
+ if (output)
+ Log::Info << ", ";
+ Log::Info << hours << " hrs";
+ output = true;
+ }
+
+ if (minutes > 0)
+ {
+ if (output)
+ Log::Info << ", ";
+ Log::Info << minutes << " mins";
+ output = true;
+ }
+
+ if (seconds > 0)
+ {
+ if (output)
+ Log::Info << ", ";
+ Log::Info << seconds << "." << std::setw(1) << (t.tv_usec / 100000) <<
+ "secs";
+ output = true;
+ }
+
+ Log::Info << ")";
+ }
+
+ Log::Info << std::endl;
+}
+
+void Timers::StartTimer(const std::string& timerName)
+{
+ timeval tmp;
+
+ tmp.tv_sec = 0;
+ tmp.tv_usec = 0;
+
+#ifndef _WIN32
+ gettimeofday(&tmp, NULL);
+#else
+ FileTimeToTimeVal(&tmp);
+#endif
+
+ // Check to see if the timer already exists. If it does, we'll subtract the
+ // old value.
+ if (timers.count(timerName) == 1)
+ {
+ timeval tmpDelta;
+
+ timersub(&tmp, &timers[timerName], &tmpDelta);
+
+ tmp = tmpDelta;
+ }
+
+ timers[timerName] = tmp;
+}
+
+#ifdef _WIN32
+void Timers::FileTimeToTimeVal(timeval* tv)
+{
+ FILETIME ftime;
+ uint64_t ptime = 0;
+ // Acquire the file time.
+ GetSystemTimeAsFileTime(&ftime);
+ // Now convert FILETIME to timeval.
+ ptime |= ftime.dwHighDateTime;
+ ptime = ptime << 32;
+ ptime |= ftime.dwLowDateTime;
+ ptime /= 10;
+ ptime -= DELTA_EPOCH_IN_MICROSECS;
+
+ tv->tv_sec = (long) (ptime / 1000000UL);
+ tv->tv_usec = (long) (ptime % 1000000UL);
+}
+#endif // _WIN32
+
+void Timers::StopTimer(const std::string& timerName)
+{
+ timeval delta, b, a = timers[timerName];
+
+#ifndef _WIN32
+ gettimeofday(&b, NULL);
+#else
+ FileTimeToTimeVal(&b);
+#endif
+ // Calculate the delta time.
+ timersub(&b, &a, &delta);
+ timers[timerName] = delta;
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/timers.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core/util/timers.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/timers.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,132 +0,0 @@
-/**
- * @file timers.hpp
- * @author Matthew Amidon
- *
- * Timers for MLPACK.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_UTILITIES_TIMERS_HPP
-#define __MLPACK_CORE_UTILITIES_TIMERS_HPP
-
-#include <map>
-#include <string>
-
-#ifndef _WIN32
- #include <sys/time.h> //linux
-#else
- #include <winsock.h> //timeval on windows
- #include <windows.h> //GetSystemTimeAsFileTime on windows
-//gettimeofday has no equivalent will need to write extra code for that.
- #if defined(_MSC_VER) || defined(_MSC_EXTENSIONS)
- #define DELTA_EPOCH_IN_MICROSECS 11644473600000000Ui64
- #else
- #define DELTA_EPOCH_IN_MICROSECS 11644473600000000ULL
- #endif
-#endif //_WIN32
-
-namespace mlpack {
-
-/**
- * The timer class provides a way for MLPACK methods to be timed. The three
- * methods contained in this class allow a named timer to be started and
- * stopped, and its value to be obtained.
- */
-class Timer
-{
- public:
- /**
- * Start the given timer. If a timer is started, then stopped, then
- * re-started, then re-stopped, the final value of the timer is the length of
- * both runs -- that is, MLPACK timers are additive for each time they are
- * run, and do not reset.
- *
- * @note Undefined behavior will occur if a timer is started twice.
- *
- * @param name Name of timer to be started.
- */
- static void Start(const std::string& name);
-
- /**
- * Stop the given timer.
- *
- * @note Undefined behavior will occur if a timer is started twice.
- *
- * @param name Name of timer to be stopped.
- */
- static void Stop(const std::string& name);
-
- /**
- * Get the value of the given timer.
- *
- * @param name Name of timer to return value of.
- */
- static timeval Get(const std::string& name);
-};
-
-class Timers
-{
- public:
- //! Nothing to do for the constructor.
- Timers() { }
-
- /**
- * Returns a copy of all the timers used via this interface.
- */
- std::map<std::string, timeval>& GetAllTimers();
-
- /**
- * Returns a copy of the timer specified.
- *
- * @param timerName The name of the timer in question.
- */
- timeval GetTimer(const std::string& timerName);
-
- /**
- * Prints the specified timer. If it took longer than a minute to complete
- * the timer will be displayed in days, hours, and minutes as well.
- *
- * @param timerName The name of the timer in question.
- */
- void PrintTimer(const std::string& timerName);
-
- /**
- * Initializes a timer, available like a normal value specified on
- * the command line. Timers are of type timeval. If a timer is started, then
- * stopped, then re-started, then stopped, the final timer value will be the
- * length of both runs of the timer.
- *
- * @param timerName The name of the timer in question.
- */
- void StartTimer(const std::string& timerName);
-
- /**
- * Halts the timer, and replaces it's value with
- * the delta time from it's start
- *
- * @param timerName The name of the timer in question.
- */
- void StopTimer(const std::string& timerName);
-
- private:
- std::map<std::string, timeval> timers;
-
- void FileTimeToTimeVal(timeval* tv);
-};
-
-}; // namespace mlpack
-
-#endif // __MLPACK_CORE_UTILITIES_TIMERS_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/timers.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core/util/timers.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/timers.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core/util/timers.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,132 @@
+/**
+ * @file timers.hpp
+ * @author Matthew Amidon
+ *
+ * Timers for MLPACK.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_UTILITIES_TIMERS_HPP
+#define __MLPACK_CORE_UTILITIES_TIMERS_HPP
+
+#include <map>
+#include <string>
+
+#ifndef _WIN32
+ #include <sys/time.h> //linux
+#else
+ #include <winsock.h> //timeval on windows
+ #include <windows.h> //GetSystemTimeAsFileTime on windows
+//gettimeofday has no equivalent will need to write extra code for that.
+ #if defined(_MSC_VER) || defined(_MSC_EXTENSIONS)
+ #define DELTA_EPOCH_IN_MICROSECS 11644473600000000Ui64
+ #else
+ #define DELTA_EPOCH_IN_MICROSECS 11644473600000000ULL
+ #endif
+#endif //_WIN32
+
+namespace mlpack {
+
+/**
+ * The timer class provides a way for MLPACK methods to be timed. The three
+ * methods contained in this class allow a named timer to be started and
+ * stopped, and its value to be obtained.
+ */
+class Timer
+{
+ public:
+ /**
+ * Start the given timer. If a timer is started, then stopped, then
+ * re-started, then re-stopped, the final value of the timer is the length of
+ * both runs -- that is, MLPACK timers are additive for each time they are
+ * run, and do not reset.
+ *
+ * @note Undefined behavior will occur if a timer is started twice.
+ *
+ * @param name Name of timer to be started.
+ */
+ static void Start(const std::string& name);
+
+ /**
+ * Stop the given timer.
+ *
+ * @note Undefined behavior will occur if a timer is started twice.
+ *
+ * @param name Name of timer to be stopped.
+ */
+ static void Stop(const std::string& name);
+
+ /**
+ * Get the value of the given timer.
+ *
+ * @param name Name of timer to return value of.
+ */
+ static timeval Get(const std::string& name);
+};
+
+class Timers
+{
+ public:
+ //! Nothing to do for the constructor.
+ Timers() { }
+
+ /**
+ * Returns a copy of all the timers used via this interface.
+ */
+ std::map<std::string, timeval>& GetAllTimers();
+
+ /**
+ * Returns a copy of the timer specified.
+ *
+ * @param timerName The name of the timer in question.
+ */
+ timeval GetTimer(const std::string& timerName);
+
+ /**
+ * Prints the specified timer. If it took longer than a minute to complete
+ * the timer will be displayed in days, hours, and minutes as well.
+ *
+ * @param timerName The name of the timer in question.
+ */
+ void PrintTimer(const std::string& timerName);
+
+ /**
+ * Initializes a timer, available like a normal value specified on
+ * the command line. Timers are of type timeval. If a timer is started, then
+ * stopped, then re-started, then stopped, the final timer value will be the
+ * length of both runs of the timer.
+ *
+ * @param timerName The name of the timer in question.
+ */
+ void StartTimer(const std::string& timerName);
+
+ /**
+ * Halts the timer, and replaces it's value with
+ * the delta time from it's start
+ *
+ * @param timerName The name of the timer in question.
+ */
+ void StopTimer(const std::string& timerName);
+
+ private:
+ std::map<std::string, timeval> timers;
+
+ void FileTimeToTimeVal(timeval* tv);
+};
+
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_UTILITIES_TIMERS_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/core.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/core.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,207 +0,0 @@
-/***
- * @file core.hpp
- *
- * Include all of the base components required to write MLPACK methods, and the
- * main MLPACK Doxygen documentation.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_CORE_HPP
-#define __MLPACK_CORE_HPP
-
-/**
- * @mainpage MLPACK Documentation
- *
- * @section intro_sec Introduction
- *
- * MLPACK is an intuitive, fast, scalable C++ machine learning library, meant to
- * be a machine learning analog to LAPACK. It aims to implement a wide array of
- * machine learning methods and function as a "swiss army knife" for machine
- * learning researchers. The MLPACK development website can be found at
- * http://mlpack.org.
- *
- * MLPACK uses the Armadillo C++ matrix library (http://arma.sourceforge.net)
- * for general matrix, vector, and linear algebra support. MLPACK also uses the
- * program_options, math_c99, and unit_test_framework components of the Boost
- * library; in addition, LibXml2 is used.
- *
- * @section howto How To Use This Documentation
- *
- * This documentation is API documentation similar to Javadoc. It isn't
- * necessarily a tutorial, but it does provide detailed documentation on every
- * namespace, method, and class.
- *
- * Each MLPACK namespace generally refers to one machine learning method, so
- * browsing the list of namespaces provides some insight as to the breadth of
- * the methods contained in the library.
- *
- * To generate this documentation in your own local copy of MLPACK, you can
- * simply use Doxygen, from the root directory (@c /mlpack/trunk/ ):
- *
- * @code
- * $ doxygen
- * @endcode
- *
- * @section executables Executables
- *
- * MLPACK provides several executables so that MLPACK methods can be used
- * without any need for knowledge of C++. These executables are all
- * self-documented, and that documentation can be accessed by running the
- * executables with the '-h' or '--help' flag.
- *
- * A full list of executables is given below:
- *
- * allkfn, allknn, emst, gmm, hmm_train, hmm_loglik, hmm_viterbi, hmm_generate,
- * kernel_pca, kmeans, lars, linear_regression, local_coordinate_coding, mvu,
- * nbc, nca, pca, radical, sparse_coding
- *
- * @section tutorial Tutorials
- *
- * A few short tutorials on how to use MLPACK are given below.
- *
- * - @ref build
- * - @ref matrices
- * - @ref iodoc
- * - @ref timer
- * - @ref sample
- *
- * Tutorials on specific methods are also available.
- *
- * - @ref nstutorial
- * - @ref lrtutorial
- * - @ref rstutorial
- * - @ref dettutorial
- * - @ref emst_tutorial
- * - @ref kmtutorial
- * - @ref fmkstutorial
- *
- * @section methods Methods in MLPACK
- *
- * The following methods are included in MLPACK:
- *
- * - Density Estimation Trees - mlpack::det::DTree
- * - Euclidean Minimum Spanning Trees - mlpack::emst::DualTreeBoruvka
- * - Gaussian Mixture Models (GMMs) - mlpack::gmm::GMM
- * - Hidden Markov Models (HMMs) - mlpack::hmm::HMM
- * - Kernel PCA - mlpack::kpca::KernelPCA
- * - K-Means Clustering - mlpack::kmeans::KMeans
- * - Least-Angle Regression (LARS/LASSO) - mlpack::regression::LARS
- * - Local Coordinate Coding - mlpack::lcc::LocalCoordinateCoding
- * - Locality-Sensitive Hashing - mlpack::neighbor::LSHSearch
- * - Naive Bayes Classifier - mlpack::naive_bayes::NaiveBayesClassifier
- * - Neighborhood Components Analysis (NCA) - mlpack::nca::NCA
- * - Principal Components Analysis (PCA) - mlpack::pca::PCA
- * - RADICAL (ICA) - mlpack::radical::Radical
- * - Simple Least-Squares Linear Regression -
- * mlpack::regression::LinearRegression
- * - Sparse Coding - mlpack::sparse_coding::SparseCoding
- * - Tree-based neighbor search (AllkNN, AllkFN) -
- * mlpack::neighbor::NeighborSearch
- * - Tree-based range search - mlpack::range::RangeSearch
- *
- * @section remarks Final Remarks
- *
- * This software was written in the FASTLab (http://www.fast-lab.org), which is
- * in the School of Computational Science and Engineering at the Georgia
- * Institute of Technology.
- *
- * MLPACK contributors include:
- *
- * - Ryan Curtin <gth671b at mail.gatech.edu>
- * - James Cline <james.cline at gatech.edu>
- * - Neil Slagle <nslagle3 at gatech.edu>
- * - Matthew Amidon <mamidon at gatech.edu>
- * - Vlad Grantcharov <vlad321 at gatech.edu>
- * - Ajinkya Kale <kaleajinkya at gmail.com>
- * - Bill March <march at gatech.edu>
- * - Dongryeol Lee <dongryel at cc.gatech.edu>
- * - Nishant Mehta <niche at cc.gatech.edu>
- * - Parikshit Ram <p.ram at gatech.edu>
- * - Rajendran Mohan <rmohan88 at gatech.edu>
- * - Trironk Kiatkungwanglai <trironk at gmail.com>
- * - Patrick Mason <patrick.s.mason at gmail.com>
- * - Chip Mappus <cmappus at gatech.edu>
- * - Hua Ouyang <houyang at gatech.edu>
- * - Long Quoc Tran <tqlong at gmail.com>
- * - Noah Kauffman <notoriousnoah at gmail.com>
- * - Guillermo Colon <gcolon7 at mail.gatech.edu>
- * - Wei Guan <wguan at cc.gatech.edu>
- * - Ryan Riegel <rriegel at cc.gatech.edu>
- * - Nikolaos Vasiloglou <nvasil at ieee.org>
- * - Garry Boyer <garryb at gmail.com>
- * - Andreas Löf <andreas.lof at cs.waikato.ac.nz>
- */
-
-// First, standard includes.
-#include <stdlib.h>
-#include <stdio.h>
-#include <string.h>
-#include <ctype.h>
-#include <limits.h>
-#include <float.h>
-#include <stdint.h>
-#include <iostream>
-
-// Defining _USE_MATH_DEFINES should set M_PI.
-#define _USE_MATH_DEFINES
-#include <math.h>
-
-// For tgamma().
-#include <boost/math/special_functions/gamma.hpp>
-
-// But if it's not defined, we'll do it.
-#ifndef M_PI
- #define M_PI 3.141592653589793238462643383279
-#endif
-
-// Clean up unfortunate Windows preprocessor definitions.
-// Use std::min and std::max!
-#ifdef _WIN32
- #ifdef min
- #undef min
- #endif
-
- #ifdef max
- #undef max
- #endif
-#endif
-
-// Give ourselves a nice way to force functions to be inline if we need.
-#define force_inline
-#if defined(__GNUG__) && !defined(DEBUG)
- #undef force_inline
- #define force_inline __attribute__((always_inline))
-#elif defined(_MSC_VER)
- #undef force_inline && !defined(DEBUG)
- #define force_inline __forceinline
-#endif
-
-// Now MLPACK-specific includes.
-#include <mlpack/core/arma_extend/arma_extend.hpp> // Includes Armadillo.
-#include <mlpack/core/util/log.hpp>
-#include <mlpack/core/util/cli.hpp>
-#include <mlpack/core/data/load.hpp>
-#include <mlpack/core/data/save.hpp>
-#include <mlpack/core/math/clamp.hpp>
-#include <mlpack/core/math/random.hpp>
-#include <mlpack/core/math/lin_alg.hpp>
-#include <mlpack/core/math/range.hpp>
-#include <mlpack/core/math/round.hpp>
-#include <mlpack/core/util/save_restore_utility.hpp>
-#include <mlpack/core/dists/discrete_distribution.hpp>
-#include <mlpack/core/dists/gaussian_distribution.hpp>
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/core.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/core.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/core.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/core.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,207 @@
+/***
+ * @file core.hpp
+ *
+ * Include all of the base components required to write MLPACK methods, and the
+ * main MLPACK Doxygen documentation.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_CORE_HPP
+#define __MLPACK_CORE_HPP
+
+/**
+ * @mainpage MLPACK Documentation
+ *
+ * @section intro_sec Introduction
+ *
+ * MLPACK is an intuitive, fast, scalable C++ machine learning library, meant to
+ * be a machine learning analog to LAPACK. It aims to implement a wide array of
+ * machine learning methods and function as a "swiss army knife" for machine
+ * learning researchers. The MLPACK development website can be found at
+ * http://mlpack.org.
+ *
+ * MLPACK uses the Armadillo C++ matrix library (http://arma.sourceforge.net)
+ * for general matrix, vector, and linear algebra support. MLPACK also uses the
+ * program_options, math_c99, and unit_test_framework components of the Boost
+ * library; in addition, LibXml2 is used.
+ *
+ * @section howto How To Use This Documentation
+ *
+ * This documentation is API documentation similar to Javadoc. It isn't
+ * necessarily a tutorial, but it does provide detailed documentation on every
+ * namespace, method, and class.
+ *
+ * Each MLPACK namespace generally refers to one machine learning method, so
+ * browsing the list of namespaces provides some insight as to the breadth of
+ * the methods contained in the library.
+ *
+ * To generate this documentation in your own local copy of MLPACK, you can
+ * simply use Doxygen, from the root directory (@c /mlpack/trunk/ ):
+ *
+ * @code
+ * $ doxygen
+ * @endcode
+ *
+ * @section executables Executables
+ *
+ * MLPACK provides several executables so that MLPACK methods can be used
+ * without any need for knowledge of C++. These executables are all
+ * self-documented, and that documentation can be accessed by running the
+ * executables with the '-h' or '--help' flag.
+ *
+ * A full list of executables is given below:
+ *
+ * allkfn, allknn, emst, gmm, hmm_train, hmm_loglik, hmm_viterbi, hmm_generate,
+ * kernel_pca, kmeans, lars, linear_regression, local_coordinate_coding, mvu,
+ * nbc, nca, pca, radical, sparse_coding
+ *
+ * @section tutorial Tutorials
+ *
+ * A few short tutorials on how to use MLPACK are given below.
+ *
+ * - @ref build
+ * - @ref matrices
+ * - @ref iodoc
+ * - @ref timer
+ * - @ref sample
+ *
+ * Tutorials on specific methods are also available.
+ *
+ * - @ref nstutorial
+ * - @ref lrtutorial
+ * - @ref rstutorial
+ * - @ref dettutorial
+ * - @ref emst_tutorial
+ * - @ref kmtutorial
+ * - @ref fmkstutorial
+ *
+ * @section methods Methods in MLPACK
+ *
+ * The following methods are included in MLPACK:
+ *
+ * - Density Estimation Trees - mlpack::det::DTree
+ * - Euclidean Minimum Spanning Trees - mlpack::emst::DualTreeBoruvka
+ * - Gaussian Mixture Models (GMMs) - mlpack::gmm::GMM
+ * - Hidden Markov Models (HMMs) - mlpack::hmm::HMM
+ * - Kernel PCA - mlpack::kpca::KernelPCA
+ * - K-Means Clustering - mlpack::kmeans::KMeans
+ * - Least-Angle Regression (LARS/LASSO) - mlpack::regression::LARS
+ * - Local Coordinate Coding - mlpack::lcc::LocalCoordinateCoding
+ * - Locality-Sensitive Hashing - mlpack::neighbor::LSHSearch
+ * - Naive Bayes Classifier - mlpack::naive_bayes::NaiveBayesClassifier
+ * - Neighborhood Components Analysis (NCA) - mlpack::nca::NCA
+ * - Principal Components Analysis (PCA) - mlpack::pca::PCA
+ * - RADICAL (ICA) - mlpack::radical::Radical
+ * - Simple Least-Squares Linear Regression -
+ * mlpack::regression::LinearRegression
+ * - Sparse Coding - mlpack::sparse_coding::SparseCoding
+ * - Tree-based neighbor search (AllkNN, AllkFN) -
+ * mlpack::neighbor::NeighborSearch
+ * - Tree-based range search - mlpack::range::RangeSearch
+ *
+ * @section remarks Final Remarks
+ *
+ * This software was written in the FASTLab (http://www.fast-lab.org), which is
+ * in the School of Computational Science and Engineering at the Georgia
+ * Institute of Technology.
+ *
+ * MLPACK contributors include:
+ *
+ * - Ryan Curtin <gth671b at mail.gatech.edu>
+ * - James Cline <james.cline at gatech.edu>
+ * - Neil Slagle <nslagle3 at gatech.edu>
+ * - Matthew Amidon <mamidon at gatech.edu>
+ * - Vlad Grantcharov <vlad321 at gatech.edu>
+ * - Ajinkya Kale <kaleajinkya at gmail.com>
+ * - Bill March <march at gatech.edu>
+ * - Dongryeol Lee <dongryel at cc.gatech.edu>
+ * - Nishant Mehta <niche at cc.gatech.edu>
+ * - Parikshit Ram <p.ram at gatech.edu>
+ * - Rajendran Mohan <rmohan88 at gatech.edu>
+ * - Trironk Kiatkungwanglai <trironk at gmail.com>
+ * - Patrick Mason <patrick.s.mason at gmail.com>
+ * - Chip Mappus <cmappus at gatech.edu>
+ * - Hua Ouyang <houyang at gatech.edu>
+ * - Long Quoc Tran <tqlong at gmail.com>
+ * - Noah Kauffman <notoriousnoah at gmail.com>
+ * - Guillermo Colon <gcolon7 at mail.gatech.edu>
+ * - Wei Guan <wguan at cc.gatech.edu>
+ * - Ryan Riegel <rriegel at cc.gatech.edu>
+ * - Nikolaos Vasiloglou <nvasil at ieee.org>
+ * - Garry Boyer <garryb at gmail.com>
+ * - Andreas Löf <andreas.lof at cs.waikato.ac.nz>
+ */
+
+// First, standard includes.
+#include <stdlib.h>
+#include <stdio.h>
+#include <string.h>
+#include <ctype.h>
+#include <limits.h>
+#include <float.h>
+#include <stdint.h>
+#include <iostream>
+
+// Defining _USE_MATH_DEFINES should set M_PI.
+#define _USE_MATH_DEFINES
+#include <math.h>
+
+// For tgamma().
+#include <boost/math/special_functions/gamma.hpp>
+
+// But if it's not defined, we'll do it.
+#ifndef M_PI
+ #define M_PI 3.141592653589793238462643383279
+#endif
+
+// Clean up unfortunate Windows preprocessor definitions.
+// Use std::min and std::max!
+#ifdef _WIN32
+ #ifdef min
+ #undef min
+ #endif
+
+ #ifdef max
+ #undef max
+ #endif
+#endif
+
+// Give ourselves a nice way to force functions to be inline if we need.
+#define force_inline
+#if defined(__GNUG__) && !defined(DEBUG)
+ #undef force_inline
+ #define force_inline __attribute__((always_inline))
+#elif defined(_MSC_VER)
+ #undef force_inline && !defined(DEBUG)
+ #define force_inline __forceinline
+#endif
+
+// Now MLPACK-specific includes.
+#include <mlpack/core/arma_extend/arma_extend.hpp> // Includes Armadillo.
+#include <mlpack/core/util/log.hpp>
+#include <mlpack/core/util/cli.hpp>
+#include <mlpack/core/data/load.hpp>
+#include <mlpack/core/data/save.hpp>
+#include <mlpack/core/math/clamp.hpp>
+#include <mlpack/core/math/random.hpp>
+#include <mlpack/core/math/lin_alg.hpp>
+#include <mlpack/core/math/range.hpp>
+#include <mlpack/core/math/round.hpp>
+#include <mlpack/core/util/save_restore_utility.hpp>
+#include <mlpack/core/dists/discrete_distribution.hpp>
+#include <mlpack/core/dists/gaussian_distribution.hpp>
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/det_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/det/det_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/det_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,216 +0,0 @@
-/**
- * @file dt_main.cpp
- * @ Parikshit Ram (pram at cc.gatech.edu)
- *
- * This file provides an example use of the DET
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#include <mlpack/core.hpp>
-#include "dt_utils.hpp"
-
-using namespace mlpack;
-using namespace mlpack::det;
-using namespace std;
-
-PROGRAM_INFO("Density Estimation With Density Estimation Trees",
- "This program performs a number of functions related to Density Estimation "
- "Trees. The optimal Density Estimation Tree (DET) can be trained on a set "
- "of data (specified by --train_file) using cross-validation (with number of"
- " folds specified by --folds). In addition, the density of a set of test "
- "points (specified by --test_file) can be estimated, and the importance of "
- "each dimension can be computed. If class labels are given for the "
- "training points (with --labels_file), the class memberships of each leaf "
- "in the DET can be calculated."
- "\n\n"
- "The created DET can be saved to a file, along with the density estimates "
- "for the test set and the variable importances.");
-
-// Input data files.
-PARAM_STRING_REQ("train_file", "The data set on which to build a density "
- "estimation tree.", "t");
-PARAM_STRING("test_file", "A set of test points to estimate the density of.",
- "T", "");
-PARAM_STRING("labels_file", "The labels for the given training data to "
- "generate the class membership of each leaf (as an extra statistic)", "l",
- "");
-
-// Output data files.
-PARAM_STRING("unpruned_tree_estimates_file", "The file in which to output the "
- "density estimates on the training set from the large unpruned tree.", "u",
- "");
-PARAM_STRING("training_set_estimates_file", "The file in which to output the "
- "density estimates on the training set from the final optimally pruned "
- "tree.", "e", "");
-PARAM_STRING("test_set_estimates_file", "The file in which to output the "
- "estimates on the test set from the final optimally pruned tree.", "E", "");
-PARAM_STRING("leaf_class_table_file", "The file in which to output the leaf "
- "class membership table.", "L", "leaf_class_membership.txt");
-PARAM_STRING("tree_file", "The file in which to print the final optimally "
- "pruned tree.", "r", "");
-PARAM_STRING("vi_file", "The file to output the variable importance values "
- "for each feature.", "i", "");
-
-// Parameters for the algorithm.
-PARAM_INT("folds", "The number of folds of cross-validation to perform for the "
- "estimation (0 is LOOCV)", "f", 10);
-PARAM_INT("min_leaf_size", "The minimum size of a leaf in the unpruned, fully "
- "grown DET.", "N", 5);
-PARAM_INT("max_leaf_size", "The maximum size of a leaf in the unpruned, fully "
- "grown DET.", "M", 10);
-/*
-PARAM_FLAG("volume_regularization", "This flag gives the used the option to use"
- "a form of regularization similar to the usual alpha-pruning in decision "
- "tree. But instead of regularizing on the number of leaves, you regularize "
- "on the sum of the inverse of the volume of the leaves (meaning you "
- "penalize low volume leaves.", "R");
-*/
-
-// Some flags for output of some information about the tree.
-PARAM_FLAG("print_tree", "Print the tree out on the command line (or in the "
- "file specified with --tree_file).", "p");
-PARAM_FLAG("print_vi", "Print the variable importance of each feature out on "
- "the command line (or in the file specified with --vi_file).", "I");
-
-int main(int argc, char *argv[])
-{
- CLI::ParseCommandLine(argc, argv);
-
- string trainSetFile = CLI::GetParam<string>("train_file");
- arma::Mat<double> trainingData;
-
- data::Load(trainSetFile, trainingData, true);
-
- // Cross-validation here.
- size_t folds = CLI::GetParam<int>("folds");
- if (folds == 0)
- {
- folds = trainingData.n_cols;
- Log::Info << "Performing leave-one-out cross validation." << endl;
- }
- else
- {
- Log::Info << "Performing " << folds << "-fold cross validation." << endl;
- }
-
- const string unprunedTreeEstimateFile =
- CLI::GetParam<string>("unpruned_tree_estimates_file");
- const bool regularization = false;
-// const bool regularization = CLI::HasParam("volume_regularization");
- const int maxLeafSize = CLI::GetParam<int>("max_leaf_size");
- const int minLeafSize = CLI::GetParam<int>("min_leaf_size");
-
- // Obtain the optimal tree.
- Timer::Start("det_training");
- DTree *dtreeOpt = Trainer(trainingData, folds, regularization, maxLeafSize,
- minLeafSize, unprunedTreeEstimateFile);
- Timer::Stop("det_training");
-
- // Compute densities for the training points in the optimal tree.
- FILE *fp = NULL;
-
- if (CLI::GetParam<string>("training_set_estimate_file") != "")
- {
- fp = fopen(CLI::GetParam<string>("training_set_estimate_file").c_str(),
- "w");
-
- // Compute density estimates for each point in the training set.
- Timer::Start("det_estimation_time");
- for (size_t i = 0; i < trainingData.n_cols; i++)
- fprintf(fp, "%lg\n", dtreeOpt->ComputeValue(trainingData.unsafe_col(i)));
- Timer::Stop("det_estimation_time");
-
- fclose(fp);
- }
-
- // Compute the density at the provided test points and output the density in
- // the given file.
- const string testFile = CLI::GetParam<string>("test_file");
- if (testFile != "")
- {
- arma::mat testData;
- data::Load(testFile, testData, true);
-
- fp = NULL;
-
- if (CLI::GetParam<string>("test_set_estimates_file") != "")
- {
- fp = fopen(CLI::GetParam<string>("test_set_estimates_file").c_str(), "w");
-
- Timer::Start("det_test_set_estimation");
- for (size_t i = 0; i < testData.n_cols; i++)
- fprintf(fp, "%lg\n", dtreeOpt->ComputeValue(testData.unsafe_col(i)));
- Timer::Stop("det_test_set_estimation");
-
- fclose(fp);
- }
- }
-
- // Print the final tree.
- if (CLI::HasParam("print_tree"))
- {
- fp = NULL;
- if (CLI::GetParam<string>("tree_file") != "")
- {
- fp = fopen(CLI::GetParam<string>("tree_file").c_str(), "w");
-
- if (fp != NULL)
- {
- dtreeOpt->WriteTree(fp);
- fclose(fp);
- }
- }
- else
- {
- dtreeOpt->WriteTree(stdout);
- printf("\n");
- }
- }
-
- // Print the leaf memberships for the optimal tree.
- if (CLI::GetParam<string>("labels_file") != "")
- {
- std::string labelsFile = CLI::GetParam<string>("labels_file");
- arma::Mat<size_t> labels;
-
- data::Load(labelsFile, labels, true);
-
- size_t numClasses = 0;
- for (size_t i = 0; i < labels.n_elem; ++i)
- {
- if (labels[i] > numClasses)
- numClasses = labels[i];
- }
-
- Log::Info << numClasses << " found in labels file '" << labelsFile << "'."
- << std::endl;
-
- Log::Assert(trainingData.n_cols == labels.n_cols);
- Log::Assert(labels.n_rows == 1);
-
- PrintLeafMembership(dtreeOpt, trainingData, labels, numClasses,
- CLI::GetParam<string>("leaf_class_table_file"));
- }
-
- // Print variable importance.
- if (CLI::HasParam("print_vi"))
- {
- PrintVariableImportance(dtreeOpt, CLI::GetParam<string>("vi_file"));
- }
-
- delete dtreeOpt;
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/det_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/det/det_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/det_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/det_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,216 @@
+/**
+ * @file dt_main.cpp
+ * @ Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * This file provides an example use of the DET
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include <mlpack/core.hpp>
+#include "dt_utils.hpp"
+
+using namespace mlpack;
+using namespace mlpack::det;
+using namespace std;
+
+PROGRAM_INFO("Density Estimation With Density Estimation Trees",
+ "This program performs a number of functions related to Density Estimation "
+ "Trees. The optimal Density Estimation Tree (DET) can be trained on a set "
+ "of data (specified by --train_file) using cross-validation (with number of"
+ " folds specified by --folds). In addition, the density of a set of test "
+ "points (specified by --test_file) can be estimated, and the importance of "
+ "each dimension can be computed. If class labels are given for the "
+ "training points (with --labels_file), the class memberships of each leaf "
+ "in the DET can be calculated."
+ "\n\n"
+ "The created DET can be saved to a file, along with the density estimates "
+ "for the test set and the variable importances.");
+
+// Input data files.
+PARAM_STRING_REQ("train_file", "The data set on which to build a density "
+ "estimation tree.", "t");
+PARAM_STRING("test_file", "A set of test points to estimate the density of.",
+ "T", "");
+PARAM_STRING("labels_file", "The labels for the given training data to "
+ "generate the class membership of each leaf (as an extra statistic)", "l",
+ "");
+
+// Output data files.
+PARAM_STRING("unpruned_tree_estimates_file", "The file in which to output the "
+ "density estimates on the training set from the large unpruned tree.", "u",
+ "");
+PARAM_STRING("training_set_estimates_file", "The file in which to output the "
+ "density estimates on the training set from the final optimally pruned "
+ "tree.", "e", "");
+PARAM_STRING("test_set_estimates_file", "The file in which to output the "
+ "estimates on the test set from the final optimally pruned tree.", "E", "");
+PARAM_STRING("leaf_class_table_file", "The file in which to output the leaf "
+ "class membership table.", "L", "leaf_class_membership.txt");
+PARAM_STRING("tree_file", "The file in which to print the final optimally "
+ "pruned tree.", "r", "");
+PARAM_STRING("vi_file", "The file to output the variable importance values "
+ "for each feature.", "i", "");
+
+// Parameters for the algorithm.
+PARAM_INT("folds", "The number of folds of cross-validation to perform for the "
+ "estimation (0 is LOOCV)", "f", 10);
+PARAM_INT("min_leaf_size", "The minimum size of a leaf in the unpruned, fully "
+ "grown DET.", "N", 5);
+PARAM_INT("max_leaf_size", "The maximum size of a leaf in the unpruned, fully "
+ "grown DET.", "M", 10);
+/*
+PARAM_FLAG("volume_regularization", "This flag gives the used the option to use"
+ "a form of regularization similar to the usual alpha-pruning in decision "
+ "tree. But instead of regularizing on the number of leaves, you regularize "
+ "on the sum of the inverse of the volume of the leaves (meaning you "
+ "penalize low volume leaves.", "R");
+*/
+
+// Some flags for output of some information about the tree.
+PARAM_FLAG("print_tree", "Print the tree out on the command line (or in the "
+ "file specified with --tree_file).", "p");
+PARAM_FLAG("print_vi", "Print the variable importance of each feature out on "
+ "the command line (or in the file specified with --vi_file).", "I");
+
+int main(int argc, char *argv[])
+{
+ CLI::ParseCommandLine(argc, argv);
+
+ string trainSetFile = CLI::GetParam<string>("train_file");
+ arma::Mat<double> trainingData;
+
+ data::Load(trainSetFile, trainingData, true);
+
+ // Cross-validation here.
+ size_t folds = CLI::GetParam<int>("folds");
+ if (folds == 0)
+ {
+ folds = trainingData.n_cols;
+ Log::Info << "Performing leave-one-out cross validation." << endl;
+ }
+ else
+ {
+ Log::Info << "Performing " << folds << "-fold cross validation." << endl;
+ }
+
+ const string unprunedTreeEstimateFile =
+ CLI::GetParam<string>("unpruned_tree_estimates_file");
+ const bool regularization = false;
+// const bool regularization = CLI::HasParam("volume_regularization");
+ const int maxLeafSize = CLI::GetParam<int>("max_leaf_size");
+ const int minLeafSize = CLI::GetParam<int>("min_leaf_size");
+
+ // Obtain the optimal tree.
+ Timer::Start("det_training");
+ DTree *dtreeOpt = Trainer(trainingData, folds, regularization, maxLeafSize,
+ minLeafSize, unprunedTreeEstimateFile);
+ Timer::Stop("det_training");
+
+ // Compute densities for the training points in the optimal tree.
+ FILE *fp = NULL;
+
+ if (CLI::GetParam<string>("training_set_estimate_file") != "")
+ {
+ fp = fopen(CLI::GetParam<string>("training_set_estimate_file").c_str(),
+ "w");
+
+ // Compute density estimates for each point in the training set.
+ Timer::Start("det_estimation_time");
+ for (size_t i = 0; i < trainingData.n_cols; i++)
+ fprintf(fp, "%lg\n", dtreeOpt->ComputeValue(trainingData.unsafe_col(i)));
+ Timer::Stop("det_estimation_time");
+
+ fclose(fp);
+ }
+
+ // Compute the density at the provided test points and output the density in
+ // the given file.
+ const string testFile = CLI::GetParam<string>("test_file");
+ if (testFile != "")
+ {
+ arma::mat testData;
+ data::Load(testFile, testData, true);
+
+ fp = NULL;
+
+ if (CLI::GetParam<string>("test_set_estimates_file") != "")
+ {
+ fp = fopen(CLI::GetParam<string>("test_set_estimates_file").c_str(), "w");
+
+ Timer::Start("det_test_set_estimation");
+ for (size_t i = 0; i < testData.n_cols; i++)
+ fprintf(fp, "%lg\n", dtreeOpt->ComputeValue(testData.unsafe_col(i)));
+ Timer::Stop("det_test_set_estimation");
+
+ fclose(fp);
+ }
+ }
+
+ // Print the final tree.
+ if (CLI::HasParam("print_tree"))
+ {
+ fp = NULL;
+ if (CLI::GetParam<string>("tree_file") != "")
+ {
+ fp = fopen(CLI::GetParam<string>("tree_file").c_str(), "w");
+
+ if (fp != NULL)
+ {
+ dtreeOpt->WriteTree(fp);
+ fclose(fp);
+ }
+ }
+ else
+ {
+ dtreeOpt->WriteTree(stdout);
+ printf("\n");
+ }
+ }
+
+ // Print the leaf memberships for the optimal tree.
+ if (CLI::GetParam<string>("labels_file") != "")
+ {
+ std::string labelsFile = CLI::GetParam<string>("labels_file");
+ arma::Mat<size_t> labels;
+
+ data::Load(labelsFile, labels, true);
+
+ size_t numClasses = 0;
+ for (size_t i = 0; i < labels.n_elem; ++i)
+ {
+ if (labels[i] > numClasses)
+ numClasses = labels[i];
+ }
+
+ Log::Info << numClasses << " found in labels file '" << labelsFile << "'."
+ << std::endl;
+
+ Log::Assert(trainingData.n_cols == labels.n_cols);
+ Log::Assert(labels.n_rows == 1);
+
+ PrintLeafMembership(dtreeOpt, trainingData, labels, numClasses,
+ CLI::GetParam<string>("leaf_class_table_file"));
+ }
+
+ // Print variable importance.
+ if (CLI::HasParam("print_vi"))
+ {
+ PrintVariableImportance(dtreeOpt, CLI::GetParam<string>("vi_file"));
+ }
+
+ delete dtreeOpt;
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dt_utils.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/det/dt_utils.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dt_utils.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,317 +0,0 @@
-/**
- * @file dt_utils.cpp
- * @author Parikshit Ram (pram at cc.gatech.edu)
- *
- * This file implements functions to perform different tasks with the Density
- * Tree class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "dt_utils.hpp"
-
-using namespace mlpack;
-using namespace det;
-
-void mlpack::det::PrintLeafMembership(DTree* dtree,
- const arma::mat& data,
- const arma::Mat<size_t>& labels,
- const size_t numClasses,
- const std::string leafClassMembershipFile)
-{
- // Tag the leaves with numbers.
- int numLeaves = dtree->TagTree();
-
- arma::Mat<size_t> table(numLeaves, numClasses);
- table.zeros();
-
- for (size_t i = 0; i < data.n_cols; i++)
- {
- const arma::vec testPoint = data.unsafe_col(i);
- const int leafTag = dtree->FindBucket(testPoint);
- const size_t label = labels[i];
- table(leafTag, label) += 1;
- }
-
- if (leafClassMembershipFile == "")
- {
- Log::Info << "Leaf membership; row represents leaf id, column represents "
- << "class id; value represents number of points in leaf in class."
- << std::endl << table;
- }
- else
- {
- // Create a stream for the file.
- std::ofstream outfile(leafClassMembershipFile.c_str());
- if (outfile.good())
- {
- outfile << table;
- Log::Info << "Leaf membership printed to '" << leafClassMembershipFile
- << "'." << std::endl;
- }
- else
- {
- Log::Warn << "Can't open '" << leafClassMembershipFile << "' to write "
- << "leaf membership to." << std::endl;
- }
- outfile.close();
- }
-
- return;
-}
-
-
-void mlpack::det::PrintVariableImportance(const DTree* dtree,
- const std::string viFile)
-{
- arma::vec imps;
- dtree->ComputeVariableImportance(imps);
-
- double max = 0.0;
- for (size_t i = 0; i < imps.n_elem; ++i)
- if (imps[i] > max)
- max = imps[i];
-
- Log::Info << "Maximum variable importance: " << max << "." << std::endl;
-
- if (viFile == "")
- {
- Log::Info << "Variable importance: " << std::endl << imps.t() << std::endl;
- }
- else
- {
- std::ofstream outfile(viFile.c_str());
- if (outfile.good())
- {
- outfile << imps;
- Log::Info << "Variable importance printed to '" << viFile << "'."
- << std::endl;
- }
- else
- {
- Log::Warn << "Can't open '" << viFile << "' to write variable importance "
- << "to." << std::endl;
- }
- outfile.close();
- }
-}
-
-
-// This function trains the optimal decision tree using the given number of
-// folds.
-DTree* mlpack::det::Trainer(arma::mat& dataset,
- const size_t folds,
- const bool useVolumeReg,
- const size_t maxLeafSize,
- const size_t minLeafSize,
- const std::string unprunedTreeOutput)
-{
- // Initialize the tree.
- DTree* dtree = new DTree(dataset);
-
- // Prepare to grow the tree...
- arma::Col<size_t> oldFromNew(dataset.n_cols);
- for (size_t i = 0; i < oldFromNew.n_elem; i++)
- oldFromNew[i] = i;
-
- // Save the dataset since it would be modified while growing the tree.
- arma::mat newDataset(dataset);
-
- // Growing the tree
- double oldAlpha = 0.0;
- double alpha = dtree->Grow(newDataset, oldFromNew, useVolumeReg, maxLeafSize,
- minLeafSize);
-
- Log::Info << dtree->SubtreeLeaves() << " leaf nodes in the tree using full "
- << "dataset; minimum alpha: " << alpha << "." << std::endl;
-
- // Compute densities for the training points in the full tree, if we were
- // asked for this.
- if (unprunedTreeOutput != "")
- {
- std::ofstream outfile(unprunedTreeOutput.c_str());
- if (outfile.good())
- {
- for (size_t i = 0; i < dataset.n_cols; ++i)
- {
- arma::vec testPoint = dataset.unsafe_col(i);
- outfile << dtree->ComputeValue(testPoint) << std::endl;
- }
- }
- else
- {
- Log::Warn << "Can't open '" << unprunedTreeOutput << "' to write computed"
- << " densities to." << std::endl;
- }
-
- outfile.close();
- }
-
- // Sequentially prune and save the alpha values and the values of c_t^2 * r_t.
- std::vector<std::pair<double, double> > prunedSequence;
- while (dtree->SubtreeLeaves() > 1)
- {
- std::pair<double, double> treeSeq(oldAlpha,
- dtree->SubtreeLeavesLogNegError());
- prunedSequence.push_back(treeSeq);
- oldAlpha = alpha;
- alpha = dtree->PruneAndUpdate(oldAlpha, dataset.n_cols, useVolumeReg);
-
- // Some sanity checks.
- Log::Assert((alpha < std::numeric_limits<double>::max()) ||
- (dtree->SubtreeLeaves() == 1));
- Log::Assert(alpha > oldAlpha);
- Log::Assert(dtree->SubtreeLeavesLogNegError() < treeSeq.second);
- }
-
- std::pair<double, double> treeSeq(oldAlpha,
- dtree->SubtreeLeavesLogNegError());
- prunedSequence.push_back(treeSeq);
-
- Log::Info << prunedSequence.size() << " trees in the sequence; maximum alpha:"
- << " " << oldAlpha << "." << std::endl;
-
- delete dtree;
-
- arma::mat cvData(dataset);
- size_t testSize = dataset.n_cols / folds;
-
- std::vector<double> regularizationConstants;
- regularizationConstants.resize(prunedSequence.size(), 0);
-
- // Go through each fold.
- for (size_t fold = 0; fold < folds; fold++)
- {
- // Break up data into train and test sets.
- size_t start = fold * testSize;
- size_t end = std::min((fold + 1) * testSize, (size_t) cvData.n_cols);
-
- arma::mat test = cvData.cols(start, end - 1);
- arma::mat train(cvData.n_rows, cvData.n_cols - test.n_cols);
-
- if (start == 0 && end < cvData.n_cols)
- {
- train.cols(0, train.n_cols - 1) = cvData.cols(end, cvData.n_cols - 1);
- }
- else if (start > 0 && end == cvData.n_cols)
- {
- train.cols(0, train.n_cols - 1) = cvData.cols(0, start - 1);
- }
- else
- {
- train.cols(0, start - 1) = cvData.cols(0, start - 1);
- train.cols(start, train.n_cols - 1) = cvData.cols(end, cvData.n_cols - 1);
- }
-
- // Initialize the tree.
- DTree* cvDTree = new DTree(train);
-
- // Getting ready to grow the tree...
- arma::Col<size_t> cvOldFromNew(train.n_cols);
- for (size_t i = 0; i < cvOldFromNew.n_elem; i++)
- cvOldFromNew[i] = i;
-
- // Grow the tree.
- oldAlpha = 0.0;
- alpha = cvDTree->Grow(train, cvOldFromNew, useVolumeReg, maxLeafSize,
- minLeafSize);
-
- // Sequentially prune with all the values of available alphas and adding
- // values for test values.
- for (size_t i = 0; i < prunedSequence.size() - 2; ++i)
- {
- // Compute test values for this state of the tree.
- double cvVal = 0.0;
- for (size_t j = 0; j < test.n_cols; j++)
- {
- arma::vec testPoint = test.unsafe_col(j);
- cvVal += cvDTree->ComputeValue(testPoint);
- }
-
- // Update the cv regularization constant.
- regularizationConstants[i] += 2.0 * cvVal / (double) dataset.n_cols;
-
- // Determine the new alpha value and prune accordingly.
- oldAlpha = 0.5 * (prunedSequence[i + 1].first +
- prunedSequence[i + 2].first);
- alpha = cvDTree->PruneAndUpdate(oldAlpha, train.n_cols, useVolumeReg);
- }
-
- // Compute test values for this state of the tree.
- double cvVal = 0.0;
- for (size_t i = 0; i < test.n_cols; ++i)
- {
- arma::vec testPoint = test.unsafe_col(i);
- cvVal += cvDTree->ComputeValue(testPoint);
- }
-
- regularizationConstants[prunedSequence.size() - 2] += 2.0 * cvVal /
- (double) dataset.n_cols;
-
- test.reset();
- delete cvDTree;
- }
-
- double optimalAlpha = -1.0;
- long double cvBestError = -std::numeric_limits<long double>::max();
-
- for (size_t i = 0; i < prunedSequence.size() - 1; ++i)
- {
- // We can no longer work in the log-space for this because we have no
- // guarantee the quantity will be positive.
- long double thisError = -std::exp((long double) prunedSequence[i].second) +
- (long double) regularizationConstants[i];
-
- if (thisError > cvBestError)
- {
- cvBestError = thisError;
- optimalAlpha = prunedSequence[i].first;
- }
- }
-
- Log::Info << "Optimal alpha: " << optimalAlpha << "." << std::endl;
-
- // Initialize the tree.
- DTree* dtreeOpt = new DTree(dataset);
-
- // Getting ready to grow the tree...
- for (size_t i = 0; i < oldFromNew.n_elem; i++)
- oldFromNew[i] = i;
-
- // Save the dataset since it would be modified while growing the tree.
- newDataset = dataset;
-
- // Grow the tree.
- oldAlpha = -DBL_MAX;
- alpha = dtreeOpt->Grow(newDataset, oldFromNew, useVolumeReg, maxLeafSize,
- minLeafSize);
-
- // Prune with optimal alpha.
- while ((oldAlpha < optimalAlpha) && (dtreeOpt->SubtreeLeaves() > 1))
- {
- oldAlpha = alpha;
- alpha = dtreeOpt->PruneAndUpdate(oldAlpha, newDataset.n_cols, useVolumeReg);
-
- // Some sanity checks.
- Log::Assert((alpha < std::numeric_limits<double>::max()) ||
- (dtreeOpt->SubtreeLeaves() == 1));
- Log::Assert(alpha > oldAlpha);
- }
-
- Log::Info << dtreeOpt->SubtreeLeaves() << " leaf nodes in the optimally "
- << "pruned tree; optimal alpha: " << oldAlpha << "." << std::endl;
-
- return dtreeOpt;
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dt_utils.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/det/dt_utils.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dt_utils.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dt_utils.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,317 @@
+/**
+ * @file dt_utils.cpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * This file implements functions to perform different tasks with the Density
+ * Tree class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "dt_utils.hpp"
+
+using namespace mlpack;
+using namespace det;
+
+void mlpack::det::PrintLeafMembership(DTree* dtree,
+ const arma::mat& data,
+ const arma::Mat<size_t>& labels,
+ const size_t numClasses,
+ const std::string leafClassMembershipFile)
+{
+ // Tag the leaves with numbers.
+ int numLeaves = dtree->TagTree();
+
+ arma::Mat<size_t> table(numLeaves, numClasses);
+ table.zeros();
+
+ for (size_t i = 0; i < data.n_cols; i++)
+ {
+ const arma::vec testPoint = data.unsafe_col(i);
+ const int leafTag = dtree->FindBucket(testPoint);
+ const size_t label = labels[i];
+ table(leafTag, label) += 1;
+ }
+
+ if (leafClassMembershipFile == "")
+ {
+ Log::Info << "Leaf membership; row represents leaf id, column represents "
+ << "class id; value represents number of points in leaf in class."
+ << std::endl << table;
+ }
+ else
+ {
+ // Create a stream for the file.
+ std::ofstream outfile(leafClassMembershipFile.c_str());
+ if (outfile.good())
+ {
+ outfile << table;
+ Log::Info << "Leaf membership printed to '" << leafClassMembershipFile
+ << "'." << std::endl;
+ }
+ else
+ {
+ Log::Warn << "Can't open '" << leafClassMembershipFile << "' to write "
+ << "leaf membership to." << std::endl;
+ }
+ outfile.close();
+ }
+
+ return;
+}
+
+
+void mlpack::det::PrintVariableImportance(const DTree* dtree,
+ const std::string viFile)
+{
+ arma::vec imps;
+ dtree->ComputeVariableImportance(imps);
+
+ double max = 0.0;
+ for (size_t i = 0; i < imps.n_elem; ++i)
+ if (imps[i] > max)
+ max = imps[i];
+
+ Log::Info << "Maximum variable importance: " << max << "." << std::endl;
+
+ if (viFile == "")
+ {
+ Log::Info << "Variable importance: " << std::endl << imps.t() << std::endl;
+ }
+ else
+ {
+ std::ofstream outfile(viFile.c_str());
+ if (outfile.good())
+ {
+ outfile << imps;
+ Log::Info << "Variable importance printed to '" << viFile << "'."
+ << std::endl;
+ }
+ else
+ {
+ Log::Warn << "Can't open '" << viFile << "' to write variable importance "
+ << "to." << std::endl;
+ }
+ outfile.close();
+ }
+}
+
+
+// This function trains the optimal decision tree using the given number of
+// folds.
+DTree* mlpack::det::Trainer(arma::mat& dataset,
+ const size_t folds,
+ const bool useVolumeReg,
+ const size_t maxLeafSize,
+ const size_t minLeafSize,
+ const std::string unprunedTreeOutput)
+{
+ // Initialize the tree.
+ DTree* dtree = new DTree(dataset);
+
+ // Prepare to grow the tree...
+ arma::Col<size_t> oldFromNew(dataset.n_cols);
+ for (size_t i = 0; i < oldFromNew.n_elem; i++)
+ oldFromNew[i] = i;
+
+ // Save the dataset since it would be modified while growing the tree.
+ arma::mat newDataset(dataset);
+
+ // Growing the tree
+ double oldAlpha = 0.0;
+ double alpha = dtree->Grow(newDataset, oldFromNew, useVolumeReg, maxLeafSize,
+ minLeafSize);
+
+ Log::Info << dtree->SubtreeLeaves() << " leaf nodes in the tree using full "
+ << "dataset; minimum alpha: " << alpha << "." << std::endl;
+
+ // Compute densities for the training points in the full tree, if we were
+ // asked for this.
+ if (unprunedTreeOutput != "")
+ {
+ std::ofstream outfile(unprunedTreeOutput.c_str());
+ if (outfile.good())
+ {
+ for (size_t i = 0; i < dataset.n_cols; ++i)
+ {
+ arma::vec testPoint = dataset.unsafe_col(i);
+ outfile << dtree->ComputeValue(testPoint) << std::endl;
+ }
+ }
+ else
+ {
+ Log::Warn << "Can't open '" << unprunedTreeOutput << "' to write computed"
+ << " densities to." << std::endl;
+ }
+
+ outfile.close();
+ }
+
+ // Sequentially prune and save the alpha values and the values of c_t^2 * r_t.
+ std::vector<std::pair<double, double> > prunedSequence;
+ while (dtree->SubtreeLeaves() > 1)
+ {
+ std::pair<double, double> treeSeq(oldAlpha,
+ dtree->SubtreeLeavesLogNegError());
+ prunedSequence.push_back(treeSeq);
+ oldAlpha = alpha;
+ alpha = dtree->PruneAndUpdate(oldAlpha, dataset.n_cols, useVolumeReg);
+
+ // Some sanity checks.
+ Log::Assert((alpha < std::numeric_limits<double>::max()) ||
+ (dtree->SubtreeLeaves() == 1));
+ Log::Assert(alpha > oldAlpha);
+ Log::Assert(dtree->SubtreeLeavesLogNegError() < treeSeq.second);
+ }
+
+ std::pair<double, double> treeSeq(oldAlpha,
+ dtree->SubtreeLeavesLogNegError());
+ prunedSequence.push_back(treeSeq);
+
+ Log::Info << prunedSequence.size() << " trees in the sequence; maximum alpha:"
+ << " " << oldAlpha << "." << std::endl;
+
+ delete dtree;
+
+ arma::mat cvData(dataset);
+ size_t testSize = dataset.n_cols / folds;
+
+ std::vector<double> regularizationConstants;
+ regularizationConstants.resize(prunedSequence.size(), 0);
+
+ // Go through each fold.
+ for (size_t fold = 0; fold < folds; fold++)
+ {
+ // Break up data into train and test sets.
+ size_t start = fold * testSize;
+ size_t end = std::min((fold + 1) * testSize, (size_t) cvData.n_cols);
+
+ arma::mat test = cvData.cols(start, end - 1);
+ arma::mat train(cvData.n_rows, cvData.n_cols - test.n_cols);
+
+ if (start == 0 && end < cvData.n_cols)
+ {
+ train.cols(0, train.n_cols - 1) = cvData.cols(end, cvData.n_cols - 1);
+ }
+ else if (start > 0 && end == cvData.n_cols)
+ {
+ train.cols(0, train.n_cols - 1) = cvData.cols(0, start - 1);
+ }
+ else
+ {
+ train.cols(0, start - 1) = cvData.cols(0, start - 1);
+ train.cols(start, train.n_cols - 1) = cvData.cols(end, cvData.n_cols - 1);
+ }
+
+ // Initialize the tree.
+ DTree* cvDTree = new DTree(train);
+
+ // Getting ready to grow the tree...
+ arma::Col<size_t> cvOldFromNew(train.n_cols);
+ for (size_t i = 0; i < cvOldFromNew.n_elem; i++)
+ cvOldFromNew[i] = i;
+
+ // Grow the tree.
+ oldAlpha = 0.0;
+ alpha = cvDTree->Grow(train, cvOldFromNew, useVolumeReg, maxLeafSize,
+ minLeafSize);
+
+ // Sequentially prune with all the values of available alphas and adding
+ // values for test values.
+ for (size_t i = 0; i < prunedSequence.size() - 2; ++i)
+ {
+ // Compute test values for this state of the tree.
+ double cvVal = 0.0;
+ for (size_t j = 0; j < test.n_cols; j++)
+ {
+ arma::vec testPoint = test.unsafe_col(j);
+ cvVal += cvDTree->ComputeValue(testPoint);
+ }
+
+ // Update the cv regularization constant.
+ regularizationConstants[i] += 2.0 * cvVal / (double) dataset.n_cols;
+
+ // Determine the new alpha value and prune accordingly.
+ oldAlpha = 0.5 * (prunedSequence[i + 1].first +
+ prunedSequence[i + 2].first);
+ alpha = cvDTree->PruneAndUpdate(oldAlpha, train.n_cols, useVolumeReg);
+ }
+
+ // Compute test values for this state of the tree.
+ double cvVal = 0.0;
+ for (size_t i = 0; i < test.n_cols; ++i)
+ {
+ arma::vec testPoint = test.unsafe_col(i);
+ cvVal += cvDTree->ComputeValue(testPoint);
+ }
+
+ regularizationConstants[prunedSequence.size() - 2] += 2.0 * cvVal /
+ (double) dataset.n_cols;
+
+ test.reset();
+ delete cvDTree;
+ }
+
+ double optimalAlpha = -1.0;
+ long double cvBestError = -std::numeric_limits<long double>::max();
+
+ for (size_t i = 0; i < prunedSequence.size() - 1; ++i)
+ {
+ // We can no longer work in the log-space for this because we have no
+ // guarantee the quantity will be positive.
+ long double thisError = -std::exp((long double) prunedSequence[i].second) +
+ (long double) regularizationConstants[i];
+
+ if (thisError > cvBestError)
+ {
+ cvBestError = thisError;
+ optimalAlpha = prunedSequence[i].first;
+ }
+ }
+
+ Log::Info << "Optimal alpha: " << optimalAlpha << "." << std::endl;
+
+ // Initialize the tree.
+ DTree* dtreeOpt = new DTree(dataset);
+
+ // Getting ready to grow the tree...
+ for (size_t i = 0; i < oldFromNew.n_elem; i++)
+ oldFromNew[i] = i;
+
+ // Save the dataset since it would be modified while growing the tree.
+ newDataset = dataset;
+
+ // Grow the tree.
+ oldAlpha = -DBL_MAX;
+ alpha = dtreeOpt->Grow(newDataset, oldFromNew, useVolumeReg, maxLeafSize,
+ minLeafSize);
+
+ // Prune with optimal alpha.
+ while ((oldAlpha < optimalAlpha) && (dtreeOpt->SubtreeLeaves() > 1))
+ {
+ oldAlpha = alpha;
+ alpha = dtreeOpt->PruneAndUpdate(oldAlpha, newDataset.n_cols, useVolumeReg);
+
+ // Some sanity checks.
+ Log::Assert((alpha < std::numeric_limits<double>::max()) ||
+ (dtreeOpt->SubtreeLeaves() == 1));
+ Log::Assert(alpha > oldAlpha);
+ }
+
+ Log::Info << dtreeOpt->SubtreeLeaves() << " leaf nodes in the optimally "
+ << "pruned tree; optimal alpha: " << oldAlpha << "." << std::endl;
+
+ return dtreeOpt;
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dt_utils.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/det/dt_utils.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dt_utils.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,84 +0,0 @@
-/**
- * @file dt_utils.hpp
- * @author Parikshit Ram (pram at cc.gatech.edu)
- *
- * This file implements functions to perform different tasks with the Density
- * Tree class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_DET_DT_UTILS_HPP
-#define __MLPACK_METHODS_DET_DT_UTILS_HPP
-
-#include <string>
-
-#include <mlpack/core.hpp>
-#include "dtree.hpp"
-
-namespace mlpack {
-namespace det {
-
-/**
- * Print the membership of leaves of a density estimation tree given the labels
- * and number of classes. Optionally, pass the name of a file to print this
- * information to (otherwise stdout is used).
- *
- * @param dtree Tree to print membership of.
- * @param data Dataset tree is built upon.
- * @param labels Class labels of dataset.
- * @param numClasses Number of classes in dataset.
- * @param leafClassMembershipFile Name of file to print to (optional).
- */
-void PrintLeafMembership(DTree* dtree,
- const arma::mat& data,
- const arma::Mat<size_t>& labels,
- const size_t numClasses,
- const std::string leafClassMembershipFile = "");
-
-/**
- * Print the variable importance of each dimension of a density estimation tree.
- * Optionally, pass the name of a file to print this information to (otherwise
- * stdout is used).
- *
- * @param dtree Density tree to use.
- * @param viFile Name of file to print to (optional).
- */
-void PrintVariableImportance(const DTree* dtree,
- const std::string viFile = "");
-
-/**
- * Train the optimal decision tree using cross-validation with the given number
- * of folds. Optionally, give a filename to print the unpruned tree to. This
- * initializes a tree on the heap, so you are responsible for deleting it.
- *
- * @param dataset Dataset for the tree to use.
- * @param folds Number of folds to use for cross-validation.
- * @param useVolumeReg If true, use volume regularization.
- * @param maxLeafSize Maximum number of points allowed in a leaf.
- * @param minLeafSize Minimum number of points allowed in a leaf.
- * @param unprunedTreeOutput Filename to print unpruned tree to (optional).
- */
-DTree* Trainer(arma::mat& dataset,
- const size_t folds,
- const bool useVolumeReg = false,
- const size_t maxLeafSize = 10,
- const size_t minLeafSize = 5,
- const std::string unprunedTreeOutput = "");
-
-}; // namespace det
-}; // namespace mlpack
-
-#endif // __MLPACK_METHODS_DET_DT_UTILS_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dt_utils.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/det/dt_utils.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dt_utils.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dt_utils.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,84 @@
+/**
+ * @file dt_utils.hpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * This file implements functions to perform different tasks with the Density
+ * Tree class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_DET_DT_UTILS_HPP
+#define __MLPACK_METHODS_DET_DT_UTILS_HPP
+
+#include <string>
+
+#include <mlpack/core.hpp>
+#include "dtree.hpp"
+
+namespace mlpack {
+namespace det {
+
+/**
+ * Print the membership of leaves of a density estimation tree given the labels
+ * and number of classes. Optionally, pass the name of a file to print this
+ * information to (otherwise stdout is used).
+ *
+ * @param dtree Tree to print membership of.
+ * @param data Dataset tree is built upon.
+ * @param labels Class labels of dataset.
+ * @param numClasses Number of classes in dataset.
+ * @param leafClassMembershipFile Name of file to print to (optional).
+ */
+void PrintLeafMembership(DTree* dtree,
+ const arma::mat& data,
+ const arma::Mat<size_t>& labels,
+ const size_t numClasses,
+ const std::string leafClassMembershipFile = "");
+
+/**
+ * Print the variable importance of each dimension of a density estimation tree.
+ * Optionally, pass the name of a file to print this information to (otherwise
+ * stdout is used).
+ *
+ * @param dtree Density tree to use.
+ * @param viFile Name of file to print to (optional).
+ */
+void PrintVariableImportance(const DTree* dtree,
+ const std::string viFile = "");
+
+/**
+ * Train the optimal decision tree using cross-validation with the given number
+ * of folds. Optionally, give a filename to print the unpruned tree to. This
+ * initializes a tree on the heap, so you are responsible for deleting it.
+ *
+ * @param dataset Dataset for the tree to use.
+ * @param folds Number of folds to use for cross-validation.
+ * @param useVolumeReg If true, use volume regularization.
+ * @param maxLeafSize Maximum number of points allowed in a leaf.
+ * @param minLeafSize Minimum number of points allowed in a leaf.
+ * @param unprunedTreeOutput Filename to print unpruned tree to (optional).
+ */
+DTree* Trainer(arma::mat& dataset,
+ const size_t folds,
+ const bool useVolumeReg = false,
+ const size_t maxLeafSize = 10,
+ const size_t minLeafSize = 5,
+ const std::string unprunedTreeOutput = "");
+
+}; // namespace det
+}; // namespace mlpack
+
+#endif // __MLPACK_METHODS_DET_DT_UTILS_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dtree.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/det/dtree.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dtree.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,682 +0,0 @@
- /**
- * @file dtree.cpp
- * @author Parikshit Ram (pram at cc.gatech.edu)
- *
- * Implementations of some declared functions in
- * the Density Estimation Tree class.
- *
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "dtree.hpp"
-#include <stack>
-
-using namespace mlpack;
-using namespace det;
-
-DTree::DTree() :
- start(0),
- end(0),
- logNegError(-DBL_MAX),
- root(true),
- bucketTag(-1),
- left(NULL),
- right(NULL)
-{ /* Nothing to do. */ }
-
-
-// Root node initializers
-DTree::DTree(const arma::vec& maxVals,
- const arma::vec& minVals,
- const size_t totalPoints) :
- start(0),
- end(totalPoints),
- maxVals(maxVals),
- minVals(minVals),
- logNegError(LogNegativeError(totalPoints)),
- root(true),
- bucketTag(-1),
- left(NULL),
- right(NULL)
-{ /* Nothing to do. */ }
-
-DTree::DTree(arma::mat& data) :
- start(0),
- end(data.n_cols),
- left(NULL),
- right(NULL)
-{
- maxVals.set_size(data.n_rows);
- minVals.set_size(data.n_rows);
-
- // Initialize to first column; values will be overwritten if necessary.
- maxVals = data.col(0);
- minVals = data.col(0);
-
- // Loop over data to extract maximum and minimum values in each dimension.
- for (size_t i = 1; i < data.n_cols; ++i)
- {
- for (size_t j = 0; j < data.n_rows; ++j)
- {
- if (data(j, i) > maxVals[j])
- maxVals[j] = data(j, i);
- if (data(j, i) < minVals[j])
- minVals[j] = data(j, i);
- }
- }
-
- logNegError = LogNegativeError(data.n_cols);
-
- bucketTag = -1;
- root = true;
-}
-
-
-// Non-root node initializers
-DTree::DTree(const arma::vec& maxVals,
- const arma::vec& minVals,
- const size_t start,
- const size_t end,
- const double logNegError) :
- start(start),
- end(end),
- maxVals(maxVals),
- minVals(minVals),
- logNegError(logNegError),
- root(false),
- bucketTag(-1),
- left(NULL),
- right(NULL)
-{ /* Nothing to do. */ }
-
-DTree::DTree(const arma::vec& maxVals,
- const arma::vec& minVals,
- const size_t totalPoints,
- const size_t start,
- const size_t end) :
- start(start),
- end(end),
- maxVals(maxVals),
- minVals(minVals),
- logNegError(LogNegativeError(totalPoints)),
- root(false),
- bucketTag(-1),
- left(NULL),
- right(NULL)
-{ /* Nothing to do. */ }
-
-DTree::~DTree()
-{
- if (left != NULL)
- delete left;
-
- if (right != NULL)
- delete right;
-}
-
-// This function computes the log-l2-negative-error of a given node from the
-// formula R(t) = log(|t|^2 / (N^2 V_t)).
-double DTree::LogNegativeError(const size_t totalPoints) const
-{
- // log(-|t|^2 / (N^2 V_t)) = log(-1) + 2 log(|t|) - 2 log(N) - log(V_t).
- return 2 * std::log((double) (end - start)) -
- 2 * std::log((double) totalPoints) -
- arma::accu(arma::log(maxVals - minVals));
-}
-
-// This function finds the best split with respect to the L2-error, by trying
-// all possible splits. The dataset is the full data set but the start and
-// end are used to obtain the point in this node.
-bool DTree::FindSplit(const arma::mat& data,
- size_t& splitDim,
- double& splitValue,
- double& leftError,
- double& rightError,
- const size_t maxLeafSize,
- const size_t minLeafSize) const
-{
- // Ensure the dimensionality of the data is the same as the dimensionality of
- // the bounding rectangle.
- assert(data.n_rows == maxVals.n_elem);
- assert(data.n_rows == minVals.n_elem);
-
- const size_t points = end - start;
-
- double minError = logNegError;
- bool splitFound = false;
-
- // Loop through each dimension.
- for (size_t dim = 0; dim < maxVals.n_elem; dim++)
- {
- // Have to deal with REAL, INTEGER, NOMINAL data differently, so we have to
- // think of how to do that...
- const double min = minVals[dim];
- const double max = maxVals[dim];
-
- // If there is nothing to split in this dimension, move on.
- if (max - min == 0.0)
- continue; // Skip to next dimension.
-
- // Initializing all the stuff for this dimension.
- bool dimSplitFound = false;
- // Take an error estimate for this dimension.
- double minDimError = std::pow(points, 2.0) / (max - min);
- double dimLeftError;
- double dimRightError;
- double dimSplitValue;
-
- // Find the log volume of all the other dimensions.
- double volumeWithoutDim = logVolume - std::log(max - min);
-
- // Get the values for the dimension.
- arma::rowvec dimVec = data.row(dim).subvec(start, end - 1);
-
- // Sort the values in ascending order.
- dimVec = arma::sort(dimVec);
-
- // Get ready to go through the sorted list and compute error.
- assert(dimVec.n_elem > maxLeafSize);
-
- // Find the best split for this dimension. We need to figure out why
- // there are spikes if this minLeafSize is enforced here...
- for (size_t i = minLeafSize - 1; i < dimVec.n_elem - minLeafSize; ++i)
- {
- // This makes sense for real continuous data. This kinda corrupts the
- // data and estimation if the data is ordinal.
- const double split = (dimVec[i] + dimVec[i + 1]) / 2.0;
-
- if (split == dimVec[i])
- continue; // We can't split here (two points are the same).
-
- // Another way of picking split is using this:
- // split = leftsplit;
- if ((split - min > 0.0) && (max - split > 0.0))
- {
- // Ensure that the right node will have at least the minimum number of
- // points.
- Log::Assert((points - i - 1) >= minLeafSize);
-
- // Now we have to see if the error will be reduced. Simple manipulation
- // of the error function gives us the condition we must satisfy:
- // |t_l|^2 / V_l + |t_r|^2 / V_r >= |t|^2 / (V_l + V_r)
- // and because the volume is only dependent on the dimension we are
- // splitting, we can assume V_l is just the range of the left and V_r is
- // just the range of the right.
- double negLeftError = std::pow(i + 1, 2.0) / (split - min);
- double negRightError = std::pow(points - i - 1, 2.0) / (max - split);
-
- // If this is better, take it.
- if ((negLeftError + negRightError) >= minDimError)
- {
- minDimError = negLeftError + negRightError;
- dimLeftError = negLeftError;
- dimRightError = negRightError;
- dimSplitValue = split;
- dimSplitFound = true;
- }
- }
- }
-
- double actualMinDimError = std::log(minDimError)
- - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
-
- if ((actualMinDimError > minError) && dimSplitFound)
- {
- // Calculate actual error (in logspace) by adding terms back to our
- // estimate.
- minError = actualMinDimError;
- splitDim = dim;
- splitValue = dimSplitValue;
- leftError = std::log(dimLeftError) - 2 * std::log((double) data.n_cols)
- - volumeWithoutDim;
- rightError = std::log(dimRightError) - 2 * std::log((double) data.n_cols)
- - volumeWithoutDim;
- splitFound = true;
- } // end if better split found in this dimension.
- }
-
- return splitFound;
-}
-
-size_t DTree::SplitData(arma::mat& data,
- const size_t splitDim,
- const double splitValue,
- arma::Col<size_t>& oldFromNew) const
-{
- // Swap all columns such that any columns with value in dimension splitDim
- // less than or equal to splitValue are on the left side, and all others are
- // on the right side. A similar sort to this is also performed in
- // BinarySpaceTree construction (its comments are more detailed).
- size_t left = start;
- size_t right = end - 1;
- for (;;)
- {
- while (data(splitDim, left) <= splitValue)
- ++left;
- while (data(splitDim, right) > splitValue)
- --right;
-
- if (left > right)
- break;
-
- data.swap_cols(left, right);
-
- // Store the mapping from old to new.
- const size_t tmp = oldFromNew[left];
- oldFromNew[left] = oldFromNew[right];
- oldFromNew[right] = tmp;
- }
-
- // This now refers to the first index of the "right" side.
- return left;
-}
-
-// Greedily expand the tree
-double DTree::Grow(arma::mat& data,
- arma::Col<size_t>& oldFromNew,
- const bool useVolReg,
- const size_t maxLeafSize,
- const size_t minLeafSize)
-{
- assert(data.n_rows == maxVals.n_elem);
- assert(data.n_rows == minVals.n_elem);
-
- double leftG, rightG;
-
- // Compute points ratio.
- ratio = (double) (end - start) / (double) oldFromNew.n_elem;
-
- // Compute the log of the volume of the node.
- logVolume = 0;
- for (size_t i = 0; i < maxVals.n_elem; ++i)
- if (maxVals[i] - minVals[i] > 0.0)
- logVolume += std::log(maxVals[i] - minVals[i]);
-
- // Check if node is large enough to split.
- if ((size_t) (end - start) > maxLeafSize) {
-
- // Find the split.
- size_t dim;
- double splitValueTmp;
- double leftError, rightError;
- if (FindSplit(data, dim, splitValueTmp, leftError, rightError, maxLeafSize,
- minLeafSize))
- {
- // Move the data around for the children to have points in a node lie
- // contiguously (to increase efficiency during the training).
- const size_t splitIndex = SplitData(data, dim, splitValueTmp, oldFromNew);
-
- // Make max and min vals for the children.
- arma::vec maxValsL(maxVals);
- arma::vec maxValsR(maxVals);
- arma::vec minValsL(minVals);
- arma::vec minValsR(minVals);
-
- maxValsL[dim] = splitValueTmp;
- minValsR[dim] = splitValueTmp;
-
- // Store split dim and split val in the node.
- splitValue = splitValueTmp;
- splitDim = dim;
-
- // Recursively grow the children.
- left = new DTree(maxValsL, minValsL, start, splitIndex, leftError);
- right = new DTree(maxValsR, minValsR, splitIndex, end, rightError);
-
- leftG = left->Grow(data, oldFromNew, useVolReg, maxLeafSize,
- minLeafSize);
- rightG = right->Grow(data, oldFromNew, useVolReg, maxLeafSize,
- minLeafSize);
-
- // Store values of R(T~) and |T~|.
- subtreeLeaves = left->SubtreeLeaves() + right->SubtreeLeaves();
-
- // Find the log negative error of the subtree leaves. This is kind of an
- // odd one because we don't want to represent the error in non-log-space,
- // but we have to calculate log(E_l + E_r). So we multiply E_l and E_r by
- // V_t (remember E_l has an inverse relationship to the volume of the
- // nodes) and then subtract log(V_t) at the end of the whole expression.
- // As a result we do leave log-space, but the largest quantity we
- // represent is on the order of (V_t / V_i) where V_i is the smallest leaf
- // node below this node, which depends heavily on the depth of the tree.
- subtreeLeavesLogNegError = std::log(
- std::exp(logVolume + left->SubtreeLeavesLogNegError()) +
- std::exp(logVolume + right->SubtreeLeavesLogNegError()))
- - logVolume;
- }
- else
- {
- // No split found so make a leaf out of it.
- subtreeLeaves = 1;
- subtreeLeavesLogNegError = logNegError;
- }
- }
- else
- {
- // We can make this a leaf node.
- assert((size_t) (end - start) >= minLeafSize);
- subtreeLeaves = 1;
- subtreeLeavesLogNegError = logNegError;
- }
-
- // If this is a leaf, do not compute g_k(t); otherwise compute, store, and
- // propagate min(g_k(t_L), g_k(t_R), g_k(t)), unless t_L and/or t_R are
- // leaves.
- if (subtreeLeaves == 1)
- {
- return std::numeric_limits<double>::max();
- }
- else
- {
- const double range = maxVals[splitDim] - minVals[splitDim];
- const double leftRatio = (splitValue - minVals[splitDim]) / range;
- const double rightRatio = (maxVals[splitDim] - splitValue) / range;
-
- const size_t leftPow = std::pow((double) (left->End() - left->Start()), 2);
- const size_t rightPow = std::pow((double) (right->End() - right->Start()),
- 2);
- const size_t thisPow = std::pow((double) (end - start), 2);
-
- double tmpAlphaSum = leftPow / leftRatio + rightPow / rightRatio - thisPow;
-
- if (left->SubtreeLeaves() > 1)
- {
- const double exponent = 2 * std::log((double) data.n_cols) + logVolume +
- left->AlphaUpper();
-
- // Whether or not this will overflow is highly dependent on the depth of
- // the tree.
- tmpAlphaSum += std::exp(exponent);
- }
-
- if (right->SubtreeLeaves() > 1)
- {
- const double exponent = 2 * std::log((double) data.n_cols) + logVolume +
- right->AlphaUpper();
-
- tmpAlphaSum += std::exp(exponent);
- }
-
- alphaUpper = std::log(tmpAlphaSum) - 2 * std::log((double) data.n_cols)
- - logVolume;
-
- double gT;
- if (useVolReg)
- {
- // This is wrong for now!
- gT = alphaUpper;// / (subtreeLeavesVTInv - vTInv);
- }
- else
- {
- gT = alphaUpper - std::log((double) (subtreeLeaves - 1));
- }
-
- return std::min(gT, std::min(leftG, rightG));
- }
-
- // We need to compute (c_t^2) * r_t for all subtree leaves; this is equal to
- // n_t ^ 2 / r_t * n ^ 2 = -error. Therefore the value we need is actually
- // -1.0 * subtreeLeavesError.
-}
-
-
-double DTree::PruneAndUpdate(const double oldAlpha,
- const size_t points,
- const bool useVolReg)
-
-{
- // Compute gT.
- if (subtreeLeaves == 1) // If we are a leaf...
- {
- return std::numeric_limits<double>::max();
- }
- else
- {
- // Compute gT value for node t.
- volatile double gT;
- if (useVolReg)
- gT = alphaUpper;// - std::log(subtreeLeavesVTInv - vTInv);
- else
- gT = alphaUpper - std::log((double) (subtreeLeaves - 1));
-
- if (gT > oldAlpha)
- {
- // Go down the tree and update accordingly. Traverse the children.
- double leftG = left->PruneAndUpdate(oldAlpha, points, useVolReg);
- double rightG = right->PruneAndUpdate(oldAlpha, points, useVolReg);
-
- // Update values.
- subtreeLeaves = left->SubtreeLeaves() + right->SubtreeLeaves();
-
- // Find the log negative error of the subtree leaves. This is kind of an
- // odd one because we don't want to represent the error in non-log-space,
- // but we have to calculate log(E_l + E_r). So we multiply E_l and E_r by
- // V_t (remember E_l has an inverse relationship to the volume of the
- // nodes) and then subtract log(V_t) at the end of the whole expression.
- // As a result we do leave log-space, but the largest quantity we
- // represent is on the order of (V_t / V_i) where V_i is the smallest leaf
- // node below this node, which depends heavily on the depth of the tree.
- subtreeLeavesLogNegError = std::log(
- std::exp(logVolume + left->SubtreeLeavesLogNegError()) +
- std::exp(logVolume + right->SubtreeLeavesLogNegError()))
- - logVolume;
-
- // Recalculate upper alpha.
- const double range = maxVals[splitDim] - minVals[splitDim];
- const double leftRatio = (splitValue - minVals[splitDim]) / range;
- const double rightRatio = (maxVals[splitDim] - splitValue) / range;
-
- const size_t leftPow = std::pow((double) (left->End() - left->Start()),
- 2);
- const size_t rightPow = std::pow((double) (right->End() - right->Start()),
- 2);
- const size_t thisPow = std::pow((double) (end - start), 2);
-
- double tmpAlphaSum = leftPow / leftRatio + rightPow / rightRatio -
- thisPow;
-
- if (left->SubtreeLeaves() > 1)
- {
- const double exponent = 2 * std::log((double) points) + logVolume +
- left->AlphaUpper();
-
- // Whether or not this will overflow is highly dependent on the depth of
- // the tree.
- tmpAlphaSum += std::exp(exponent);
- }
-
- if (right->SubtreeLeaves() > 1)
- {
- const double exponent = 2 * std::log((double) points) + logVolume +
- right->AlphaUpper();
-
- tmpAlphaSum += std::exp(exponent);
- }
-
- alphaUpper = std::log(tmpAlphaSum) - 2 * std::log((double) points) -
- logVolume;
-
- // Update gT value.
- if (useVolReg)
- {
- // This is incorrect.
- gT = alphaUpper; // / (subtreeLeavesVTInv - vTInv);
- }
- else
- {
- gT = alphaUpper - std::log((double) (subtreeLeaves - 1));
- }
-
- Log::Assert(gT < std::numeric_limits<double>::max());
-
- return std::min((double) gT, std::min(leftG, rightG));
- }
- else
- {
- // Prune this subtree.
- // First, make this node a leaf node.
- subtreeLeaves = 1;
- subtreeLeavesLogNegError = logNegError;
-
- delete left;
- delete right;
-
- left = NULL;
- right = NULL;
-
- // Pass information upward.
- return std::numeric_limits<double>::max();
- }
- }
-}
-
-// Check whether a given point is within the bounding box of this node (check
-// generally done at the root, so its the bounding box of the data).
-//
-// Future improvement: Open up the range with epsilons on both sides where
-// epsilon depends on the density near the boundary.
-bool DTree::WithinRange(const arma::vec& query) const
-{
- for (size_t i = 0; i < query.n_elem; ++i)
- if ((query[i] < minVals[i]) || (query[i] > maxVals[i]))
- return false;
-
- return true;
-}
-
-
-double DTree::ComputeValue(const arma::vec& query) const
-{
- Log::Assert(query.n_elem == maxVals.n_elem);
-
- if (root == 1) // If we are the root...
- {
- // Check if the query is within range.
- if (!WithinRange(query))
- return 0.0;
- }
-
- if (subtreeLeaves == 1) // If we are a leaf...
- {
- return std::exp(std::log(ratio) - logVolume);
- }
- else
- {
- if (query[splitDim] <= splitValue)
- {
- // If left subtree, go to left child.
- return left->ComputeValue(query);
- }
- else // If right subtree, go to right child
- {
- return right->ComputeValue(query);
- }
- }
-
- return 0.0;
-}
-
-
-void DTree::WriteTree(FILE *fp, const size_t level) const
-{
- if (subtreeLeaves > 1)
- {
- fprintf(fp, "\n");
- for (size_t i = 0; i < level; ++i)
- fprintf(fp, "|\t");
- fprintf(fp, "Var. %zu > %lg", splitDim, splitValue);
-
- right->WriteTree(fp, level + 1);
-
- fprintf(fp, "\n");
- for (size_t i = 0; i < level; ++i)
- fprintf(fp, "|\t");
- fprintf(fp, "Var. %zu <= %lg ", splitDim, splitValue);
-
- left->WriteTree(fp, level);
- }
- else // If we are a leaf...
- {
- fprintf(fp, ": f(x)=%lg", std::exp(std::log(ratio) - logVolume));
- if (bucketTag != -1)
- fprintf(fp, " BT:%d", bucketTag);
- }
-}
-
-
-// Index the buckets for possible usage later.
-int DTree::TagTree(const int tag)
-{
- if (subtreeLeaves == 1)
- {
- // Only label leaves.
- bucketTag = tag;
- return (tag + 1);
- }
- else
- {
- return right->TagTree(left->TagTree(tag));
- }
-}
-
-
-int DTree::FindBucket(const arma::vec& query) const
-{
- Log::Assert(query.n_elem == maxVals.n_elem);
-
- if (subtreeLeaves == 1) // If we are a leaf...
- {
- return bucketTag;
- }
- else if (query[splitDim] <= splitValue)
- {
- // If left subtree, go to left child.
- return left->FindBucket(query);
- }
- else
- {
- // If right subtree, go to right child.
- return right->FindBucket(query);
- }
-}
-
-
-void DTree::ComputeVariableImportance(arma::vec& importances) const
-{
- // Clear and set to right size.
- importances.zeros(maxVals.n_elem);
-
- std::stack<const DTree*> nodes;
- nodes.push(this);
-
- while(!nodes.empty())
- {
- const DTree& curNode = *nodes.top();
- nodes.pop();
-
- if (curNode.subtreeLeaves == 1)
- continue; // Do nothing for leaves.
-
- // The way to do this entirely in log-space is (at this time) somewhat
- // unclear. So this risks overflow.
- importances[curNode.SplitDim()] += (-std::exp(curNode.LogNegError()) -
- (-std::exp(curNode.Left()->LogNegError()) +
- -std::exp(curNode.Right()->LogNegError())));
-
- nodes.push(curNode.Left());
- nodes.push(curNode.Right());
- }
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dtree.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/det/dtree.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dtree.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dtree.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,682 @@
+ /**
+ * @file dtree.cpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * Implementations of some declared functions in
+ * the Density Estimation Tree class.
+ *
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "dtree.hpp"
+#include <stack>
+
+using namespace mlpack;
+using namespace det;
+
+DTree::DTree() :
+ start(0),
+ end(0),
+ logNegError(-DBL_MAX),
+ root(true),
+ bucketTag(-1),
+ left(NULL),
+ right(NULL)
+{ /* Nothing to do. */ }
+
+
+// Root node initializers
+DTree::DTree(const arma::vec& maxVals,
+ const arma::vec& minVals,
+ const size_t totalPoints) :
+ start(0),
+ end(totalPoints),
+ maxVals(maxVals),
+ minVals(minVals),
+ logNegError(LogNegativeError(totalPoints)),
+ root(true),
+ bucketTag(-1),
+ left(NULL),
+ right(NULL)
+{ /* Nothing to do. */ }
+
+DTree::DTree(arma::mat& data) :
+ start(0),
+ end(data.n_cols),
+ left(NULL),
+ right(NULL)
+{
+ maxVals.set_size(data.n_rows);
+ minVals.set_size(data.n_rows);
+
+ // Initialize to first column; values will be overwritten if necessary.
+ maxVals = data.col(0);
+ minVals = data.col(0);
+
+ // Loop over data to extract maximum and minimum values in each dimension.
+ for (size_t i = 1; i < data.n_cols; ++i)
+ {
+ for (size_t j = 0; j < data.n_rows; ++j)
+ {
+ if (data(j, i) > maxVals[j])
+ maxVals[j] = data(j, i);
+ if (data(j, i) < minVals[j])
+ minVals[j] = data(j, i);
+ }
+ }
+
+ logNegError = LogNegativeError(data.n_cols);
+
+ bucketTag = -1;
+ root = true;
+}
+
+
+// Non-root node initializers
+DTree::DTree(const arma::vec& maxVals,
+ const arma::vec& minVals,
+ const size_t start,
+ const size_t end,
+ const double logNegError) :
+ start(start),
+ end(end),
+ maxVals(maxVals),
+ minVals(minVals),
+ logNegError(logNegError),
+ root(false),
+ bucketTag(-1),
+ left(NULL),
+ right(NULL)
+{ /* Nothing to do. */ }
+
+DTree::DTree(const arma::vec& maxVals,
+ const arma::vec& minVals,
+ const size_t totalPoints,
+ const size_t start,
+ const size_t end) :
+ start(start),
+ end(end),
+ maxVals(maxVals),
+ minVals(minVals),
+ logNegError(LogNegativeError(totalPoints)),
+ root(false),
+ bucketTag(-1),
+ left(NULL),
+ right(NULL)
+{ /* Nothing to do. */ }
+
+DTree::~DTree()
+{
+ if (left != NULL)
+ delete left;
+
+ if (right != NULL)
+ delete right;
+}
+
+// This function computes the log-l2-negative-error of a given node from the
+// formula R(t) = log(|t|^2 / (N^2 V_t)).
+double DTree::LogNegativeError(const size_t totalPoints) const
+{
+ // log(-|t|^2 / (N^2 V_t)) = log(-1) + 2 log(|t|) - 2 log(N) - log(V_t).
+ return 2 * std::log((double) (end - start)) -
+ 2 * std::log((double) totalPoints) -
+ arma::accu(arma::log(maxVals - minVals));
+}
+
+// This function finds the best split with respect to the L2-error, by trying
+// all possible splits. The dataset is the full data set but the start and
+// end are used to obtain the point in this node.
+bool DTree::FindSplit(const arma::mat& data,
+ size_t& splitDim,
+ double& splitValue,
+ double& leftError,
+ double& rightError,
+ const size_t maxLeafSize,
+ const size_t minLeafSize) const
+{
+ // Ensure the dimensionality of the data is the same as the dimensionality of
+ // the bounding rectangle.
+ assert(data.n_rows == maxVals.n_elem);
+ assert(data.n_rows == minVals.n_elem);
+
+ const size_t points = end - start;
+
+ double minError = logNegError;
+ bool splitFound = false;
+
+ // Loop through each dimension.
+ for (size_t dim = 0; dim < maxVals.n_elem; dim++)
+ {
+ // Have to deal with REAL, INTEGER, NOMINAL data differently, so we have to
+ // think of how to do that...
+ const double min = minVals[dim];
+ const double max = maxVals[dim];
+
+ // If there is nothing to split in this dimension, move on.
+ if (max - min == 0.0)
+ continue; // Skip to next dimension.
+
+ // Initializing all the stuff for this dimension.
+ bool dimSplitFound = false;
+ // Take an error estimate for this dimension.
+ double minDimError = std::pow(points, 2.0) / (max - min);
+ double dimLeftError;
+ double dimRightError;
+ double dimSplitValue;
+
+ // Find the log volume of all the other dimensions.
+ double volumeWithoutDim = logVolume - std::log(max - min);
+
+ // Get the values for the dimension.
+ arma::rowvec dimVec = data.row(dim).subvec(start, end - 1);
+
+ // Sort the values in ascending order.
+ dimVec = arma::sort(dimVec);
+
+ // Get ready to go through the sorted list and compute error.
+ assert(dimVec.n_elem > maxLeafSize);
+
+ // Find the best split for this dimension. We need to figure out why
+ // there are spikes if this minLeafSize is enforced here...
+ for (size_t i = minLeafSize - 1; i < dimVec.n_elem - minLeafSize; ++i)
+ {
+ // This makes sense for real continuous data. This kinda corrupts the
+ // data and estimation if the data is ordinal.
+ const double split = (dimVec[i] + dimVec[i + 1]) / 2.0;
+
+ if (split == dimVec[i])
+ continue; // We can't split here (two points are the same).
+
+ // Another way of picking split is using this:
+ // split = leftsplit;
+ if ((split - min > 0.0) && (max - split > 0.0))
+ {
+ // Ensure that the right node will have at least the minimum number of
+ // points.
+ Log::Assert((points - i - 1) >= minLeafSize);
+
+ // Now we have to see if the error will be reduced. Simple manipulation
+ // of the error function gives us the condition we must satisfy:
+ // |t_l|^2 / V_l + |t_r|^2 / V_r >= |t|^2 / (V_l + V_r)
+ // and because the volume is only dependent on the dimension we are
+ // splitting, we can assume V_l is just the range of the left and V_r is
+ // just the range of the right.
+ double negLeftError = std::pow(i + 1, 2.0) / (split - min);
+ double negRightError = std::pow(points - i - 1, 2.0) / (max - split);
+
+ // If this is better, take it.
+ if ((negLeftError + negRightError) >= minDimError)
+ {
+ minDimError = negLeftError + negRightError;
+ dimLeftError = negLeftError;
+ dimRightError = negRightError;
+ dimSplitValue = split;
+ dimSplitFound = true;
+ }
+ }
+ }
+
+ double actualMinDimError = std::log(minDimError)
+ - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
+
+ if ((actualMinDimError > minError) && dimSplitFound)
+ {
+ // Calculate actual error (in logspace) by adding terms back to our
+ // estimate.
+ minError = actualMinDimError;
+ splitDim = dim;
+ splitValue = dimSplitValue;
+ leftError = std::log(dimLeftError) - 2 * std::log((double) data.n_cols)
+ - volumeWithoutDim;
+ rightError = std::log(dimRightError) - 2 * std::log((double) data.n_cols)
+ - volumeWithoutDim;
+ splitFound = true;
+ } // end if better split found in this dimension.
+ }
+
+ return splitFound;
+}
+
+size_t DTree::SplitData(arma::mat& data,
+ const size_t splitDim,
+ const double splitValue,
+ arma::Col<size_t>& oldFromNew) const
+{
+ // Swap all columns such that any columns with value in dimension splitDim
+ // less than or equal to splitValue are on the left side, and all others are
+ // on the right side. A similar sort to this is also performed in
+ // BinarySpaceTree construction (its comments are more detailed).
+ size_t left = start;
+ size_t right = end - 1;
+ for (;;)
+ {
+ while (data(splitDim, left) <= splitValue)
+ ++left;
+ while (data(splitDim, right) > splitValue)
+ --right;
+
+ if (left > right)
+ break;
+
+ data.swap_cols(left, right);
+
+ // Store the mapping from old to new.
+ const size_t tmp = oldFromNew[left];
+ oldFromNew[left] = oldFromNew[right];
+ oldFromNew[right] = tmp;
+ }
+
+ // This now refers to the first index of the "right" side.
+ return left;
+}
+
+// Greedily expand the tree
+double DTree::Grow(arma::mat& data,
+ arma::Col<size_t>& oldFromNew,
+ const bool useVolReg,
+ const size_t maxLeafSize,
+ const size_t minLeafSize)
+{
+ assert(data.n_rows == maxVals.n_elem);
+ assert(data.n_rows == minVals.n_elem);
+
+ double leftG, rightG;
+
+ // Compute points ratio.
+ ratio = (double) (end - start) / (double) oldFromNew.n_elem;
+
+ // Compute the log of the volume of the node.
+ logVolume = 0;
+ for (size_t i = 0; i < maxVals.n_elem; ++i)
+ if (maxVals[i] - minVals[i] > 0.0)
+ logVolume += std::log(maxVals[i] - minVals[i]);
+
+ // Check if node is large enough to split.
+ if ((size_t) (end - start) > maxLeafSize) {
+
+ // Find the split.
+ size_t dim;
+ double splitValueTmp;
+ double leftError, rightError;
+ if (FindSplit(data, dim, splitValueTmp, leftError, rightError, maxLeafSize,
+ minLeafSize))
+ {
+ // Move the data around for the children to have points in a node lie
+ // contiguously (to increase efficiency during the training).
+ const size_t splitIndex = SplitData(data, dim, splitValueTmp, oldFromNew);
+
+ // Make max and min vals for the children.
+ arma::vec maxValsL(maxVals);
+ arma::vec maxValsR(maxVals);
+ arma::vec minValsL(minVals);
+ arma::vec minValsR(minVals);
+
+ maxValsL[dim] = splitValueTmp;
+ minValsR[dim] = splitValueTmp;
+
+ // Store split dim and split val in the node.
+ splitValue = splitValueTmp;
+ splitDim = dim;
+
+ // Recursively grow the children.
+ left = new DTree(maxValsL, minValsL, start, splitIndex, leftError);
+ right = new DTree(maxValsR, minValsR, splitIndex, end, rightError);
+
+ leftG = left->Grow(data, oldFromNew, useVolReg, maxLeafSize,
+ minLeafSize);
+ rightG = right->Grow(data, oldFromNew, useVolReg, maxLeafSize,
+ minLeafSize);
+
+ // Store values of R(T~) and |T~|.
+ subtreeLeaves = left->SubtreeLeaves() + right->SubtreeLeaves();
+
+ // Find the log negative error of the subtree leaves. This is kind of an
+ // odd one because we don't want to represent the error in non-log-space,
+ // but we have to calculate log(E_l + E_r). So we multiply E_l and E_r by
+ // V_t (remember E_l has an inverse relationship to the volume of the
+ // nodes) and then subtract log(V_t) at the end of the whole expression.
+ // As a result we do leave log-space, but the largest quantity we
+ // represent is on the order of (V_t / V_i) where V_i is the smallest leaf
+ // node below this node, which depends heavily on the depth of the tree.
+ subtreeLeavesLogNegError = std::log(
+ std::exp(logVolume + left->SubtreeLeavesLogNegError()) +
+ std::exp(logVolume + right->SubtreeLeavesLogNegError()))
+ - logVolume;
+ }
+ else
+ {
+ // No split found so make a leaf out of it.
+ subtreeLeaves = 1;
+ subtreeLeavesLogNegError = logNegError;
+ }
+ }
+ else
+ {
+ // We can make this a leaf node.
+ assert((size_t) (end - start) >= minLeafSize);
+ subtreeLeaves = 1;
+ subtreeLeavesLogNegError = logNegError;
+ }
+
+ // If this is a leaf, do not compute g_k(t); otherwise compute, store, and
+ // propagate min(g_k(t_L), g_k(t_R), g_k(t)), unless t_L and/or t_R are
+ // leaves.
+ if (subtreeLeaves == 1)
+ {
+ return std::numeric_limits<double>::max();
+ }
+ else
+ {
+ const double range = maxVals[splitDim] - minVals[splitDim];
+ const double leftRatio = (splitValue - minVals[splitDim]) / range;
+ const double rightRatio = (maxVals[splitDim] - splitValue) / range;
+
+ const size_t leftPow = std::pow((double) (left->End() - left->Start()), 2);
+ const size_t rightPow = std::pow((double) (right->End() - right->Start()),
+ 2);
+ const size_t thisPow = std::pow((double) (end - start), 2);
+
+ double tmpAlphaSum = leftPow / leftRatio + rightPow / rightRatio - thisPow;
+
+ if (left->SubtreeLeaves() > 1)
+ {
+ const double exponent = 2 * std::log((double) data.n_cols) + logVolume +
+ left->AlphaUpper();
+
+ // Whether or not this will overflow is highly dependent on the depth of
+ // the tree.
+ tmpAlphaSum += std::exp(exponent);
+ }
+
+ if (right->SubtreeLeaves() > 1)
+ {
+ const double exponent = 2 * std::log((double) data.n_cols) + logVolume +
+ right->AlphaUpper();
+
+ tmpAlphaSum += std::exp(exponent);
+ }
+
+ alphaUpper = std::log(tmpAlphaSum) - 2 * std::log((double) data.n_cols)
+ - logVolume;
+
+ double gT;
+ if (useVolReg)
+ {
+ // This is wrong for now!
+ gT = alphaUpper;// / (subtreeLeavesVTInv - vTInv);
+ }
+ else
+ {
+ gT = alphaUpper - std::log((double) (subtreeLeaves - 1));
+ }
+
+ return std::min(gT, std::min(leftG, rightG));
+ }
+
+ // We need to compute (c_t^2) * r_t for all subtree leaves; this is equal to
+ // n_t ^ 2 / r_t * n ^ 2 = -error. Therefore the value we need is actually
+ // -1.0 * subtreeLeavesError.
+}
+
+
+double DTree::PruneAndUpdate(const double oldAlpha,
+ const size_t points,
+ const bool useVolReg)
+
+{
+ // Compute gT.
+ if (subtreeLeaves == 1) // If we are a leaf...
+ {
+ return std::numeric_limits<double>::max();
+ }
+ else
+ {
+ // Compute gT value for node t.
+ volatile double gT;
+ if (useVolReg)
+ gT = alphaUpper;// - std::log(subtreeLeavesVTInv - vTInv);
+ else
+ gT = alphaUpper - std::log((double) (subtreeLeaves - 1));
+
+ if (gT > oldAlpha)
+ {
+ // Go down the tree and update accordingly. Traverse the children.
+ double leftG = left->PruneAndUpdate(oldAlpha, points, useVolReg);
+ double rightG = right->PruneAndUpdate(oldAlpha, points, useVolReg);
+
+ // Update values.
+ subtreeLeaves = left->SubtreeLeaves() + right->SubtreeLeaves();
+
+ // Find the log negative error of the subtree leaves. This is kind of an
+ // odd one because we don't want to represent the error in non-log-space,
+ // but we have to calculate log(E_l + E_r). So we multiply E_l and E_r by
+ // V_t (remember E_l has an inverse relationship to the volume of the
+ // nodes) and then subtract log(V_t) at the end of the whole expression.
+ // As a result we do leave log-space, but the largest quantity we
+ // represent is on the order of (V_t / V_i) where V_i is the smallest leaf
+ // node below this node, which depends heavily on the depth of the tree.
+ subtreeLeavesLogNegError = std::log(
+ std::exp(logVolume + left->SubtreeLeavesLogNegError()) +
+ std::exp(logVolume + right->SubtreeLeavesLogNegError()))
+ - logVolume;
+
+ // Recalculate upper alpha.
+ const double range = maxVals[splitDim] - minVals[splitDim];
+ const double leftRatio = (splitValue - minVals[splitDim]) / range;
+ const double rightRatio = (maxVals[splitDim] - splitValue) / range;
+
+ const size_t leftPow = std::pow((double) (left->End() - left->Start()),
+ 2);
+ const size_t rightPow = std::pow((double) (right->End() - right->Start()),
+ 2);
+ const size_t thisPow = std::pow((double) (end - start), 2);
+
+ double tmpAlphaSum = leftPow / leftRatio + rightPow / rightRatio -
+ thisPow;
+
+ if (left->SubtreeLeaves() > 1)
+ {
+ const double exponent = 2 * std::log((double) points) + logVolume +
+ left->AlphaUpper();
+
+ // Whether or not this will overflow is highly dependent on the depth of
+ // the tree.
+ tmpAlphaSum += std::exp(exponent);
+ }
+
+ if (right->SubtreeLeaves() > 1)
+ {
+ const double exponent = 2 * std::log((double) points) + logVolume +
+ right->AlphaUpper();
+
+ tmpAlphaSum += std::exp(exponent);
+ }
+
+ alphaUpper = std::log(tmpAlphaSum) - 2 * std::log((double) points) -
+ logVolume;
+
+ // Update gT value.
+ if (useVolReg)
+ {
+ // This is incorrect.
+ gT = alphaUpper; // / (subtreeLeavesVTInv - vTInv);
+ }
+ else
+ {
+ gT = alphaUpper - std::log((double) (subtreeLeaves - 1));
+ }
+
+ Log::Assert(gT < std::numeric_limits<double>::max());
+
+ return std::min((double) gT, std::min(leftG, rightG));
+ }
+ else
+ {
+ // Prune this subtree.
+ // First, make this node a leaf node.
+ subtreeLeaves = 1;
+ subtreeLeavesLogNegError = logNegError;
+
+ delete left;
+ delete right;
+
+ left = NULL;
+ right = NULL;
+
+ // Pass information upward.
+ return std::numeric_limits<double>::max();
+ }
+ }
+}
+
+// Check whether a given point is within the bounding box of this node (check
+// generally done at the root, so its the bounding box of the data).
+//
+// Future improvement: Open up the range with epsilons on both sides where
+// epsilon depends on the density near the boundary.
+bool DTree::WithinRange(const arma::vec& query) const
+{
+ for (size_t i = 0; i < query.n_elem; ++i)
+ if ((query[i] < minVals[i]) || (query[i] > maxVals[i]))
+ return false;
+
+ return true;
+}
+
+
+double DTree::ComputeValue(const arma::vec& query) const
+{
+ Log::Assert(query.n_elem == maxVals.n_elem);
+
+ if (root == 1) // If we are the root...
+ {
+ // Check if the query is within range.
+ if (!WithinRange(query))
+ return 0.0;
+ }
+
+ if (subtreeLeaves == 1) // If we are a leaf...
+ {
+ return std::exp(std::log(ratio) - logVolume);
+ }
+ else
+ {
+ if (query[splitDim] <= splitValue)
+ {
+ // If left subtree, go to left child.
+ return left->ComputeValue(query);
+ }
+ else // If right subtree, go to right child
+ {
+ return right->ComputeValue(query);
+ }
+ }
+
+ return 0.0;
+}
+
+
+void DTree::WriteTree(FILE *fp, const size_t level) const
+{
+ if (subtreeLeaves > 1)
+ {
+ fprintf(fp, "\n");
+ for (size_t i = 0; i < level; ++i)
+ fprintf(fp, "|\t");
+ fprintf(fp, "Var. %zu > %lg", splitDim, splitValue);
+
+ right->WriteTree(fp, level + 1);
+
+ fprintf(fp, "\n");
+ for (size_t i = 0; i < level; ++i)
+ fprintf(fp, "|\t");
+ fprintf(fp, "Var. %zu <= %lg ", splitDim, splitValue);
+
+ left->WriteTree(fp, level);
+ }
+ else // If we are a leaf...
+ {
+ fprintf(fp, ": f(x)=%lg", std::exp(std::log(ratio) - logVolume));
+ if (bucketTag != -1)
+ fprintf(fp, " BT:%d", bucketTag);
+ }
+}
+
+
+// Index the buckets for possible usage later.
+int DTree::TagTree(const int tag)
+{
+ if (subtreeLeaves == 1)
+ {
+ // Only label leaves.
+ bucketTag = tag;
+ return (tag + 1);
+ }
+ else
+ {
+ return right->TagTree(left->TagTree(tag));
+ }
+}
+
+
+int DTree::FindBucket(const arma::vec& query) const
+{
+ Log::Assert(query.n_elem == maxVals.n_elem);
+
+ if (subtreeLeaves == 1) // If we are a leaf...
+ {
+ return bucketTag;
+ }
+ else if (query[splitDim] <= splitValue)
+ {
+ // If left subtree, go to left child.
+ return left->FindBucket(query);
+ }
+ else
+ {
+ // If right subtree, go to right child.
+ return right->FindBucket(query);
+ }
+}
+
+
+void DTree::ComputeVariableImportance(arma::vec& importances) const
+{
+ // Clear and set to right size.
+ importances.zeros(maxVals.n_elem);
+
+ std::stack<const DTree*> nodes;
+ nodes.push(this);
+
+ while(!nodes.empty())
+ {
+ const DTree& curNode = *nodes.top();
+ nodes.pop();
+
+ if (curNode.subtreeLeaves == 1)
+ continue; // Do nothing for leaves.
+
+ // The way to do this entirely in log-space is (at this time) somewhat
+ // unclear. So this risks overflow.
+ importances[curNode.SplitDim()] += (-std::exp(curNode.LogNegError()) -
+ (-std::exp(curNode.Left()->LogNegError()) +
+ -std::exp(curNode.Right()->LogNegError())));
+
+ nodes.push(curNode.Left());
+ nodes.push(curNode.Right());
+ }
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dtree.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/det/dtree.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dtree.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,326 +0,0 @@
-/**
- * @file dtree.hpp
- * @author Parikshit Ram (pram at cc.gatech.edu)
- *
- * Density Estimation Tree class
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#ifndef __MLPACK_METHODS_DET_DTREE_HPP
-#define __MLPACK_METHODS_DET_DTREE_HPP
-
-#include <assert.h>
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace det /** Density Estimation Trees */ {
-
-/**
- * A density estimation tree is similar to both a decision tree and a space
- * partitioning tree (like a kd-tree). Each leaf represents a constant-density
- * hyper-rectangle. The tree is constructed in such a way as to minimize the
- * integrated square error between the probability distribution of the tree and
- * the observed probability distribution of the data. Because the tree is
- * similar to a decision tree, the density estimation tree can provide very fast
- * density estimates for a given point.
- *
- * For more information, see the following paper:
- *
- * @code
- * @incollection{ram2011,
- * author = {Ram, Parikshit and Gray, Alexander G.},
- * title = {Density estimation trees},
- * booktitle = {{Proceedings of the 17th ACM SIGKDD International Conference
- * on Knowledge Discovery and Data Mining}},
- * series = {KDD '11},
- * year = {2011},
- * pages = {627--635}
- * }
- * @endcode
- */
-class DTree
-{
- public:
- /**
- * Create an empty density estimation tree.
- */
- DTree();
-
- /**
- * Create a density estimation tree with the given bounds and the given number
- * of total points. Children will not be created.
- *
- * @param maxVals Maximum values of the bounding box.
- * @param minVals Minimum values of the bounding box.
- * @param totalPoints Total number of points in the dataset.
- */
- DTree(const arma::vec& maxVals,
- const arma::vec& minVals,
- const size_t totalPoints);
-
- /**
- * Create a density estimation tree on the given data. Children will be
- * created following the procedure outlined in the paper. The data will be
- * modified; it will be reordered similar to the way BinarySpaceTree modifies
- * datasets.
- *
- * @param data Dataset to build tree on.
- */
- DTree(arma::mat& data);
-
- /**
- * Create a child node of a density estimation tree given the bounding box
- * specified by maxVals and minVals, using the size given in start and end and
- * the specified error. Children of this node will not be created
- * recursively.
- *
- * @param maxVals Upper bound of bounding box.
- * @param minVals Lower bound of bounding box.
- * @param start Start of points represented by this node in the data matrix.
- * @param end End of points represented by this node in the data matrix.
- * @param error log-negative error of this node.
- */
- DTree(const arma::vec& maxVals,
- const arma::vec& minVals,
- const size_t start,
- const size_t end,
- const double logNegError);
-
- /**
- * Create a child node of a density estimation tree given the bounding box
- * specified by maxVals and minVals, using the size given in start and end,
- * and calculating the error with the total number of points given. Children
- * of this node will not be created recursively.
- *
- * @param maxVals Upper bound of bounding box.
- * @param minVals Lower bound of bounding box.
- * @param start Start of points represented by this node in the data matrix.
- * @param end End of points represented by this node in the data matrix.
- */
- DTree(const arma::vec& maxVals,
- const arma::vec& minVals,
- const size_t totalPoints,
- const size_t start,
- const size_t end);
-
- //! Clean up memory allocated by the tree.
- ~DTree();
-
- /**
- * Greedily expand the tree. The points in the dataset will be reordered
- * during tree growth.
- *
- * @param data Dataset to build tree on.
- * @param oldFromNew Mappings from old points to new points.
- * @param useVolReg If true, volume regularization is used.
- * @param maxLeafSize Maximum size of a leaf.
- * @param minLeafSize Minimum size of a leaf.
- */
- double Grow(arma::mat& data,
- arma::Col<size_t>& oldFromNew,
- const bool useVolReg = false,
- const size_t maxLeafSize = 10,
- const size_t minLeafSize = 5);
-
- /**
- * Perform alpha pruning on a tree. Returns the new value of alpha.
- *
- * @param oldAlpha Old value of alpha.
- * @param points Total number of points in dataset.
- * @param useVolReg If true, volume regularization is used.
- * @return New value of alpha.
- */
- double PruneAndUpdate(const double oldAlpha,
- const size_t points,
- const bool useVolReg = false);
-
- /**
- * Compute the logarithm of the density estimate of a given query point.
- *
- * @param query Point to estimate density of.
- */
- double ComputeValue(const arma::vec& query) const;
-
- /**
- * Print the tree in a depth-first manner (this function is called
- * recursively).
- *
- * @param fp File to write the tree to.
- * @param level Level of the tree (should start at 0).
- */
- void WriteTree(FILE *fp, const size_t level = 0) const;
-
- /**
- * Index the buckets for possible usage later; this results in every leaf in
- * the tree having a specific tag (accessible with BucketTag()). This
- * function calls itself recursively.
- *
- * @param tag Tag for the next leaf; leave at 0 for the initial call.
- */
- int TagTree(const int tag = 0);
-
- /**
- * Return the tag of the leaf containing the query. This is useful for
- * generating class memberships.
- *
- * @param query Query to search for.
- */
- int FindBucket(const arma::vec& query) const;
-
- /**
- * Compute the variable importance of each dimension in the learned tree.
- *
- * @param importances Vector to store the calculated importances in.
- */
- void ComputeVariableImportance(arma::vec& importances) const;
-
- /**
- * Compute the log-negative-error for this point, given the total number of
- * points in the dataset.
- *
- * @param totalPoints Total number of points in the dataset.
- */
- double LogNegativeError(const size_t totalPoints) const;
-
- /**
- * Return whether a query point is within the range of this node.
- */
- bool WithinRange(const arma::vec& query) const;
-
- private:
- // The indices in the complete set of points
- // (after all forms of swapping in the original data
- // matrix to align all the points in a node
- // consecutively in the matrix. The 'old_from_new' array
- // maps the points back to their original indices.
-
- //! The index of the first point in the dataset contained in this node (and
- //! its children).
- size_t start;
- //! The index of the last point in the dataset contained in this node (and its
- //! children).
- size_t end;
-
- //! Upper half of bounding box for this node.
- arma::vec maxVals;
- //! Lower half of bounding box for this node.
- arma::vec minVals;
-
- //! The splitting dimension for this node.
- size_t splitDim;
-
- //! The split value on the splitting dimension for this node.
- double splitValue;
-
- //! log-negative-L2-error of the node.
- double logNegError;
-
- //! Sum of the error of the leaves of the subtree.
- double subtreeLeavesLogNegError;
-
- //! Number of leaves of the subtree.
- size_t subtreeLeaves;
-
- //! If true, this node is the root of the tree.
- bool root;
-
- //! Ratio of the number of points in the node to the total number of points.
- double ratio;
-
- //! The logarithm of the volume of the node.
- double logVolume;
-
- //! The tag for the leaf, used for hashing points.
- int bucketTag;
-
- //! Upper part of alpha sum; used for pruning.
- double alphaUpper;
-
- //! The left child.
- DTree* left;
- //! The right child.
- DTree* right;
-
- public:
- //! Return the starting index of points contained in this node.
- size_t Start() const { return start; }
- //! Return the first index of a point not contained in this node.
- size_t End() const { return end; }
- //! Return the split dimension of this node.
- size_t SplitDim() const { return splitDim; }
- //! Return the split value of this node.
- double SplitValue() const { return splitValue; }
- //! Return the log negative error of this node.
- double LogNegError() const { return logNegError; }
- //! Return the log negative error of all descendants of this node.
- double SubtreeLeavesLogNegError() const { return subtreeLeavesLogNegError; }
- //! Return the number of leaves which are descendants of this node.
- size_t SubtreeLeaves() const { return subtreeLeaves; }
- //! Return the ratio of points in this node to the points in the whole
- //! dataset.
- double Ratio() const { return ratio; }
- //! Return the inverse of the volume of this node.
- double LogVolume() const { return logVolume; }
- //! Return the left child.
- DTree* Left() const { return left; }
- //! Return the right child.
- DTree* Right() const { return right; }
- //! Return whether or not this is the root of the tree.
- bool Root() const { return root; }
- //! Return the upper part of the alpha sum.
- double AlphaUpper() const { return alphaUpper; }
-
- //! Return the maximum values.
- const arma::vec& MaxVals() const { return maxVals; }
- //! Modify the maximum values.
- arma::vec& MaxVals() { return maxVals; }
-
- //! Return the minimum values.
- const arma::vec& MinVals() const { return minVals; }
- //! Modify the minimum values.
- arma::vec& MinVals() { return minVals; }
-
- private:
-
- // Utility methods.
-
- /**
- * Find the dimension to split on.
- */
- bool FindSplit(const arma::mat& data,
- size_t& splitDim,
- double& splitValue,
- double& leftError,
- double& rightError,
- const size_t maxLeafSize = 10,
- const size_t minLeafSize = 5) const;
-
- /**
- * Split the data, returning the number of points left of the split.
- */
- size_t SplitData(arma::mat& data,
- const size_t splitDim,
- const double splitValue,
- arma::Col<size_t>& oldFromNew) const;
-
-};
-
-}; // namespace det
-}; // namespace mlpack
-
-#endif // __MLPACK_METHODS_DET_DTREE_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dtree.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/det/dtree.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dtree.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/det/dtree.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,326 @@
+/**
+ * @file dtree.hpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * Density Estimation Tree class
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef __MLPACK_METHODS_DET_DTREE_HPP
+#define __MLPACK_METHODS_DET_DTREE_HPP
+
+#include <assert.h>
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace det /** Density Estimation Trees */ {
+
+/**
+ * A density estimation tree is similar to both a decision tree and a space
+ * partitioning tree (like a kd-tree). Each leaf represents a constant-density
+ * hyper-rectangle. The tree is constructed in such a way as to minimize the
+ * integrated square error between the probability distribution of the tree and
+ * the observed probability distribution of the data. Because the tree is
+ * similar to a decision tree, the density estimation tree can provide very fast
+ * density estimates for a given point.
+ *
+ * For more information, see the following paper:
+ *
+ * @code
+ * @incollection{ram2011,
+ * author = {Ram, Parikshit and Gray, Alexander G.},
+ * title = {Density estimation trees},
+ * booktitle = {{Proceedings of the 17th ACM SIGKDD International Conference
+ * on Knowledge Discovery and Data Mining}},
+ * series = {KDD '11},
+ * year = {2011},
+ * pages = {627--635}
+ * }
+ * @endcode
+ */
+class DTree
+{
+ public:
+ /**
+ * Create an empty density estimation tree.
+ */
+ DTree();
+
+ /**
+ * Create a density estimation tree with the given bounds and the given number
+ * of total points. Children will not be created.
+ *
+ * @param maxVals Maximum values of the bounding box.
+ * @param minVals Minimum values of the bounding box.
+ * @param totalPoints Total number of points in the dataset.
+ */
+ DTree(const arma::vec& maxVals,
+ const arma::vec& minVals,
+ const size_t totalPoints);
+
+ /**
+ * Create a density estimation tree on the given data. Children will be
+ * created following the procedure outlined in the paper. The data will be
+ * modified; it will be reordered similar to the way BinarySpaceTree modifies
+ * datasets.
+ *
+ * @param data Dataset to build tree on.
+ */
+ DTree(arma::mat& data);
+
+ /**
+ * Create a child node of a density estimation tree given the bounding box
+ * specified by maxVals and minVals, using the size given in start and end and
+ * the specified error. Children of this node will not be created
+ * recursively.
+ *
+ * @param maxVals Upper bound of bounding box.
+ * @param minVals Lower bound of bounding box.
+ * @param start Start of points represented by this node in the data matrix.
+ * @param end End of points represented by this node in the data matrix.
+ * @param error log-negative error of this node.
+ */
+ DTree(const arma::vec& maxVals,
+ const arma::vec& minVals,
+ const size_t start,
+ const size_t end,
+ const double logNegError);
+
+ /**
+ * Create a child node of a density estimation tree given the bounding box
+ * specified by maxVals and minVals, using the size given in start and end,
+ * and calculating the error with the total number of points given. Children
+ * of this node will not be created recursively.
+ *
+ * @param maxVals Upper bound of bounding box.
+ * @param minVals Lower bound of bounding box.
+ * @param start Start of points represented by this node in the data matrix.
+ * @param end End of points represented by this node in the data matrix.
+ */
+ DTree(const arma::vec& maxVals,
+ const arma::vec& minVals,
+ const size_t totalPoints,
+ const size_t start,
+ const size_t end);
+
+ //! Clean up memory allocated by the tree.
+ ~DTree();
+
+ /**
+ * Greedily expand the tree. The points in the dataset will be reordered
+ * during tree growth.
+ *
+ * @param data Dataset to build tree on.
+ * @param oldFromNew Mappings from old points to new points.
+ * @param useVolReg If true, volume regularization is used.
+ * @param maxLeafSize Maximum size of a leaf.
+ * @param minLeafSize Minimum size of a leaf.
+ */
+ double Grow(arma::mat& data,
+ arma::Col<size_t>& oldFromNew,
+ const bool useVolReg = false,
+ const size_t maxLeafSize = 10,
+ const size_t minLeafSize = 5);
+
+ /**
+ * Perform alpha pruning on a tree. Returns the new value of alpha.
+ *
+ * @param oldAlpha Old value of alpha.
+ * @param points Total number of points in dataset.
+ * @param useVolReg If true, volume regularization is used.
+ * @return New value of alpha.
+ */
+ double PruneAndUpdate(const double oldAlpha,
+ const size_t points,
+ const bool useVolReg = false);
+
+ /**
+ * Compute the logarithm of the density estimate of a given query point.
+ *
+ * @param query Point to estimate density of.
+ */
+ double ComputeValue(const arma::vec& query) const;
+
+ /**
+ * Print the tree in a depth-first manner (this function is called
+ * recursively).
+ *
+ * @param fp File to write the tree to.
+ * @param level Level of the tree (should start at 0).
+ */
+ void WriteTree(FILE *fp, const size_t level = 0) const;
+
+ /**
+ * Index the buckets for possible usage later; this results in every leaf in
+ * the tree having a specific tag (accessible with BucketTag()). This
+ * function calls itself recursively.
+ *
+ * @param tag Tag for the next leaf; leave at 0 for the initial call.
+ */
+ int TagTree(const int tag = 0);
+
+ /**
+ * Return the tag of the leaf containing the query. This is useful for
+ * generating class memberships.
+ *
+ * @param query Query to search for.
+ */
+ int FindBucket(const arma::vec& query) const;
+
+ /**
+ * Compute the variable importance of each dimension in the learned tree.
+ *
+ * @param importances Vector to store the calculated importances in.
+ */
+ void ComputeVariableImportance(arma::vec& importances) const;
+
+ /**
+ * Compute the log-negative-error for this point, given the total number of
+ * points in the dataset.
+ *
+ * @param totalPoints Total number of points in the dataset.
+ */
+ double LogNegativeError(const size_t totalPoints) const;
+
+ /**
+ * Return whether a query point is within the range of this node.
+ */
+ bool WithinRange(const arma::vec& query) const;
+
+ private:
+ // The indices in the complete set of points
+ // (after all forms of swapping in the original data
+ // matrix to align all the points in a node
+ // consecutively in the matrix. The 'old_from_new' array
+ // maps the points back to their original indices.
+
+ //! The index of the first point in the dataset contained in this node (and
+ //! its children).
+ size_t start;
+ //! The index of the last point in the dataset contained in this node (and its
+ //! children).
+ size_t end;
+
+ //! Upper half of bounding box for this node.
+ arma::vec maxVals;
+ //! Lower half of bounding box for this node.
+ arma::vec minVals;
+
+ //! The splitting dimension for this node.
+ size_t splitDim;
+
+ //! The split value on the splitting dimension for this node.
+ double splitValue;
+
+ //! log-negative-L2-error of the node.
+ double logNegError;
+
+ //! Sum of the error of the leaves of the subtree.
+ double subtreeLeavesLogNegError;
+
+ //! Number of leaves of the subtree.
+ size_t subtreeLeaves;
+
+ //! If true, this node is the root of the tree.
+ bool root;
+
+ //! Ratio of the number of points in the node to the total number of points.
+ double ratio;
+
+ //! The logarithm of the volume of the node.
+ double logVolume;
+
+ //! The tag for the leaf, used for hashing points.
+ int bucketTag;
+
+ //! Upper part of alpha sum; used for pruning.
+ double alphaUpper;
+
+ //! The left child.
+ DTree* left;
+ //! The right child.
+ DTree* right;
+
+ public:
+ //! Return the starting index of points contained in this node.
+ size_t Start() const { return start; }
+ //! Return the first index of a point not contained in this node.
+ size_t End() const { return end; }
+ //! Return the split dimension of this node.
+ size_t SplitDim() const { return splitDim; }
+ //! Return the split value of this node.
+ double SplitValue() const { return splitValue; }
+ //! Return the log negative error of this node.
+ double LogNegError() const { return logNegError; }
+ //! Return the log negative error of all descendants of this node.
+ double SubtreeLeavesLogNegError() const { return subtreeLeavesLogNegError; }
+ //! Return the number of leaves which are descendants of this node.
+ size_t SubtreeLeaves() const { return subtreeLeaves; }
+ //! Return the ratio of points in this node to the points in the whole
+ //! dataset.
+ double Ratio() const { return ratio; }
+ //! Return the inverse of the volume of this node.
+ double LogVolume() const { return logVolume; }
+ //! Return the left child.
+ DTree* Left() const { return left; }
+ //! Return the right child.
+ DTree* Right() const { return right; }
+ //! Return whether or not this is the root of the tree.
+ bool Root() const { return root; }
+ //! Return the upper part of the alpha sum.
+ double AlphaUpper() const { return alphaUpper; }
+
+ //! Return the maximum values.
+ const arma::vec& MaxVals() const { return maxVals; }
+ //! Modify the maximum values.
+ arma::vec& MaxVals() { return maxVals; }
+
+ //! Return the minimum values.
+ const arma::vec& MinVals() const { return minVals; }
+ //! Modify the minimum values.
+ arma::vec& MinVals() { return minVals; }
+
+ private:
+
+ // Utility methods.
+
+ /**
+ * Find the dimension to split on.
+ */
+ bool FindSplit(const arma::mat& data,
+ size_t& splitDim,
+ double& splitValue,
+ double& leftError,
+ double& rightError,
+ const size_t maxLeafSize = 10,
+ const size_t minLeafSize = 5) const;
+
+ /**
+ * Split the data, returning the number of points left of the split.
+ */
+ size_t SplitData(arma::mat& data,
+ const size_t splitDim,
+ const double splitValue,
+ arma::Col<size_t>& oldFromNew) const;
+
+};
+
+}; // namespace det
+}; // namespace mlpack
+
+#endif // __MLPACK_METHODS_DET_DTREE_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/emst/dtb.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,265 +0,0 @@
-/**
- * @file dtb.hpp
- * @author Bill March (march at gatech.edu)
- *
- * Contains an implementation of the DualTreeBoruvka algorithm for finding a
- * Euclidean Minimum Spanning Tree using the kd-tree data structure.
- *
- * @code
- * @inproceedings{
- * author = {March, W.B., Ram, P., and Gray, A.G.},
- * title = {{Fast Euclidean Minimum Spanning Tree: Algorithm, Analysis,
- * Applications.}},
- * booktitle = {Proceedings of the 16th ACM SIGKDD International Conference
- * on Knowledge Discovery and Data Mining}
- * series = {KDD 2010},
- * year = {2010}
- * }
- * @endcode
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_EMST_DTB_HPP
-#define __MLPACK_METHODS_EMST_DTB_HPP
-
-#include "edge_pair.hpp"
-
-#include <mlpack/core.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-
-#include <mlpack/core/tree/binary_space_tree.hpp>
-
-namespace mlpack {
-namespace emst /** Euclidean Minimum Spanning Trees. */ {
-
-/**
- * A statistic for use with MLPACK trees, which stores the upper bound on
- * distance to nearest neighbors and the component which this node belongs to.
- */
-class DTBStat
-{
- private:
- //! Upper bound on the distance to the nearest neighbor of any point in this
- //! node.
- double maxNeighborDistance;
- //! The index of the component that all points in this node belong to. This
- //! is the same index returned by UnionFind for all points in this node. If
- //! points in this node are in different components, this value will be
- //! negative.
- int componentMembership;
-
- public:
- /**
- * A generic initializer. Sets the maximum neighbor distance to its default,
- * and the component membership to -1 (no component).
- */
- DTBStat();
-
- /**
- * This is called when a node is finished initializing. We set the maximum
- * neighbor distance to its default, and if possible, we set the component
- * membership of the node (if it has only one point and no children).
- *
- * @param node Node that has been finished.
- */
- template<typename TreeType>
- DTBStat(const TreeType& node);
-
- //! Get the maximum neighbor distance.
- double MaxNeighborDistance() const { return maxNeighborDistance; }
- //! Modify the maximum neighbor distance.
- double& MaxNeighborDistance() { return maxNeighborDistance; }
-
- //! Get the component membership of this node.
- int ComponentMembership() const { return componentMembership; }
- //! Modify the component membership of this node.
- int& ComponentMembership() { return componentMembership; }
-
-}; // class DTBStat
-
-/**
- * Performs the MST calculation using the Dual-Tree Boruvka algorithm, using any
- * type of tree.
- *
- * For more information on the algorithm, see the following citation:
- *
- * @code
- * @inproceedings{
- * author = {March, W.B., Ram, P., and Gray, A.G.},
- * title = {{Fast Euclidean Minimum Spanning Tree: Algorithm, Analysis,
- * Applications.}},
- * booktitle = {Proceedings of the 16th ACM SIGKDD International Conference
- * on Knowledge Discovery and Data Mining}
- * series = {KDD 2010},
- * year = {2010}
- * }
- * @endcode
- *
- * General usage of this class might be like this:
- *
- * @code
- * extern arma::mat data; // We want to find the MST of this dataset.
- * DualTreeBoruvka<> dtb(data); // Create the tree with default options.
- *
- * // Find the MST.
- * arma::mat mstResults;
- * dtb.ComputeMST(mstResults);
- * @endcode
- *
- * More advanced usage of the class can use different types of trees, pass in an
- * already-built tree, or compute the MST using the O(n^2) naive algorithm.
- *
- * @tparam MetricType The metric to use. IMPORTANT: this hasn't really been
- * tested with anything other than the L2 metric, so user beware. Note that the
- * tree type needs to compute bounds using the same metric as the type
- * specified here.
- * @tparam TreeType Type of tree to use. Should use DTBStat as a statistic.
- */
-template<
- typename MetricType = metric::SquaredEuclideanDistance,
- typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>, DTBStat>
->
-class DualTreeBoruvka
-{
- private:
- //! Copy of the data (if necessary).
- typename TreeType::Mat dataCopy;
- //! Reference to the data (this is what should be used for accessing data).
- typename TreeType::Mat& data;
-
- //! Pointer to the root of the tree.
- TreeType* tree;
- //! Indicates whether or not we "own" the tree.
- bool ownTree;
-
- //! Indicates whether or not O(n^2) naive mode will be used.
- bool naive;
-
- //! Edges.
- std::vector<EdgePair> edges; // We must use vector with non-numerical types.
-
- //! Connections.
- UnionFind connections;
-
- //! Permutations of points during tree building.
- std::vector<size_t> oldFromNew;
- //! List of edge nodes.
- arma::Col<size_t> neighborsInComponent;
- //! List of edge nodes.
- arma::Col<size_t> neighborsOutComponent;
- //! List of edge distances.
- arma::vec neighborsDistances;
-
- //! Total distance of the tree.
- double totalDist;
-
- //! The metric
- MetricType metric;
-
- // For sorting the edge list after the computation.
- struct SortEdgesHelper
- {
- bool operator()(const EdgePair& pairA, const EdgePair& pairB)
- {
- return (pairA.Distance() < pairB.Distance());
- }
- } SortFun;
-
- public:
- /**
- * Create the tree from the given dataset. This copies the dataset to an
- * internal copy, because tree-building modifies the dataset.
- *
- * @param data Dataset to build a tree for.
- * @param naive Whether the computation should be done in O(n^2) naive mode.
- * @param leafSize The leaf size to be used during tree construction.
- */
- DualTreeBoruvka(const typename TreeType::Mat& dataset,
- const bool naive = false,
- const size_t leafSize = 1,
- const MetricType metric = MetricType());
-
- /**
- * Create the DualTreeBoruvka object with an already initialized tree. This
- * will not copy the dataset, and can save a little processing power. Naive
- * mode is not available as an option for this constructor; instead, to run
- * naive computation, construct a tree with all the points in one leaf (i.e.
- * leafSize = number of points).
- *
- * @note
- * Because tree-building (at least with BinarySpaceTree) modifies the ordering
- * of a matrix, be sure you pass the modified matrix to this object! In
- * addition, mapping the points of the matrix back to their original indices
- * is not done when this constructor is used.
- * @endnote
- *
- * @param tree Pre-built tree.
- * @param dataset Dataset corresponding to the pre-built tree.
- */
- DualTreeBoruvka(TreeType* tree, const typename TreeType::Mat& dataset,
- const MetricType metric = MetricType());
-
- /**
- * Delete the tree, if it was created inside the object.
- */
- ~DualTreeBoruvka();
-
- /**
- * Iteratively find the nearest neighbor of each component until the MST is
- * complete. The results will be a 3xN matrix (with N equal to the number of
- * edges in the minimum spanning tree). The first row will contain the lesser
- * index of the edge; the second row will contain the greater index of the
- * edge; and the third row will contain the distance between the two edges.
- *
- * @param results Matrix which results will be stored in.
- */
- void ComputeMST(arma::mat& results);
-
- private:
- /**
- * Adds a single edge to the edge list
- */
- void AddEdge(const size_t e1, const size_t e2, const double distance);
-
- /**
- * Adds all the edges found in one iteration to the list of neighbors.
- */
- void AddAllEdges();
-
- /**
- * Unpermute the edge list and output it to results.
- */
- void EmitResults(arma::mat& results);
-
- /**
- * This function resets the values in the nodes of the tree nearest neighbor
- * distance, and checks for fully connected nodes.
- */
- void CleanupHelper(TreeType* tree);
-
- /**
- * The values stored in the tree must be reset on each iteration.
- */
- void Cleanup();
-
-}; // class DualTreeBoruvka
-
-}; // namespace emst
-}; // namespace mlpack
-
-#include "dtb_impl.hpp"
-
-#endif // __MLPACK_METHODS_EMST_DTB_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/emst/dtb.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,265 @@
+/**
+ * @file dtb.hpp
+ * @author Bill March (march at gatech.edu)
+ *
+ * Contains an implementation of the DualTreeBoruvka algorithm for finding a
+ * Euclidean Minimum Spanning Tree using the kd-tree data structure.
+ *
+ * @code
+ * @inproceedings{
+ * author = {March, W.B., Ram, P., and Gray, A.G.},
+ * title = {{Fast Euclidean Minimum Spanning Tree: Algorithm, Analysis,
+ * Applications.}},
+ * booktitle = {Proceedings of the 16th ACM SIGKDD International Conference
+ * on Knowledge Discovery and Data Mining}
+ * series = {KDD 2010},
+ * year = {2010}
+ * }
+ * @endcode
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_EMST_DTB_HPP
+#define __MLPACK_METHODS_EMST_DTB_HPP
+
+#include "edge_pair.hpp"
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+
+#include <mlpack/core/tree/binary_space_tree.hpp>
+
+namespace mlpack {
+namespace emst /** Euclidean Minimum Spanning Trees. */ {
+
+/**
+ * A statistic for use with MLPACK trees, which stores the upper bound on
+ * distance to nearest neighbors and the component which this node belongs to.
+ */
+class DTBStat
+{
+ private:
+ //! Upper bound on the distance to the nearest neighbor of any point in this
+ //! node.
+ double maxNeighborDistance;
+ //! The index of the component that all points in this node belong to. This
+ //! is the same index returned by UnionFind for all points in this node. If
+ //! points in this node are in different components, this value will be
+ //! negative.
+ int componentMembership;
+
+ public:
+ /**
+ * A generic initializer. Sets the maximum neighbor distance to its default,
+ * and the component membership to -1 (no component).
+ */
+ DTBStat();
+
+ /**
+ * This is called when a node is finished initializing. We set the maximum
+ * neighbor distance to its default, and if possible, we set the component
+ * membership of the node (if it has only one point and no children).
+ *
+ * @param node Node that has been finished.
+ */
+ template<typename TreeType>
+ DTBStat(const TreeType& node);
+
+ //! Get the maximum neighbor distance.
+ double MaxNeighborDistance() const { return maxNeighborDistance; }
+ //! Modify the maximum neighbor distance.
+ double& MaxNeighborDistance() { return maxNeighborDistance; }
+
+ //! Get the component membership of this node.
+ int ComponentMembership() const { return componentMembership; }
+ //! Modify the component membership of this node.
+ int& ComponentMembership() { return componentMembership; }
+
+}; // class DTBStat
+
+/**
+ * Performs the MST calculation using the Dual-Tree Boruvka algorithm, using any
+ * type of tree.
+ *
+ * For more information on the algorithm, see the following citation:
+ *
+ * @code
+ * @inproceedings{
+ * author = {March, W.B., Ram, P., and Gray, A.G.},
+ * title = {{Fast Euclidean Minimum Spanning Tree: Algorithm, Analysis,
+ * Applications.}},
+ * booktitle = {Proceedings of the 16th ACM SIGKDD International Conference
+ * on Knowledge Discovery and Data Mining}
+ * series = {KDD 2010},
+ * year = {2010}
+ * }
+ * @endcode
+ *
+ * General usage of this class might be like this:
+ *
+ * @code
+ * extern arma::mat data; // We want to find the MST of this dataset.
+ * DualTreeBoruvka<> dtb(data); // Create the tree with default options.
+ *
+ * // Find the MST.
+ * arma::mat mstResults;
+ * dtb.ComputeMST(mstResults);
+ * @endcode
+ *
+ * More advanced usage of the class can use different types of trees, pass in an
+ * already-built tree, or compute the MST using the O(n^2) naive algorithm.
+ *
+ * @tparam MetricType The metric to use. IMPORTANT: this hasn't really been
+ * tested with anything other than the L2 metric, so user beware. Note that the
+ * tree type needs to compute bounds using the same metric as the type
+ * specified here.
+ * @tparam TreeType Type of tree to use. Should use DTBStat as a statistic.
+ */
+template<
+ typename MetricType = metric::SquaredEuclideanDistance,
+ typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>, DTBStat>
+>
+class DualTreeBoruvka
+{
+ private:
+ //! Copy of the data (if necessary).
+ typename TreeType::Mat dataCopy;
+ //! Reference to the data (this is what should be used for accessing data).
+ typename TreeType::Mat& data;
+
+ //! Pointer to the root of the tree.
+ TreeType* tree;
+ //! Indicates whether or not we "own" the tree.
+ bool ownTree;
+
+ //! Indicates whether or not O(n^2) naive mode will be used.
+ bool naive;
+
+ //! Edges.
+ std::vector<EdgePair> edges; // We must use vector with non-numerical types.
+
+ //! Connections.
+ UnionFind connections;
+
+ //! Permutations of points during tree building.
+ std::vector<size_t> oldFromNew;
+ //! List of edge nodes.
+ arma::Col<size_t> neighborsInComponent;
+ //! List of edge nodes.
+ arma::Col<size_t> neighborsOutComponent;
+ //! List of edge distances.
+ arma::vec neighborsDistances;
+
+ //! Total distance of the tree.
+ double totalDist;
+
+ //! The metric
+ MetricType metric;
+
+ // For sorting the edge list after the computation.
+ struct SortEdgesHelper
+ {
+ bool operator()(const EdgePair& pairA, const EdgePair& pairB)
+ {
+ return (pairA.Distance() < pairB.Distance());
+ }
+ } SortFun;
+
+ public:
+ /**
+ * Create the tree from the given dataset. This copies the dataset to an
+ * internal copy, because tree-building modifies the dataset.
+ *
+ * @param data Dataset to build a tree for.
+ * @param naive Whether the computation should be done in O(n^2) naive mode.
+ * @param leafSize The leaf size to be used during tree construction.
+ */
+ DualTreeBoruvka(const typename TreeType::Mat& dataset,
+ const bool naive = false,
+ const size_t leafSize = 1,
+ const MetricType metric = MetricType());
+
+ /**
+ * Create the DualTreeBoruvka object with an already initialized tree. This
+ * will not copy the dataset, and can save a little processing power. Naive
+ * mode is not available as an option for this constructor; instead, to run
+ * naive computation, construct a tree with all the points in one leaf (i.e.
+ * leafSize = number of points).
+ *
+ * @note
+ * Because tree-building (at least with BinarySpaceTree) modifies the ordering
+ * of a matrix, be sure you pass the modified matrix to this object! In
+ * addition, mapping the points of the matrix back to their original indices
+ * is not done when this constructor is used.
+ * @endnote
+ *
+ * @param tree Pre-built tree.
+ * @param dataset Dataset corresponding to the pre-built tree.
+ */
+ DualTreeBoruvka(TreeType* tree, const typename TreeType::Mat& dataset,
+ const MetricType metric = MetricType());
+
+ /**
+ * Delete the tree, if it was created inside the object.
+ */
+ ~DualTreeBoruvka();
+
+ /**
+ * Iteratively find the nearest neighbor of each component until the MST is
+ * complete. The results will be a 3xN matrix (with N equal to the number of
+ * edges in the minimum spanning tree). The first row will contain the lesser
+ * index of the edge; the second row will contain the greater index of the
+ * edge; and the third row will contain the distance between the two edges.
+ *
+ * @param results Matrix which results will be stored in.
+ */
+ void ComputeMST(arma::mat& results);
+
+ private:
+ /**
+ * Adds a single edge to the edge list
+ */
+ void AddEdge(const size_t e1, const size_t e2, const double distance);
+
+ /**
+ * Adds all the edges found in one iteration to the list of neighbors.
+ */
+ void AddAllEdges();
+
+ /**
+ * Unpermute the edge list and output it to results.
+ */
+ void EmitResults(arma::mat& results);
+
+ /**
+ * This function resets the values in the nodes of the tree nearest neighbor
+ * distance, and checks for fully connected nodes.
+ */
+ void CleanupHelper(TreeType* tree);
+
+ /**
+ * The values stored in the tree must be reset on each iteration.
+ */
+ void Cleanup();
+
+}; // class DualTreeBoruvka
+
+}; // namespace emst
+}; // namespace mlpack
+
+#include "dtb_impl.hpp"
+
+#endif // __MLPACK_METHODS_EMST_DTB_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/emst/dtb_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,308 +0,0 @@
-/**
- * @file dtb_impl.hpp
- * @author Bill March (march at gatech.edu)
- *
- * Implementation of DTB.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#ifndef __MLPACK_METHODS_EMST_DTB_IMPL_HPP
-#define __MLPACK_METHODS_EMST_DTB_IMPL_HPP
-
-#include "dtb_rules.hpp"
-
-namespace mlpack {
-namespace emst {
-
-// DTBStat
-
-/**
- * A generic initializer.
- */
-DTBStat::DTBStat() : maxNeighborDistance(DBL_MAX), componentMembership(-1)
-{
- // Nothing to do.
-}
-
-/**
- * An initializer for leaves.
- */
-template<typename TreeType>
-DTBStat::DTBStat(const TreeType& node) :
- maxNeighborDistance(DBL_MAX),
- componentMembership(((node.NumPoints() == 1) && (node.NumChildren() == 0)) ?
- node.Point(0) : -1)
-{
- // Nothing to do.
-}
-
-// DualTreeBoruvka
-
-/**
- * Takes in a reference to the data set. Copies the data, builds the tree,
- * and initializes all of the member variables.
- */
-template<typename MetricType, typename TreeType>
-DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
- const typename TreeType::Mat& dataset,
- const bool naive,
- const size_t leafSize,
- const MetricType metric) :
- dataCopy(dataset),
- data(dataCopy), // The reference points to our copy of the data.
- ownTree(true),
- naive(naive),
- connections(data.n_cols),
- totalDist(0.0),
- metric(metric)
-{
- Timer::Start("emst/tree_building");
-
- if (!naive)
- {
- // Default leaf size is 1; this gives the best pruning, empirically. Use
- // leaf_size = 1 unless space is a big concern.
- tree = new TreeType(data, oldFromNew, leafSize);
- }
- else
- {
- // Naive tree holds all data in one leaf.
- tree = new TreeType(data, oldFromNew, data.n_cols);
- }
-
- Timer::Stop("emst/tree_building");
-
- edges.reserve(data.n_cols - 1); // Set size.
-
- neighborsInComponent.set_size(data.n_cols);
- neighborsOutComponent.set_size(data.n_cols);
- neighborsDistances.set_size(data.n_cols);
- neighborsDistances.fill(DBL_MAX);
-} // Constructor
-
-template<typename MetricType, typename TreeType>
-DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
- TreeType* tree,
- const typename TreeType::Mat& dataset,
- const MetricType metric) :
- data(dataset),
- tree(tree),
- ownTree(true),
- naive(false),
- connections(data.n_cols),
- totalDist(0.0),
- metric(metric)
-{
- edges.reserve(data.n_cols - 1); // fill with EdgePairs
-
- neighborsInComponent.set_size(data.n_cols);
- neighborsOutComponent.set_size(data.n_cols);
- neighborsDistances.set_size(data.n_cols);
- neighborsDistances.fill(DBL_MAX);
-}
-
-template<typename MetricType, typename TreeType>
-DualTreeBoruvka<MetricType, TreeType>::~DualTreeBoruvka()
-{
- if (ownTree)
- delete tree;
-}
-
-/**
- * Iteratively find the nearest neighbor of each component until the MST is
- * complete.
- */
-template<typename MetricType, typename TreeType>
-void DualTreeBoruvka<MetricType, TreeType>::ComputeMST(arma::mat& results)
-{
- Timer::Start("emst/mst_computation");
-
- totalDist = 0; // Reset distance.
-
- typedef DTBRules<MetricType, TreeType> RuleType;
- RuleType rules(data, connections, neighborsDistances, neighborsInComponent,
- neighborsOutComponent, metric);
-
- while (edges.size() < (data.n_cols - 1))
- {
-
- typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
-
- traverser.Traverse(*tree, *tree);
-
- AddAllEdges();
-
- Cleanup();
-
- Log::Info << edges.size() << " edges found so far.\n";
- }
-
- Timer::Stop("emst/mst_computation");
-
- EmitResults(results);
-
- Log::Info << "Total squared length: " << totalDist << std::endl;
-} // ComputeMST
-
-/**
- * Adds a single edge to the edge list
- */
-template<typename MetricType, typename TreeType>
-void DualTreeBoruvka<MetricType, TreeType>::AddEdge(const size_t e1,
- const size_t e2,
- const double distance)
-{
- Log::Assert((distance >= 0.0),
- "DualTreeBoruvka::AddEdge(): distance cannot be negative.");
-
- if (e1 < e2)
- edges.push_back(EdgePair(e1, e2, distance));
- else
- edges.push_back(EdgePair(e2, e1, distance));
-} // AddEdge
-
-/**
- * Adds all the edges found in one iteration to the list of neighbors.
- */
-template<typename MetricType, typename TreeType>
-void DualTreeBoruvka<MetricType, TreeType>::AddAllEdges()
-{
- for (size_t i = 0; i < data.n_cols; i++)
- {
- size_t component = connections.Find(i);
- size_t inEdge = neighborsInComponent[component];
- size_t outEdge = neighborsOutComponent[component];
- if (connections.Find(inEdge) != connections.Find(outEdge))
- {
- //totalDist = totalDist + dist;
- // changed to make this agree with the cover tree code
- totalDist += sqrt(neighborsDistances[component]);
- AddEdge(inEdge, outEdge, neighborsDistances[component]);
- connections.Union(inEdge, outEdge);
- }
- }
-} // AddAllEdges
-
-/**
- * Unpermute the edge list (if necessary) and output it to results.
- */
-template<typename MetricType, typename TreeType>
-void DualTreeBoruvka<MetricType, TreeType>::EmitResults(arma::mat& results)
-{
- // Sort the edges.
- std::sort(edges.begin(), edges.end(), SortFun);
-
- Log::Assert(edges.size() == data.n_cols - 1);
- results.set_size(3, edges.size());
-
- // Need to unpermute the point labels.
- if (!naive && ownTree)
- {
- for (size_t i = 0; i < (data.n_cols - 1); i++)
- {
- // Make sure the edge list stores the smaller index first to
- // make checking correctness easier
- size_t ind1 = oldFromNew[edges[i].Lesser()];
- size_t ind2 = oldFromNew[edges[i].Greater()];
-
- if (ind1 < ind2)
- {
- edges[i].Lesser() = ind1;
- edges[i].Greater() = ind2;
- }
- else
- {
- edges[i].Lesser() = ind2;
- edges[i].Greater() = ind1;
- }
-
- results(0, i) = edges[i].Lesser();
- results(1, i) = edges[i].Greater();
- results(2, i) = sqrt(edges[i].Distance());
- }
- }
- else
- {
- for (size_t i = 0; i < edges.size(); i++)
- {
- results(0, i) = edges[i].Lesser();
- results(1, i) = edges[i].Greater();
- results(2, i) = sqrt(edges[i].Distance());
- }
- }
-} // EmitResults
-
-/**
- * This function resets the values in the nodes of the tree nearest neighbor
- * distance and checks for fully connected nodes.
- */
-template<typename MetricType, typename TreeType>
-void DualTreeBoruvka<MetricType, TreeType>::CleanupHelper(TreeType* tree)
-{
- tree->Stat().MaxNeighborDistance() = DBL_MAX;
-
- if (!tree->IsLeaf())
- {
- CleanupHelper(tree->Left());
- CleanupHelper(tree->Right());
-
- if ((tree->Left()->Stat().ComponentMembership() >= 0)
- && (tree->Left()->Stat().ComponentMembership() ==
- tree->Right()->Stat().ComponentMembership()))
- {
- tree->Stat().ComponentMembership() =
- tree->Left()->Stat().ComponentMembership();
- }
- }
- else
- {
- size_t newMembership = connections.Find(tree->Begin());
-
- for (size_t i = tree->Begin(); i < tree->End(); ++i)
- {
- if (newMembership != connections.Find(i))
- {
- newMembership = -1;
- Log::Assert(tree->Stat().ComponentMembership() < 0);
- return;
- }
- }
- tree->Stat().ComponentMembership() = newMembership;
- }
-} // CleanupHelper
-
-/**
- * The values stored in the tree must be reset on each iteration.
- */
-template<typename MetricType, typename TreeType>
-void DualTreeBoruvka<MetricType, TreeType>::Cleanup()
-{
- for (size_t i = 0; i < data.n_cols; i++)
- {
- neighborsDistances[i] = DBL_MAX;
- }
-
- if (!naive)
- {
- CleanupHelper(tree);
- }
-}
-
-}; // namespace emst
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/emst/dtb_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,308 @@
+/**
+ * @file dtb_impl.hpp
+ * @author Bill March (march at gatech.edu)
+ *
+ * Implementation of DTB.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef __MLPACK_METHODS_EMST_DTB_IMPL_HPP
+#define __MLPACK_METHODS_EMST_DTB_IMPL_HPP
+
+#include "dtb_rules.hpp"
+
+namespace mlpack {
+namespace emst {
+
+// DTBStat
+
+/**
+ * A generic initializer.
+ */
+DTBStat::DTBStat() : maxNeighborDistance(DBL_MAX), componentMembership(-1)
+{
+ // Nothing to do.
+}
+
+/**
+ * An initializer for leaves.
+ */
+template<typename TreeType>
+DTBStat::DTBStat(const TreeType& node) :
+ maxNeighborDistance(DBL_MAX),
+ componentMembership(((node.NumPoints() == 1) && (node.NumChildren() == 0)) ?
+ node.Point(0) : -1)
+{
+ // Nothing to do.
+}
+
+// DualTreeBoruvka
+
+/**
+ * Takes in a reference to the data set. Copies the data, builds the tree,
+ * and initializes all of the member variables.
+ */
+template<typename MetricType, typename TreeType>
+DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
+ const typename TreeType::Mat& dataset,
+ const bool naive,
+ const size_t leafSize,
+ const MetricType metric) :
+ dataCopy(dataset),
+ data(dataCopy), // The reference points to our copy of the data.
+ ownTree(true),
+ naive(naive),
+ connections(data.n_cols),
+ totalDist(0.0),
+ metric(metric)
+{
+ Timer::Start("emst/tree_building");
+
+ if (!naive)
+ {
+ // Default leaf size is 1; this gives the best pruning, empirically. Use
+ // leaf_size = 1 unless space is a big concern.
+ tree = new TreeType(data, oldFromNew, leafSize);
+ }
+ else
+ {
+ // Naive tree holds all data in one leaf.
+ tree = new TreeType(data, oldFromNew, data.n_cols);
+ }
+
+ Timer::Stop("emst/tree_building");
+
+ edges.reserve(data.n_cols - 1); // Set size.
+
+ neighborsInComponent.set_size(data.n_cols);
+ neighborsOutComponent.set_size(data.n_cols);
+ neighborsDistances.set_size(data.n_cols);
+ neighborsDistances.fill(DBL_MAX);
+} // Constructor
+
+template<typename MetricType, typename TreeType>
+DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
+ TreeType* tree,
+ const typename TreeType::Mat& dataset,
+ const MetricType metric) :
+ data(dataset),
+ tree(tree),
+ ownTree(true),
+ naive(false),
+ connections(data.n_cols),
+ totalDist(0.0),
+ metric(metric)
+{
+ edges.reserve(data.n_cols - 1); // fill with EdgePairs
+
+ neighborsInComponent.set_size(data.n_cols);
+ neighborsOutComponent.set_size(data.n_cols);
+ neighborsDistances.set_size(data.n_cols);
+ neighborsDistances.fill(DBL_MAX);
+}
+
+template<typename MetricType, typename TreeType>
+DualTreeBoruvka<MetricType, TreeType>::~DualTreeBoruvka()
+{
+ if (ownTree)
+ delete tree;
+}
+
+/**
+ * Iteratively find the nearest neighbor of each component until the MST is
+ * complete.
+ */
+template<typename MetricType, typename TreeType>
+void DualTreeBoruvka<MetricType, TreeType>::ComputeMST(arma::mat& results)
+{
+ Timer::Start("emst/mst_computation");
+
+ totalDist = 0; // Reset distance.
+
+ typedef DTBRules<MetricType, TreeType> RuleType;
+ RuleType rules(data, connections, neighborsDistances, neighborsInComponent,
+ neighborsOutComponent, metric);
+
+ while (edges.size() < (data.n_cols - 1))
+ {
+
+ typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+
+ traverser.Traverse(*tree, *tree);
+
+ AddAllEdges();
+
+ Cleanup();
+
+ Log::Info << edges.size() << " edges found so far.\n";
+ }
+
+ Timer::Stop("emst/mst_computation");
+
+ EmitResults(results);
+
+ Log::Info << "Total squared length: " << totalDist << std::endl;
+} // ComputeMST
+
+/**
+ * Adds a single edge to the edge list
+ */
+template<typename MetricType, typename TreeType>
+void DualTreeBoruvka<MetricType, TreeType>::AddEdge(const size_t e1,
+ const size_t e2,
+ const double distance)
+{
+ Log::Assert((distance >= 0.0),
+ "DualTreeBoruvka::AddEdge(): distance cannot be negative.");
+
+ if (e1 < e2)
+ edges.push_back(EdgePair(e1, e2, distance));
+ else
+ edges.push_back(EdgePair(e2, e1, distance));
+} // AddEdge
+
+/**
+ * Adds all the edges found in one iteration to the list of neighbors.
+ */
+template<typename MetricType, typename TreeType>
+void DualTreeBoruvka<MetricType, TreeType>::AddAllEdges()
+{
+ for (size_t i = 0; i < data.n_cols; i++)
+ {
+ size_t component = connections.Find(i);
+ size_t inEdge = neighborsInComponent[component];
+ size_t outEdge = neighborsOutComponent[component];
+ if (connections.Find(inEdge) != connections.Find(outEdge))
+ {
+ //totalDist = totalDist + dist;
+ // changed to make this agree with the cover tree code
+ totalDist += sqrt(neighborsDistances[component]);
+ AddEdge(inEdge, outEdge, neighborsDistances[component]);
+ connections.Union(inEdge, outEdge);
+ }
+ }
+} // AddAllEdges
+
+/**
+ * Unpermute the edge list (if necessary) and output it to results.
+ */
+template<typename MetricType, typename TreeType>
+void DualTreeBoruvka<MetricType, TreeType>::EmitResults(arma::mat& results)
+{
+ // Sort the edges.
+ std::sort(edges.begin(), edges.end(), SortFun);
+
+ Log::Assert(edges.size() == data.n_cols - 1);
+ results.set_size(3, edges.size());
+
+ // Need to unpermute the point labels.
+ if (!naive && ownTree)
+ {
+ for (size_t i = 0; i < (data.n_cols - 1); i++)
+ {
+ // Make sure the edge list stores the smaller index first to
+ // make checking correctness easier
+ size_t ind1 = oldFromNew[edges[i].Lesser()];
+ size_t ind2 = oldFromNew[edges[i].Greater()];
+
+ if (ind1 < ind2)
+ {
+ edges[i].Lesser() = ind1;
+ edges[i].Greater() = ind2;
+ }
+ else
+ {
+ edges[i].Lesser() = ind2;
+ edges[i].Greater() = ind1;
+ }
+
+ results(0, i) = edges[i].Lesser();
+ results(1, i) = edges[i].Greater();
+ results(2, i) = sqrt(edges[i].Distance());
+ }
+ }
+ else
+ {
+ for (size_t i = 0; i < edges.size(); i++)
+ {
+ results(0, i) = edges[i].Lesser();
+ results(1, i) = edges[i].Greater();
+ results(2, i) = sqrt(edges[i].Distance());
+ }
+ }
+} // EmitResults
+
+/**
+ * This function resets the values in the nodes of the tree nearest neighbor
+ * distance and checks for fully connected nodes.
+ */
+template<typename MetricType, typename TreeType>
+void DualTreeBoruvka<MetricType, TreeType>::CleanupHelper(TreeType* tree)
+{
+ tree->Stat().MaxNeighborDistance() = DBL_MAX;
+
+ if (!tree->IsLeaf())
+ {
+ CleanupHelper(tree->Left());
+ CleanupHelper(tree->Right());
+
+ if ((tree->Left()->Stat().ComponentMembership() >= 0)
+ && (tree->Left()->Stat().ComponentMembership() ==
+ tree->Right()->Stat().ComponentMembership()))
+ {
+ tree->Stat().ComponentMembership() =
+ tree->Left()->Stat().ComponentMembership();
+ }
+ }
+ else
+ {
+ size_t newMembership = connections.Find(tree->Begin());
+
+ for (size_t i = tree->Begin(); i < tree->End(); ++i)
+ {
+ if (newMembership != connections.Find(i))
+ {
+ newMembership = -1;
+ Log::Assert(tree->Stat().ComponentMembership() < 0);
+ return;
+ }
+ }
+ tree->Stat().ComponentMembership() = newMembership;
+ }
+} // CleanupHelper
+
+/**
+ * The values stored in the tree must be reset on each iteration.
+ */
+template<typename MetricType, typename TreeType>
+void DualTreeBoruvka<MetricType, TreeType>::Cleanup()
+{
+ for (size_t i = 0; i < data.n_cols; i++)
+ {
+ neighborsDistances[i] = DBL_MAX;
+ }
+
+ if (!naive)
+ {
+ CleanupHelper(tree);
+ }
+}
+
+}; // namespace emst
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb_rules.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/emst/dtb_rules.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb_rules.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,165 +0,0 @@
-/**
- * @file dtb.hpp
- * @author Bill March (march at gatech.edu)
- *
- * Tree traverser rules for the DualTreeBoruvka algorithm.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-
-#ifndef __MLPACK_METHODS_EMST_DTB_RULES_HPP
-#define __MLPACK_METHODS_EMST_DTB_RULES_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace emst {
-
-template<typename MetricType, typename TreeType>
-class DTBRules
-{
- public:
-
- DTBRules(const arma::mat& dataSet,
- UnionFind& connections,
- arma::vec& neighborsDistances,
- arma::Col<size_t>& neighborsInComponent,
- arma::Col<size_t>& neighborsOutComponent,
- MetricType& metric);
-
- double BaseCase(const size_t queryIndex, const size_t referenceIndex);
-
- // Update bounds. Needs a better name.
- void UpdateAfterRecursion(TreeType& queryNode, TreeType& referenceNode);
-
- /**
- * Get the score for recursion order. A low score indicates priority for
- * recursion, while DBL_MAX indicates that the node should not be recursed
- * into at all (it should be pruned).
- *
- * @param queryIndex Index of query point.
- * @param referenceNode Candidate node to be recursed into.
- */
- double Score(const size_t queryIndex, TreeType& referenceNode);
-
- /**
- * Get the score for recursion order, passing the base case result (in the
- * situation where it may be needed to calculate the recursion order). A low
- * score indicates priority for recursion, while DBL_MAX indicates that the
- * node should not be recursed into at all (it should be pruned).
- *
- * @param queryIndex Index of query point.
- * @param referenceNode Candidate node to be recursed into.
- * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
- */
- double Score(const size_t queryIndex,
- TreeType& referenceNode,
- const double baseCaseResult);
-
- /**
- * Re-evaluate the score for recursion order. A low score indicates priority
- * for recursion, while DBL_MAX indicates that the node should not be recursed
- * into at all (it should be pruned). This is used when the score has already
- * been calculated, but another recursion may have modified the bounds for
- * pruning. So the old score is checked against the new pruning bound.
- *
- * @param queryIndex Index of query point.
- * @param referenceNode Candidate node to be recursed into.
- * @param oldScore Old score produced by Score() (or Rescore()).
- */
- double Rescore(const size_t queryIndex,
- TreeType& referenceNode,
- const double oldScore);
-
- /**
- * Get the score for recursion order. A low score indicates priority for
- * recursionm while DBL_MAX indicates that the node should not be recursed
- * into at all (it should be pruned).
- *
- * @param queryNode Candidate query node to recurse into.
- * @param referenceNode Candidate reference node to recurse into.
- */
- double Score(TreeType& queryNode, TreeType& referenceNode) const;
-
- /**
- * Get the score for recursion order, passing the base case result (in the
- * situation where it may be needed to calculate the recursion order). A low
- * score indicates priority for recursion, while DBL_MAX indicates that the
- * node should not be recursed into at all (it should be pruned).
- *
- * @param queryNode Candidate query node to recurse into.
- * @param referenceNode Candidate reference node to recurse into.
- * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
- */
- double Score(TreeType& queryNode,
- TreeType& referenceNode,
- const double baseCaseResult) const;
-
- /**
- * Re-evaluate the score for recursion order. A low score indicates priority
- * for recursion, while DBL_MAX indicates that the node should not be recursed
- * into at all (it should be pruned). This is used when the score has already
- * been calculated, but another recursion may have modified the bounds for
- * pruning. So the old score is checked against the new pruning bound.
- *
- * @param queryNode Candidate query node to recurse into.
- * @param referenceNode Candidate reference node to recurse into.
- * @param oldScore Old score produced by Socre() (or Rescore()).
- */
- double Rescore(TreeType& queryNode,
- TreeType& referenceNode,
- const double oldScore) const;
-
- private:
-
- // This class needs to know what points are connected to one another
-
-
- // Things I need
- // UnionFind storing the tree structure at this iteration
- // neighborDistances
- // neighborInComponent
- // neighborOutComponent
-
- //! The data points.
- const arma::mat& dataSet;
-
- //! Stores the tree structure so far
- UnionFind& connections;
-
- //! The distance to the candidate nearest neighbor for each component.
- arma::vec& neighborsDistances;
-
- //! The index of the point in the component that is an endpoint of the
- //! candidate edge.
- arma::Col<size_t>& neighborsInComponent;
-
- //! The index of the point outside of the component that is an endpoint
- //! of the candidate edge.
- arma::Col<size_t>& neighborsOutComponent;
-
- //! The metric
- MetricType& metric;
-
-}; // class DTBRules
-
-} // emst namespace
-} // mlpack namespace
-
-#include "dtb_rules_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb_rules.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/emst/dtb_rules.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb_rules.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb_rules.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,165 @@
+/**
+ * @file dtb.hpp
+ * @author Bill March (march at gatech.edu)
+ *
+ * Tree traverser rules for the DualTreeBoruvka algorithm.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+
+#ifndef __MLPACK_METHODS_EMST_DTB_RULES_HPP
+#define __MLPACK_METHODS_EMST_DTB_RULES_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace emst {
+
+template<typename MetricType, typename TreeType>
+class DTBRules
+{
+ public:
+
+ DTBRules(const arma::mat& dataSet,
+ UnionFind& connections,
+ arma::vec& neighborsDistances,
+ arma::Col<size_t>& neighborsInComponent,
+ arma::Col<size_t>& neighborsOutComponent,
+ MetricType& metric);
+
+ double BaseCase(const size_t queryIndex, const size_t referenceIndex);
+
+ // Update bounds. Needs a better name.
+ void UpdateAfterRecursion(TreeType& queryNode, TreeType& referenceNode);
+
+ /**
+ * Get the score for recursion order. A low score indicates priority for
+ * recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned).
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ */
+ double Score(const size_t queryIndex, TreeType& referenceNode);
+
+ /**
+ * Get the score for recursion order, passing the base case result (in the
+ * situation where it may be needed to calculate the recursion order). A low
+ * score indicates priority for recursion, while DBL_MAX indicates that the
+ * node should not be recursed into at all (it should be pruned).
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
+ */
+ double Score(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double baseCaseResult);
+
+ /**
+ * Re-evaluate the score for recursion order. A low score indicates priority
+ * for recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned). This is used when the score has already
+ * been calculated, but another recursion may have modified the bounds for
+ * pruning. So the old score is checked against the new pruning bound.
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ * @param oldScore Old score produced by Score() (or Rescore()).
+ */
+ double Rescore(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double oldScore);
+
+ /**
+ * Get the score for recursion order. A low score indicates priority for
+ * recursionm while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned).
+ *
+ * @param queryNode Candidate query node to recurse into.
+ * @param referenceNode Candidate reference node to recurse into.
+ */
+ double Score(TreeType& queryNode, TreeType& referenceNode) const;
+
+ /**
+ * Get the score for recursion order, passing the base case result (in the
+ * situation where it may be needed to calculate the recursion order). A low
+ * score indicates priority for recursion, while DBL_MAX indicates that the
+ * node should not be recursed into at all (it should be pruned).
+ *
+ * @param queryNode Candidate query node to recurse into.
+ * @param referenceNode Candidate reference node to recurse into.
+ * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
+ */
+ double Score(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double baseCaseResult) const;
+
+ /**
+ * Re-evaluate the score for recursion order. A low score indicates priority
+ * for recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned). This is used when the score has already
+ * been calculated, but another recursion may have modified the bounds for
+ * pruning. So the old score is checked against the new pruning bound.
+ *
+ * @param queryNode Candidate query node to recurse into.
+ * @param referenceNode Candidate reference node to recurse into.
+ * @param oldScore Old score produced by Socre() (or Rescore()).
+ */
+ double Rescore(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double oldScore) const;
+
+ private:
+
+ // This class needs to know what points are connected to one another
+
+
+ // Things I need
+ // UnionFind storing the tree structure at this iteration
+ // neighborDistances
+ // neighborInComponent
+ // neighborOutComponent
+
+ //! The data points.
+ const arma::mat& dataSet;
+
+ //! Stores the tree structure so far
+ UnionFind& connections;
+
+ //! The distance to the candidate nearest neighbor for each component.
+ arma::vec& neighborsDistances;
+
+ //! The index of the point in the component that is an endpoint of the
+ //! candidate edge.
+ arma::Col<size_t>& neighborsInComponent;
+
+ //! The index of the point outside of the component that is an endpoint
+ //! of the candidate edge.
+ arma::Col<size_t>& neighborsOutComponent;
+
+ //! The metric
+ MetricType& metric;
+
+}; // class DTBRules
+
+} // emst namespace
+} // mlpack namespace
+
+#include "dtb_rules_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb_rules_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/emst/dtb_rules_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb_rules_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,260 +0,0 @@
-/**
- * @file dtb_impl.hpp
- * @author Bill March (march at gatech.edu)
- *
- * Tree traverser rules for the DualTreeBoruvka algorithm.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-
-#ifndef __MLPACK_METHODS_EMST_DTB_RULES_IMPL_HPP
-#define __MLPACK_METHODS_EMST_DTB_RULES_IMPL_HPP
-
-namespace mlpack {
-namespace emst {
-
-template<typename MetricType, typename TreeType>
-DTBRules<MetricType, TreeType>::
-DTBRules(const arma::mat& dataSet,
- UnionFind& connections,
- arma::vec& neighborsDistances,
- arma::Col<size_t>& neighborsInComponent,
- arma::Col<size_t>& neighborsOutComponent,
- MetricType& metric)
-:
- dataSet(dataSet),
- connections(connections),
- neighborsDistances(neighborsDistances),
- neighborsInComponent(neighborsInComponent),
- neighborsOutComponent(neighborsOutComponent),
- metric(metric)
-{
- // nothing else to do
-} // constructor
-
-template<typename MetricType, typename TreeType>
-double DTBRules<MetricType, TreeType>::BaseCase(const size_t queryIndex,
- const size_t referenceIndex)
-{
-
- // Check if the points are in the same component at this iteration.
- // If not, return the distance between them.
- // Also responsible for storing this as the current neighbor
-
- double newUpperBound = -1.0;
-
- // Find the index of the component the query is in.
- size_t queryComponentIndex = connections.Find(queryIndex);
-
- size_t referenceComponentIndex = connections.Find(referenceIndex);
-
- if (queryComponentIndex != referenceComponentIndex)
- {
- double distance = metric.Evaluate(dataSet.col(queryIndex),
- dataSet.col(referenceIndex));
-
- if (distance < neighborsDistances[queryComponentIndex])
- {
- Log::Assert(queryIndex != referenceIndex);
-
- neighborsDistances[queryComponentIndex] = distance;
- neighborsInComponent[queryComponentIndex] = queryIndex;
- neighborsOutComponent[queryComponentIndex] = referenceIndex;
-
- } // if distance
- } // if indices not equal
-
- if (newUpperBound < neighborsDistances[queryComponentIndex])
- newUpperBound = neighborsDistances[queryComponentIndex];
-
- Log::Assert(newUpperBound >= 0.0);
-
- return newUpperBound;
-
-} // BaseCase()
-
-template<typename MetricType, typename TreeType>
-void DTBRules<MetricType, TreeType>::UpdateAfterRecursion(
- TreeType& queryNode,
- TreeType& /* referenceNode */)
-{
-
- // Find the worst distance that the children found (including any points), and
- // update the bound accordingly.
- double newUpperBound = 0.0;
-
- // First look through children nodes.
- for (size_t i = 0; i < queryNode.NumChildren(); ++i)
- {
- if (newUpperBound < queryNode.Child(i).Stat().MaxNeighborDistance())
- newUpperBound = queryNode.Child(i).Stat().MaxNeighborDistance();
- }
-
- // Now look through children points.
- for (size_t i = 0; i < queryNode.NumPoints(); ++i)
- {
- size_t pointComponent = connections.Find(queryNode.Point(i));
- if (newUpperBound < neighborsDistances[pointComponent])
- newUpperBound = neighborsDistances[pointComponent];
- }
-
- // Update the bound in the query's stat
- queryNode.Stat().MaxNeighborDistance() = newUpperBound;
-
-} // UpdateAfterRecursion
-
-template<typename MetricType, typename TreeType>
-double DTBRules<MetricType, TreeType>::Score(const size_t queryIndex,
- TreeType& referenceNode)
-{
-
- size_t queryComponentIndex = connections.Find(queryIndex);
-
- // If the query belongs to the same component as all of the references,
- // then prune.
- // Casting this to stop a warning about comparing unsigned to signed
- // values.
- if (queryComponentIndex == (size_t)referenceNode.Stat().ComponentMembership())
- return DBL_MAX;
-
- const arma::vec queryPoint = dataSet.unsafe_col(queryIndex);
-
- const double distance = referenceNode.MinDistance(queryPoint);
-
- // If all the points in the reference node are farther than the candidate
- // nearest neighbor for the query's component, we prune.
- return neighborsDistances[queryComponentIndex] < distance
- ? DBL_MAX : distance;
-
-} // Score()
-
-template<typename MetricType, typename TreeType>
-double DTBRules<MetricType, TreeType>::Score(const size_t queryIndex,
- TreeType& referenceNode,
- const double baseCaseResult)
-{
- // I don't really understand the last argument here
- // It just gets passed in the distance call, otherwise this function
- // is the same as the one above
-
- size_t queryComponentIndex = connections.Find(queryIndex);
-
- // if the query belongs to the same component as all of the references,
- // then prune
- if (queryComponentIndex == referenceNode.Stat().ComponentMembership())
- return DBL_MAX;
-
- const arma::vec queryPoint = dataSet.unsafe_col(queryIndex);
-
- const double distance = referenceNode.MinDistance(queryPoint,
- baseCaseResult);
-
- // If all the points in the reference node are farther than the candidate
- // nearest neighbor for the query's component, we prune.
- return neighborsDistances[queryComponentIndex] < distance
- ? DBL_MAX : distance;
-
-} // Score()
-
-template<typename MetricType, typename TreeType>
-double DTBRules<MetricType, TreeType>::Rescore(const size_t queryIndex,
- TreeType& referenceNode,
- const double oldScore)
-{
- // We don't need to check component membership again, because it can't
- // change inside a single iteration.
-
- // If we are already pruning, still prune.
- if (oldScore == DBL_MAX)
- return oldScore;
-
- if (oldScore > neighborsDistances[connections.Find(queryIndex)])
- return DBL_MAX;
- else
- return oldScore;
-
-} // Rescore
-
-template<typename MetricType, typename TreeType>
-double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
- TreeType& referenceNode) const
-{
- // If all the queries belong to the same component as all the references
- // then we prune.
- if ((queryNode.Stat().ComponentMembership() >= 0)
- && (queryNode.Stat().ComponentMembership() ==
- referenceNode.Stat().ComponentMembership()))
- return DBL_MAX;
-
- double distance = queryNode.MinDistance(&referenceNode);
-
- // If all the points in the reference node are farther than the candidate
- // nearest neighbor for all queries in the node, we prune.
- return queryNode.Stat().MaxNeighborDistance() < distance
- ? DBL_MAX : distance;
-
-} // Score()
-
-template<typename MetricType, typename TreeType>
-double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
- TreeType& referenceNode,
- const double baseCaseResult) const
-{
-
- // If all the queries belong to the same component as all the references
- // then we prune.
- if ((queryNode.Stat().ComponentMembership() >= 0)
- && (queryNode.Stat().ComponentMembership() ==
- referenceNode.Stat().ComponentMembership()))
- return DBL_MAX;
-
- const double distance = queryNode.MinDistance(referenceNode,
- baseCaseResult);
-
- // If all the points in the reference node are farther than the candidate
- // nearest neighbor for all queries in the node, we prune.
- return queryNode.Stat().MaxNeighborDistance() < distance
- ? DBL_MAX : distance;
-
-} // Score()
-
-template<typename MetricType, typename TreeType>
-double DTBRules<MetricType, TreeType>::Rescore(TreeType& queryNode,
- TreeType& /* referenceNode */,
- const double oldScore) const
-{
-
- // Same as above, but for nodes,
-
- // If we are already pruning, still prune.
- if (oldScore == DBL_MAX)
- return oldScore;
-
- if (oldScore > queryNode.Stat().MaxNeighborDistance())
- return DBL_MAX;
- else
- return oldScore;
-
-} // Rescore
-
-} // namespace emst
-} // namespace mlpack
-
-
-
-#endif
-
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb_rules_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/emst/dtb_rules_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb_rules_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/dtb_rules_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,260 @@
+/**
+ * @file dtb_impl.hpp
+ * @author Bill March (march at gatech.edu)
+ *
+ * Tree traverser rules for the DualTreeBoruvka algorithm.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+
+#ifndef __MLPACK_METHODS_EMST_DTB_RULES_IMPL_HPP
+#define __MLPACK_METHODS_EMST_DTB_RULES_IMPL_HPP
+
+namespace mlpack {
+namespace emst {
+
+template<typename MetricType, typename TreeType>
+DTBRules<MetricType, TreeType>::
+DTBRules(const arma::mat& dataSet,
+ UnionFind& connections,
+ arma::vec& neighborsDistances,
+ arma::Col<size_t>& neighborsInComponent,
+ arma::Col<size_t>& neighborsOutComponent,
+ MetricType& metric)
+:
+ dataSet(dataSet),
+ connections(connections),
+ neighborsDistances(neighborsDistances),
+ neighborsInComponent(neighborsInComponent),
+ neighborsOutComponent(neighborsOutComponent),
+ metric(metric)
+{
+ // nothing else to do
+} // constructor
+
+template<typename MetricType, typename TreeType>
+double DTBRules<MetricType, TreeType>::BaseCase(const size_t queryIndex,
+ const size_t referenceIndex)
+{
+
+ // Check if the points are in the same component at this iteration.
+ // If not, return the distance between them.
+ // Also responsible for storing this as the current neighbor
+
+ double newUpperBound = -1.0;
+
+ // Find the index of the component the query is in.
+ size_t queryComponentIndex = connections.Find(queryIndex);
+
+ size_t referenceComponentIndex = connections.Find(referenceIndex);
+
+ if (queryComponentIndex != referenceComponentIndex)
+ {
+ double distance = metric.Evaluate(dataSet.col(queryIndex),
+ dataSet.col(referenceIndex));
+
+ if (distance < neighborsDistances[queryComponentIndex])
+ {
+ Log::Assert(queryIndex != referenceIndex);
+
+ neighborsDistances[queryComponentIndex] = distance;
+ neighborsInComponent[queryComponentIndex] = queryIndex;
+ neighborsOutComponent[queryComponentIndex] = referenceIndex;
+
+ } // if distance
+ } // if indices not equal
+
+ if (newUpperBound < neighborsDistances[queryComponentIndex])
+ newUpperBound = neighborsDistances[queryComponentIndex];
+
+ Log::Assert(newUpperBound >= 0.0);
+
+ return newUpperBound;
+
+} // BaseCase()
+
+template<typename MetricType, typename TreeType>
+void DTBRules<MetricType, TreeType>::UpdateAfterRecursion(
+ TreeType& queryNode,
+ TreeType& /* referenceNode */)
+{
+
+ // Find the worst distance that the children found (including any points), and
+ // update the bound accordingly.
+ double newUpperBound = 0.0;
+
+ // First look through children nodes.
+ for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+ {
+ if (newUpperBound < queryNode.Child(i).Stat().MaxNeighborDistance())
+ newUpperBound = queryNode.Child(i).Stat().MaxNeighborDistance();
+ }
+
+ // Now look through children points.
+ for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+ {
+ size_t pointComponent = connections.Find(queryNode.Point(i));
+ if (newUpperBound < neighborsDistances[pointComponent])
+ newUpperBound = neighborsDistances[pointComponent];
+ }
+
+ // Update the bound in the query's stat
+ queryNode.Stat().MaxNeighborDistance() = newUpperBound;
+
+} // UpdateAfterRecursion
+
+template<typename MetricType, typename TreeType>
+double DTBRules<MetricType, TreeType>::Score(const size_t queryIndex,
+ TreeType& referenceNode)
+{
+
+ size_t queryComponentIndex = connections.Find(queryIndex);
+
+ // If the query belongs to the same component as all of the references,
+ // then prune.
+ // Casting this to stop a warning about comparing unsigned to signed
+ // values.
+ if (queryComponentIndex == (size_t)referenceNode.Stat().ComponentMembership())
+ return DBL_MAX;
+
+ const arma::vec queryPoint = dataSet.unsafe_col(queryIndex);
+
+ const double distance = referenceNode.MinDistance(queryPoint);
+
+ // If all the points in the reference node are farther than the candidate
+ // nearest neighbor for the query's component, we prune.
+ return neighborsDistances[queryComponentIndex] < distance
+ ? DBL_MAX : distance;
+
+} // Score()
+
+template<typename MetricType, typename TreeType>
+double DTBRules<MetricType, TreeType>::Score(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double baseCaseResult)
+{
+ // I don't really understand the last argument here
+ // It just gets passed in the distance call, otherwise this function
+ // is the same as the one above
+
+ size_t queryComponentIndex = connections.Find(queryIndex);
+
+ // if the query belongs to the same component as all of the references,
+ // then prune
+ if (queryComponentIndex == referenceNode.Stat().ComponentMembership())
+ return DBL_MAX;
+
+ const arma::vec queryPoint = dataSet.unsafe_col(queryIndex);
+
+ const double distance = referenceNode.MinDistance(queryPoint,
+ baseCaseResult);
+
+ // If all the points in the reference node are farther than the candidate
+ // nearest neighbor for the query's component, we prune.
+ return neighborsDistances[queryComponentIndex] < distance
+ ? DBL_MAX : distance;
+
+} // Score()
+
+template<typename MetricType, typename TreeType>
+double DTBRules<MetricType, TreeType>::Rescore(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double oldScore)
+{
+ // We don't need to check component membership again, because it can't
+ // change inside a single iteration.
+
+ // If we are already pruning, still prune.
+ if (oldScore == DBL_MAX)
+ return oldScore;
+
+ if (oldScore > neighborsDistances[connections.Find(queryIndex)])
+ return DBL_MAX;
+ else
+ return oldScore;
+
+} // Rescore
+
+template<typename MetricType, typename TreeType>
+double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
+ TreeType& referenceNode) const
+{
+ // If all the queries belong to the same component as all the references
+ // then we prune.
+ if ((queryNode.Stat().ComponentMembership() >= 0)
+ && (queryNode.Stat().ComponentMembership() ==
+ referenceNode.Stat().ComponentMembership()))
+ return DBL_MAX;
+
+ double distance = queryNode.MinDistance(&referenceNode);
+
+ // If all the points in the reference node are farther than the candidate
+ // nearest neighbor for all queries in the node, we prune.
+ return queryNode.Stat().MaxNeighborDistance() < distance
+ ? DBL_MAX : distance;
+
+} // Score()
+
+template<typename MetricType, typename TreeType>
+double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double baseCaseResult) const
+{
+
+ // If all the queries belong to the same component as all the references
+ // then we prune.
+ if ((queryNode.Stat().ComponentMembership() >= 0)
+ && (queryNode.Stat().ComponentMembership() ==
+ referenceNode.Stat().ComponentMembership()))
+ return DBL_MAX;
+
+ const double distance = queryNode.MinDistance(referenceNode,
+ baseCaseResult);
+
+ // If all the points in the reference node are farther than the candidate
+ // nearest neighbor for all queries in the node, we prune.
+ return queryNode.Stat().MaxNeighborDistance() < distance
+ ? DBL_MAX : distance;
+
+} // Score()
+
+template<typename MetricType, typename TreeType>
+double DTBRules<MetricType, TreeType>::Rescore(TreeType& queryNode,
+ TreeType& /* referenceNode */,
+ const double oldScore) const
+{
+
+ // Same as above, but for nodes,
+
+ // If we are already pruning, still prune.
+ if (oldScore == DBL_MAX)
+ return oldScore;
+
+ if (oldScore > queryNode.Stat().MaxNeighborDistance())
+ return DBL_MAX;
+ else
+ return oldScore;
+
+} // Rescore
+
+} // namespace emst
+} // namespace mlpack
+
+
+
+#endif
+
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/edge_pair.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/emst/edge_pair.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/edge_pair.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,82 +0,0 @@
-/**
- * @file edge_pair.hpp
- *
- * @author Bill March (march at gatech.edu)
- *
- * This file contains utilities necessary for all of the minimum spanning tree
- * algorithms.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_EMST_EDGE_PAIR_HPP
-#define __MLPACK_METHODS_EMST_EDGE_PAIR_HPP
-
-#include <mlpack/core.hpp>
-
-#include "union_find.hpp"
-
-namespace mlpack {
-namespace emst {
-
-/**
- * An edge pair is simply two indices and a distance. It is used as the
- * basic element of an edge list when computing a minimum spanning tree.
- */
-class EdgePair
-{
- private:
- //! Lesser index.
- size_t lesser;
- //! Greater index.
- size_t greater;
- //! Distance between two indices.
- double distance;
-
- public:
- /**
- * Initialize an EdgePair with two indices and a distance. The indices are
- * called lesser and greater, implying that they be sorted before calling
- * Init. However, this is not necessary for functionality; it is just a way
- * to keep the edge list organized in other code.
- */
- EdgePair(const size_t lesser, const size_t greater, const double dist) :
- lesser(lesser), greater(greater), distance(dist)
- {
- Log::Assert(lesser != greater,
- "EdgePair::EdgePair(): indices cannot be equal.");
- }
-
- //! Get the lesser index.
- size_t Lesser() const { return lesser; }
- //! Modify the lesser index.
- size_t& Lesser() { return lesser; }
-
- //! Get the greater index.
- size_t Greater() const { return greater; }
- //! Modify the greater index.
- size_t& Greater() { return greater; }
-
- //! Get the distance.
- double Distance() const { return distance; }
- //! Modify the distance.
- double& Distance() { return distance; }
-
-}; // class EdgePair
-
-}; // namespace emst
-}; // namespace mlpack
-
-#endif // __MLPACK_METHODS_EMST_EDGE_PAIR_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/edge_pair.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/emst/edge_pair.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/edge_pair.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/edge_pair.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,82 @@
+/**
+ * @file edge_pair.hpp
+ *
+ * @author Bill March (march at gatech.edu)
+ *
+ * This file contains utilities necessary for all of the minimum spanning tree
+ * algorithms.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_EMST_EDGE_PAIR_HPP
+#define __MLPACK_METHODS_EMST_EDGE_PAIR_HPP
+
+#include <mlpack/core.hpp>
+
+#include "union_find.hpp"
+
+namespace mlpack {
+namespace emst {
+
+/**
+ * An edge pair is simply two indices and a distance. It is used as the
+ * basic element of an edge list when computing a minimum spanning tree.
+ */
+class EdgePair
+{
+ private:
+ //! Lesser index.
+ size_t lesser;
+ //! Greater index.
+ size_t greater;
+ //! Distance between two indices.
+ double distance;
+
+ public:
+ /**
+ * Initialize an EdgePair with two indices and a distance. The indices are
+ * called lesser and greater, implying that they be sorted before calling
+ * Init. However, this is not necessary for functionality; it is just a way
+ * to keep the edge list organized in other code.
+ */
+ EdgePair(const size_t lesser, const size_t greater, const double dist) :
+ lesser(lesser), greater(greater), distance(dist)
+ {
+ Log::Assert(lesser != greater,
+ "EdgePair::EdgePair(): indices cannot be equal.");
+ }
+
+ //! Get the lesser index.
+ size_t Lesser() const { return lesser; }
+ //! Modify the lesser index.
+ size_t& Lesser() { return lesser; }
+
+ //! Get the greater index.
+ size_t Greater() const { return greater; }
+ //! Modify the greater index.
+ size_t& Greater() { return greater; }
+
+ //! Get the distance.
+ double Distance() const { return distance; }
+ //! Modify the distance.
+ double& Distance() { return distance; }
+
+}; // class EdgePair
+
+}; // namespace emst
+}; // namespace mlpack
+
+#endif // __MLPACK_METHODS_EMST_EDGE_PAIR_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/emst_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/emst/emst_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/emst_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,118 +0,0 @@
-/**
- * @file emst_main.cpp
- * @author Bill March (march at gatech.edu)
- *
- * Calls the DualTreeBoruvka algorithm from dtb.hpp.
- * Can optionally call naive Boruvka's method.
- *
- * For algorithm details, see:
- *
- * @code
- * @inproceedings{
- * author = {March, W.B., Ram, P., and Gray, A.G.},
- * title = {{Fast Euclidean Minimum Spanning Tree: Algorithm, Analysis,
- * Applications.}},
- * booktitle = {Proceedings of the 16th ACM SIGKDD International Conference
- * on Knowledge Discovery and Data Mining}
- * series = {KDD 2010},
- * year = {2010}
- * }
- * @endcode
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#include "dtb.hpp"
-
-#include <mlpack/core.hpp>
-
-PROGRAM_INFO("Fast Euclidean Minimum Spanning Tree", "This program can compute "
- "the Euclidean minimum spanning tree of a set of input points using the "
- "dual-tree Boruvka algorithm."
- "\n\n"
- "The output is saved in a three-column matrix, where each row indicates an "
- "edge. The first column corresponds to the lesser index of the edge; the "
- "second column corresponds to the greater index of the edge; and the third "
- "column corresponds to the distance between the two points.");
-
-PARAM_STRING_REQ("input_file", "Data input file.", "i");
-PARAM_STRING("output_file", "Data output file. Stored as an edge list.", "o",
- "emst_output.csv");
-PARAM_FLAG("naive", "Compute the MST using O(n^2) naive algorithm.", "n");
-PARAM_INT("leaf_size", "Leaf size in the kd-tree. One-element leaves give the "
- "empirically best performance, but at the cost of greater memory "
- "requirements.", "l", 1);
-
-using namespace mlpack;
-using namespace mlpack::emst;
-using namespace mlpack::tree;
-
-int main(int argc, char* argv[])
-{
- CLI::ParseCommandLine(argc, argv);
-
- ///////////////// READ IN DATA //////////////////////////////////
- std::string dataFilename = CLI::GetParam<std::string>("input_file");
-
- Log::Info << "Reading in data.\n";
-
- arma::mat dataPoints;
- data::Load(dataFilename.c_str(), dataPoints, true);
-
- // Do naive.
- if (CLI::GetParam<bool>("naive"))
- {
- Log::Info << "Running naive algorithm.\n";
-
- DualTreeBoruvka<> naive(dataPoints, true);
-
- arma::mat naiveResults;
- naive.ComputeMST(naiveResults);
-
- std::string outputFilename = CLI::GetParam<std::string>("output_file");
-
- data::Save(outputFilename.c_str(), naiveResults, true);
- }
- else
- {
- Log::Info << "Data read, building tree.\n";
-
- /////////////// Initialize DTB //////////////////////
- if (CLI::GetParam<int>("leaf_size") <= 0)
- {
- Log::Fatal << "Invalid leaf size (" << CLI::GetParam<int>("leaf_size")
- << ")! Must be greater than or equal to 1." << std::endl;
- }
-
- size_t leafSize = CLI::GetParam<int>("leaf_size");
-
- DualTreeBoruvka<> dtb(dataPoints, false, leafSize);
-
- Log::Info << "Tree built, running algorithm.\n";
-
- ////////////// Run DTB /////////////////////
- arma::mat results;
-
- dtb.ComputeMST(results);
-
- //////////////// Output the Results ////////////////
- std::string outputFilename = CLI::GetParam<std::string>("output_file");
-
- data::Save(outputFilename.c_str(), results, true);
- }
-
- return 0;
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/emst_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/emst/emst_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/emst_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/emst_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,118 @@
+/**
+ * @file emst_main.cpp
+ * @author Bill March (march at gatech.edu)
+ *
+ * Calls the DualTreeBoruvka algorithm from dtb.hpp.
+ * Can optionally call naive Boruvka's method.
+ *
+ * For algorithm details, see:
+ *
+ * @code
+ * @inproceedings{
+ * author = {March, W.B., Ram, P., and Gray, A.G.},
+ * title = {{Fast Euclidean Minimum Spanning Tree: Algorithm, Analysis,
+ * Applications.}},
+ * booktitle = {Proceedings of the 16th ACM SIGKDD International Conference
+ * on Knowledge Discovery and Data Mining}
+ * series = {KDD 2010},
+ * year = {2010}
+ * }
+ * @endcode
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include "dtb.hpp"
+
+#include <mlpack/core.hpp>
+
+PROGRAM_INFO("Fast Euclidean Minimum Spanning Tree", "This program can compute "
+ "the Euclidean minimum spanning tree of a set of input points using the "
+ "dual-tree Boruvka algorithm."
+ "\n\n"
+ "The output is saved in a three-column matrix, where each row indicates an "
+ "edge. The first column corresponds to the lesser index of the edge; the "
+ "second column corresponds to the greater index of the edge; and the third "
+ "column corresponds to the distance between the two points.");
+
+PARAM_STRING_REQ("input_file", "Data input file.", "i");
+PARAM_STRING("output_file", "Data output file. Stored as an edge list.", "o",
+ "emst_output.csv");
+PARAM_FLAG("naive", "Compute the MST using O(n^2) naive algorithm.", "n");
+PARAM_INT("leaf_size", "Leaf size in the kd-tree. One-element leaves give the "
+ "empirically best performance, but at the cost of greater memory "
+ "requirements.", "l", 1);
+
+using namespace mlpack;
+using namespace mlpack::emst;
+using namespace mlpack::tree;
+
+int main(int argc, char* argv[])
+{
+ CLI::ParseCommandLine(argc, argv);
+
+ ///////////////// READ IN DATA //////////////////////////////////
+ std::string dataFilename = CLI::GetParam<std::string>("input_file");
+
+ Log::Info << "Reading in data.\n";
+
+ arma::mat dataPoints;
+ data::Load(dataFilename.c_str(), dataPoints, true);
+
+ // Do naive.
+ if (CLI::GetParam<bool>("naive"))
+ {
+ Log::Info << "Running naive algorithm.\n";
+
+ DualTreeBoruvka<> naive(dataPoints, true);
+
+ arma::mat naiveResults;
+ naive.ComputeMST(naiveResults);
+
+ std::string outputFilename = CLI::GetParam<std::string>("output_file");
+
+ data::Save(outputFilename.c_str(), naiveResults, true);
+ }
+ else
+ {
+ Log::Info << "Data read, building tree.\n";
+
+ /////////////// Initialize DTB //////////////////////
+ if (CLI::GetParam<int>("leaf_size") <= 0)
+ {
+ Log::Fatal << "Invalid leaf size (" << CLI::GetParam<int>("leaf_size")
+ << ")! Must be greater than or equal to 1." << std::endl;
+ }
+
+ size_t leafSize = CLI::GetParam<int>("leaf_size");
+
+ DualTreeBoruvka<> dtb(dataPoints, false, leafSize);
+
+ Log::Info << "Tree built, running algorithm.\n";
+
+ ////////////// Run DTB /////////////////////
+ arma::mat results;
+
+ dtb.ComputeMST(results);
+
+ //////////////// Output the Results ////////////////
+ std::string outputFilename = CLI::GetParam<std::string>("output_file");
+
+ data::Save(outputFilename.c_str(), results, true);
+ }
+
+ return 0;
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/union_find.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/emst/union_find.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/union_find.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,115 +0,0 @@
-/**
- * @file union_find.hpp
- * @author Bill March (march at gatech.edu)
- *
- * Implements a union-find data structure. This structure tracks the components
- * of a graph. Each point in the graph is initially in its own component.
- * Calling unionfind.Union(x, y) unites the components indexed by x and y.
- * unionfind.Find(x) returns the index of the component containing point x.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_EMST_UNION_FIND_HPP
-#define __MLPACK_METHODS_EMST_UNION_FIND_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace emst {
-
-/**
- * A Union-Find data structure. See Cormen, Rivest, & Stein for details. The
- * structure tracks the components of a graph. Each point in the graph is
- * initially in its own component. Calling Union(x, y) unites the components
- * indexed by x and y. Find(x) returns the index of the component containing
- * point x.
- */
-class UnionFind
-{
- private:
- size_t size;
- arma::Col<size_t> parent;
- arma::ivec rank;
-
- public:
- //! Construct the object with the given size.
- UnionFind(const size_t size) : size(size), parent(size), rank(size)
- {
- for (size_t i = 0; i < size; ++i)
- {
- parent[i] = i;
- rank[i] = 0;
- }
- }
-
- //! Destroy the object (nothing to do).
- ~UnionFind() { }
-
- /**
- * Returns the component containing an element.
- *
- * @param x the component to be found
- * @return The index of the component containing x
- */
- size_t Find(const size_t x)
- {
- if (parent[x] == x)
- {
- return x;
- }
- else
- {
- // This ensures that the tree has a small depth
- parent[x] = Find(parent[x]);
- return parent[x];
- }
- }
-
- /**
- * Union the components containing x and y.
- *
- * @param x one component
- * @param y the other component
- */
- void Union(const size_t x, const size_t y)
- {
- const size_t xRoot = Find(x);
- const size_t yRoot = Find(y);
-
- if (xRoot == yRoot)
- {
- return;
- }
- else if (rank[xRoot] == rank[yRoot])
- {
- parent[yRoot] = parent[xRoot];
- rank[xRoot] = rank[xRoot] + 1;
- }
- else if (rank[xRoot] > rank[yRoot])
- {
- parent[yRoot] = xRoot;
- }
- else
- {
- parent[xRoot] = yRoot;
- }
- }
-}; // class UnionFind
-
-}; // namespace emst
-}; // namespace mlpack
-
-#endif // __MLPACK_METHODS_EMST_UNION_FIND_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/union_find.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/emst/union_find.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/union_find.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/emst/union_find.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,115 @@
+/**
+ * @file union_find.hpp
+ * @author Bill March (march at gatech.edu)
+ *
+ * Implements a union-find data structure. This structure tracks the components
+ * of a graph. Each point in the graph is initially in its own component.
+ * Calling unionfind.Union(x, y) unites the components indexed by x and y.
+ * unionfind.Find(x) returns the index of the component containing point x.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_EMST_UNION_FIND_HPP
+#define __MLPACK_METHODS_EMST_UNION_FIND_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace emst {
+
+/**
+ * A Union-Find data structure. See Cormen, Rivest, & Stein for details. The
+ * structure tracks the components of a graph. Each point in the graph is
+ * initially in its own component. Calling Union(x, y) unites the components
+ * indexed by x and y. Find(x) returns the index of the component containing
+ * point x.
+ */
+class UnionFind
+{
+ private:
+ size_t size;
+ arma::Col<size_t> parent;
+ arma::ivec rank;
+
+ public:
+ //! Construct the object with the given size.
+ UnionFind(const size_t size) : size(size), parent(size), rank(size)
+ {
+ for (size_t i = 0; i < size; ++i)
+ {
+ parent[i] = i;
+ rank[i] = 0;
+ }
+ }
+
+ //! Destroy the object (nothing to do).
+ ~UnionFind() { }
+
+ /**
+ * Returns the component containing an element.
+ *
+ * @param x the component to be found
+ * @return The index of the component containing x
+ */
+ size_t Find(const size_t x)
+ {
+ if (parent[x] == x)
+ {
+ return x;
+ }
+ else
+ {
+ // This ensures that the tree has a small depth
+ parent[x] = Find(parent[x]);
+ return parent[x];
+ }
+ }
+
+ /**
+ * Union the components containing x and y.
+ *
+ * @param x one component
+ * @param y the other component
+ */
+ void Union(const size_t x, const size_t y)
+ {
+ const size_t xRoot = Find(x);
+ const size_t yRoot = Find(y);
+
+ if (xRoot == yRoot)
+ {
+ return;
+ }
+ else if (rank[xRoot] == rank[yRoot])
+ {
+ parent[yRoot] = parent[xRoot];
+ rank[xRoot] = rank[xRoot] + 1;
+ }
+ else if (rank[xRoot] > rank[yRoot])
+ {
+ parent[yRoot] = xRoot;
+ }
+ else
+ {
+ parent[xRoot] = yRoot;
+ }
+ }
+}; // class UnionFind
+
+}; // namespace emst
+}; // namespace mlpack
+
+#endif // __MLPACK_METHODS_EMST_UNION_FIND_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/fastmks/fastmks.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,235 +0,0 @@
-/**
- * @file fastmks.hpp
- * @author Ryan Curtin
- *
- * Definition of the FastMKS class, which implements fast exact max-kernel
- * search.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_FASTMKS_FASTMKS_HPP
-#define __MLPACK_METHODS_FASTMKS_FASTMKS_HPP
-
-#include <mlpack/core.hpp>
-#include "ip_metric.hpp"
-#include "fastmks_stat.hpp"
-#include <mlpack/core/tree/cover_tree.hpp>
-
-namespace mlpack {
-namespace fastmks {
-
-/**
- * An implementation of fast exact max-kernel search. Given a query dataset and
- * a reference dataset (or optionally just a reference dataset which is also
- * used as the query dataset), fast exact max-kernel search finds, for each
- * point in the query dataset, the k points in the reference set with maximum
- * kernel value K(p_q, p_r), where k is a specified parameter and K() is a
- * Mercer kernel.
- *
- * For more information, see the following paper.
- *
- * @code
- * @inproceedings{curtin2013fast,
- * title={Fast Exact Max-Kernel Search},
- * author={Curtin, Ryan R. and Ram, Parikshit and Gray, Alexander G.},
- * booktitle={Proceedings of the 2013 SIAM International Conference on Data
- * Mining (SDM 13)},
- * year={2013}
- * }
- * @endcode
- *
- * This class allows specification of the type of kernel and also of the type of
- * tree. FastMKS can be run on kernels that work on arbitrary objects --
- * however, this only works with cover trees and other trees that are built only
- * on points in the dataset (and not centroids of regions or anything like
- * that).
- *
- * @tparam KernelType Type of kernel to run FastMKS with.
- * @tparam TreeType Type of tree to run FastMKS with; it must have metric
- * IPMetric<KernelType>.
- */
-template<
- typename KernelType,
- typename TreeType = tree::CoverTree<IPMetric<KernelType>,
- tree::FirstPointIsRoot, FastMKSStat>
->
-class FastMKS
-{
- public:
- /**
- * Create the FastMKS object using the reference set as the query set.
- * Optionally, specify whether or not single-tree search or naive
- * (brute-force) search should be used.
- *
- * @param referenceSet Set of data to run FastMKS on.
- * @param single Whether or not to run single-tree search.
- * @param naive Whether or not to run brute-force (naive) search.
- */
- FastMKS(const arma::mat& referenceSet,
- const bool single = false,
- const bool naive = false);
-
- /**
- * Create the FastMKS object using separate reference and query sets.
- * Optionally, specify whether or not single-tree search or naive
- * (brute-force) search should be used.
- *
- * @param referenceSet Reference set of data for FastMKS.
- * @param querySet Set of query points for FastMKS.
- * @param single Whether or not to run single-tree search.
- * @param naive Whether or not to run brute-force (naive) search.
- */
- FastMKS(const arma::mat& referenceSet,
- const arma::mat& querySet,
- const bool single = false,
- const bool naive = false);
-
- /**
- * Create the FastMKS object using the reference set as the query set, and
- * with an initialized kernel. This is useful for when the kernel stores
- * state. Optionally, specify whether or not single-tree search or naive
- * (brute-force) search should be used.
- *
- * @param referenceSet Reference set of data for FastMKS.
- * @param kernel Initialized kernel.
- * @param single Whether or not to run single-tree search.
- * @param naive Whether or not to run brute-force (naive) search.
- */
- FastMKS(const arma::mat& referenceSet,
- KernelType& kernel,
- const bool single = false,
- const bool naive = false);
-
- /**
- * Create the FastMKS object using separate reference and query sets, and with
- * an initialized kernel. This is useful for when the kernel stores state.
- * Optionally, specify whether or not single-tree search or naive
- * (brute-force) search should be used.
- *
- * @param referenceSet Reference set of data for FastMKS.
- * @param querySet Set of query points for FastMKS.
- * @param kernel Initialized kernel.
- * @param single Whether or not to run single-tree search.
- * @param naive Whether or not to run brute-force (naive) search.
- */
- FastMKS(const arma::mat& referenceSet,
- const arma::mat& querySet,
- KernelType& kernel,
- const bool single = false,
- const bool naive = false);
-
- /**
- * Create the FastMKS object with an already-initialized tree built on the
- * reference points. Be sure that the tree is built with the metric type
- * IPMetric<KernelType>. For this constructor, the reference set and the
- * query set are the same points. Optionally, whether or not to run
- * single-tree search or brute-force (naive) search can be specified.
- *
- * @param referenceSet Reference set of data for FastMKS.
- * @param referenceTree Tree built on reference data.
- * @param single Whether or not to run single-tree search.
- * @param naive Whether or not to run brute-force (naive) search.
- */
- FastMKS(const arma::mat& referenceSet,
- TreeType* referenceTree,
- const bool single = false,
- const bool naive = false);
-
- /**
- * Create the FastMKS object with already-initialized trees built on the
- * reference and query points. Be sure that the trees are built with the
- * metric type IPMetric<KernelType>. Optionally, whether or not to run
- * single-tree search or naive (brute-force) search can be specified.
- *
- * @param referenceSet Reference set of data for FastMKS.
- * @param referenceTree Tree built on reference data.
- * @param querySet Set of query points for FastMKS.
- * @param queryTree Tree built on query data.
- * @param single Whether or not to use single-tree search.
- * @param naive Whether or not to use naive (brute-force) search.
- */
- FastMKS(const arma::mat& referenceSet,
- TreeType* referenceTree,
- const arma::mat& querySet,
- TreeType* queryTree,
- const bool single = false,
- const bool naive = false);
-
- //! Destructor for the FastMKS object.
- ~FastMKS();
-
- /**
- * Search for the maximum inner products of the query set (or if no query set
- * was passed, the reference set is used). The resulting maximum inner
- * products are stored in the products matrix and the corresponding point
- * indices are stores in the indices matrix. The results for each point in
- * the query set are stored in the corresponding column of the indices and
- * products matrices; for instance, the index of the point with maximum inner
- * product to point 4 in the query set will be stored in row 0 and column 4 of
- * the indices matrix.
- *
- * @param k The number of maximum kernels to find.
- * @param indices Matrix to store resulting indices of max-kernel search in.
- * @param products Matrix to store resulting max-kernel values in.
- */
- void Search(const size_t k,
- arma::Mat<size_t>& indices,
- arma::mat& products);
-
- //! Get the inner-product metric induced by the given kernel.
- const IPMetric<KernelType>& Metric() const { return metric; }
- //! Modify the inner-product metric induced by the given kernel.
- IPMetric<KernelType>& Metric() { return metric; }
-
- private:
- //! The reference dataset.
- const arma::mat& referenceSet;
- //! The query dataset.
- const arma::mat& querySet;
-
- //! The tree built on the reference dataset.
- TreeType* referenceTree;
- //! The tree built on the query dataset. This is NULL if there is no query
- //! set.
- TreeType* queryTree;
-
- //! If true, this object created the trees and is responsible for them.
- bool treeOwner;
-
- //! If true, single-tree search is used.
- bool single;
- //! If true, naive (brute-force) search is used.
- bool naive;
-
- //! The instantiated inner-product metric induced by the given kernel.
- IPMetric<KernelType> metric;
-
- //! Utility function. Copied too many times from too many places.
- void InsertNeighbor(arma::Mat<size_t>& indices,
- arma::mat& products,
- const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance);
-};
-
-}; // namespace fastmks
-}; // namespace mlpack
-
-// Include implementation.
-#include "fastmks_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/fastmks/fastmks.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,235 @@
+/**
+ * @file fastmks.hpp
+ * @author Ryan Curtin
+ *
+ * Definition of the FastMKS class, which implements fast exact max-kernel
+ * search.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_FASTMKS_FASTMKS_HPP
+#define __MLPACK_METHODS_FASTMKS_FASTMKS_HPP
+
+#include <mlpack/core.hpp>
+#include "ip_metric.hpp"
+#include "fastmks_stat.hpp"
+#include <mlpack/core/tree/cover_tree.hpp>
+
+namespace mlpack {
+namespace fastmks {
+
+/**
+ * An implementation of fast exact max-kernel search. Given a query dataset and
+ * a reference dataset (or optionally just a reference dataset which is also
+ * used as the query dataset), fast exact max-kernel search finds, for each
+ * point in the query dataset, the k points in the reference set with maximum
+ * kernel value K(p_q, p_r), where k is a specified parameter and K() is a
+ * Mercer kernel.
+ *
+ * For more information, see the following paper.
+ *
+ * @code
+ * @inproceedings{curtin2013fast,
+ * title={Fast Exact Max-Kernel Search},
+ * author={Curtin, Ryan R. and Ram, Parikshit and Gray, Alexander G.},
+ * booktitle={Proceedings of the 2013 SIAM International Conference on Data
+ * Mining (SDM 13)},
+ * year={2013}
+ * }
+ * @endcode
+ *
+ * This class allows specification of the type of kernel and also of the type of
+ * tree. FastMKS can be run on kernels that work on arbitrary objects --
+ * however, this only works with cover trees and other trees that are built only
+ * on points in the dataset (and not centroids of regions or anything like
+ * that).
+ *
+ * @tparam KernelType Type of kernel to run FastMKS with.
+ * @tparam TreeType Type of tree to run FastMKS with; it must have metric
+ * IPMetric<KernelType>.
+ */
+template<
+ typename KernelType,
+ typename TreeType = tree::CoverTree<IPMetric<KernelType>,
+ tree::FirstPointIsRoot, FastMKSStat>
+>
+class FastMKS
+{
+ public:
+ /**
+ * Create the FastMKS object using the reference set as the query set.
+ * Optionally, specify whether or not single-tree search or naive
+ * (brute-force) search should be used.
+ *
+ * @param referenceSet Set of data to run FastMKS on.
+ * @param single Whether or not to run single-tree search.
+ * @param naive Whether or not to run brute-force (naive) search.
+ */
+ FastMKS(const arma::mat& referenceSet,
+ const bool single = false,
+ const bool naive = false);
+
+ /**
+ * Create the FastMKS object using separate reference and query sets.
+ * Optionally, specify whether or not single-tree search or naive
+ * (brute-force) search should be used.
+ *
+ * @param referenceSet Reference set of data for FastMKS.
+ * @param querySet Set of query points for FastMKS.
+ * @param single Whether or not to run single-tree search.
+ * @param naive Whether or not to run brute-force (naive) search.
+ */
+ FastMKS(const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ const bool single = false,
+ const bool naive = false);
+
+ /**
+ * Create the FastMKS object using the reference set as the query set, and
+ * with an initialized kernel. This is useful for when the kernel stores
+ * state. Optionally, specify whether or not single-tree search or naive
+ * (brute-force) search should be used.
+ *
+ * @param referenceSet Reference set of data for FastMKS.
+ * @param kernel Initialized kernel.
+ * @param single Whether or not to run single-tree search.
+ * @param naive Whether or not to run brute-force (naive) search.
+ */
+ FastMKS(const arma::mat& referenceSet,
+ KernelType& kernel,
+ const bool single = false,
+ const bool naive = false);
+
+ /**
+ * Create the FastMKS object using separate reference and query sets, and with
+ * an initialized kernel. This is useful for when the kernel stores state.
+ * Optionally, specify whether or not single-tree search or naive
+ * (brute-force) search should be used.
+ *
+ * @param referenceSet Reference set of data for FastMKS.
+ * @param querySet Set of query points for FastMKS.
+ * @param kernel Initialized kernel.
+ * @param single Whether or not to run single-tree search.
+ * @param naive Whether or not to run brute-force (naive) search.
+ */
+ FastMKS(const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ KernelType& kernel,
+ const bool single = false,
+ const bool naive = false);
+
+ /**
+ * Create the FastMKS object with an already-initialized tree built on the
+ * reference points. Be sure that the tree is built with the metric type
+ * IPMetric<KernelType>. For this constructor, the reference set and the
+ * query set are the same points. Optionally, whether or not to run
+ * single-tree search or brute-force (naive) search can be specified.
+ *
+ * @param referenceSet Reference set of data for FastMKS.
+ * @param referenceTree Tree built on reference data.
+ * @param single Whether or not to run single-tree search.
+ * @param naive Whether or not to run brute-force (naive) search.
+ */
+ FastMKS(const arma::mat& referenceSet,
+ TreeType* referenceTree,
+ const bool single = false,
+ const bool naive = false);
+
+ /**
+ * Create the FastMKS object with already-initialized trees built on the
+ * reference and query points. Be sure that the trees are built with the
+ * metric type IPMetric<KernelType>. Optionally, whether or not to run
+ * single-tree search or naive (brute-force) search can be specified.
+ *
+ * @param referenceSet Reference set of data for FastMKS.
+ * @param referenceTree Tree built on reference data.
+ * @param querySet Set of query points for FastMKS.
+ * @param queryTree Tree built on query data.
+ * @param single Whether or not to use single-tree search.
+ * @param naive Whether or not to use naive (brute-force) search.
+ */
+ FastMKS(const arma::mat& referenceSet,
+ TreeType* referenceTree,
+ const arma::mat& querySet,
+ TreeType* queryTree,
+ const bool single = false,
+ const bool naive = false);
+
+ //! Destructor for the FastMKS object.
+ ~FastMKS();
+
+ /**
+ * Search for the maximum inner products of the query set (or if no query set
+ * was passed, the reference set is used). The resulting maximum inner
+ * products are stored in the products matrix and the corresponding point
+ * indices are stores in the indices matrix. The results for each point in
+ * the query set are stored in the corresponding column of the indices and
+ * products matrices; for instance, the index of the point with maximum inner
+ * product to point 4 in the query set will be stored in row 0 and column 4 of
+ * the indices matrix.
+ *
+ * @param k The number of maximum kernels to find.
+ * @param indices Matrix to store resulting indices of max-kernel search in.
+ * @param products Matrix to store resulting max-kernel values in.
+ */
+ void Search(const size_t k,
+ arma::Mat<size_t>& indices,
+ arma::mat& products);
+
+ //! Get the inner-product metric induced by the given kernel.
+ const IPMetric<KernelType>& Metric() const { return metric; }
+ //! Modify the inner-product metric induced by the given kernel.
+ IPMetric<KernelType>& Metric() { return metric; }
+
+ private:
+ //! The reference dataset.
+ const arma::mat& referenceSet;
+ //! The query dataset.
+ const arma::mat& querySet;
+
+ //! The tree built on the reference dataset.
+ TreeType* referenceTree;
+ //! The tree built on the query dataset. This is NULL if there is no query
+ //! set.
+ TreeType* queryTree;
+
+ //! If true, this object created the trees and is responsible for them.
+ bool treeOwner;
+
+ //! If true, single-tree search is used.
+ bool single;
+ //! If true, naive (brute-force) search is used.
+ bool naive;
+
+ //! The instantiated inner-product metric induced by the given kernel.
+ IPMetric<KernelType> metric;
+
+ //! Utility function. Copied too many times from too many places.
+ void InsertNeighbor(arma::Mat<size_t>& indices,
+ arma::mat& products,
+ const size_t queryIndex,
+ const size_t pos,
+ const size_t neighbor,
+ const double distance);
+};
+
+}; // namespace fastmks
+}; // namespace mlpack
+
+// Include implementation.
+#include "fastmks_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/fastmks/fastmks_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,521 +0,0 @@
-/**
- * @file fastmks_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of the FastMKS class (fast max-kernel search).
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_FASTMKS_FASTMKS_IMPL_HPP
-#define __MLPACK_METHODS_FASTMKS_FASTMKS_IMPL_HPP
-
-// In case it hasn't yet been included.
-#include "fastmks.hpp"
-
-#include "fastmks_rules.hpp"
-
-#include <mlpack/core/kernels/gaussian_kernel.hpp>
-#include <queue>
-
-namespace mlpack {
-namespace fastmks {
-
-// Single dataset, no instantiated kernel.
-template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
- const bool single,
- const bool naive) :
- referenceSet(referenceSet),
- querySet(referenceSet),
- referenceTree(NULL),
- queryTree(NULL),
- treeOwner(true),
- single(single),
- naive(naive)
-{
- Timer::Start("tree_building");
-
- if (!naive)
- referenceTree = new TreeType(referenceSet);
-
- Timer::Stop("tree_building");
-}
-
-// Two datasets, no instantiated kernel.
-template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
- const arma::mat& querySet,
- const bool single,
- const bool naive) :
- referenceSet(referenceSet),
- querySet(querySet),
- referenceTree(NULL),
- queryTree(NULL),
- treeOwner(true),
- single(single),
- naive(naive)
-{
- Timer::Start("tree_building");
-
- // If necessary, the trees should be built.
- if (!naive)
- referenceTree = new TreeType(referenceSet);
-
- if (!naive && !single)
- queryTree = new TreeType(querySet);
-
- Timer::Stop("tree_building");
-}
-
-// One dataset, instantiated kernel.
-template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
- KernelType& kernel,
- const bool single,
- const bool naive) :
- referenceSet(referenceSet),
- querySet(referenceSet),
- referenceTree(NULL),
- queryTree(NULL),
- treeOwner(true),
- single(single),
- naive(naive),
- metric(kernel)
-{
- Timer::Start("tree_building");
-
- // If necessary, the reference tree should be built. There is no query tree.
- if (!naive)
- referenceTree = new TreeType(referenceSet, metric);
-
- Timer::Stop("tree_building");
-}
-
-// Two datasets, instantiated kernel.
-template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
- const arma::mat& querySet,
- KernelType& kernel,
- const bool single,
- const bool naive) :
- referenceSet(referenceSet),
- querySet(querySet),
- referenceTree(NULL),
- queryTree(NULL),
- treeOwner(true),
- single(single),
- naive(naive),
- metric(kernel)
-{
- Timer::Start("tree_building");
-
- // If necessary, the trees should be built.
- if (!naive)
- referenceTree = new TreeType(referenceSet, metric);
-
- if (!naive && !single)
- queryTree = new TreeType(querySet, metric);
-
- Timer::Stop("tree_building");
-}
-
-// One dataset, pre-built tree.
-template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
- TreeType* referenceTree,
- const bool single,
- const bool naive) :
- referenceSet(referenceSet),
- querySet(referenceSet),
- referenceTree(referenceTree),
- queryTree(NULL),
- treeOwner(false),
- single(single),
- naive(naive),
- metric(referenceTree->Metric())
-{
- // Nothing to do.
-}
-
-// Two datasets, pre-built trees.
-template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
- TreeType* referenceTree,
- const arma::mat& querySet,
- TreeType* queryTree,
- const bool single,
- const bool naive) :
- referenceSet(referenceSet),
- querySet(querySet),
- referenceTree(referenceTree),
- queryTree(queryTree),
- treeOwner(false),
- single(single),
- naive(naive),
- metric(referenceTree->Metric())
-{
- // Nothing to do.
-}
-
-template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::~FastMKS()
-{
- // If we created the trees, we must delete them.
- if (treeOwner)
- {
- if (queryTree)
- delete queryTree;
- if (referenceTree)
- delete referenceTree;
- }
-}
-
-template<typename KernelType, typename TreeType>
-void FastMKS<KernelType, TreeType>::Search(const size_t k,
- arma::Mat<size_t>& indices,
- arma::mat& products)
-{
- // No remapping will be necessary because we are using the cover tree.
- indices.set_size(k, querySet.n_cols);
- products.set_size(k, querySet.n_cols);
- products.fill(-DBL_MAX);
-
- Timer::Start("computing_products");
-
- // Naive implementation.
- if (naive)
- {
- // Simple double loop. Stupid, slow, but a good benchmark.
- for (size_t q = 0; q < querySet.n_cols; ++q)
- {
- for (size_t r = 0; r < referenceSet.n_cols; ++r)
- {
- if ((&querySet == &referenceSet) && (q == r))
- continue;
-
- const double eval = metric.Kernel().Evaluate(querySet.unsafe_col(q),
- referenceSet.unsafe_col(r));
-
- size_t insertPosition;
- for (insertPosition = 0; insertPosition < indices.n_rows;
- ++insertPosition)
- if (eval > products(insertPosition, q))
- break;
-
- if (insertPosition < indices.n_rows)
- InsertNeighbor(indices, products, q, insertPosition, r, eval);
- }
- }
-
- Timer::Stop("computing_products");
-
- return;
- }
-
- // Single-tree implementation.
- if (single)
- {
- // Create rules object (this will store the results). This constructor
- // precalculates each self-kernel value.
- typedef FastMKSRules<KernelType, TreeType> RuleType;
- RuleType rules(referenceSet, querySet, indices, products, metric.Kernel());
-
- typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
-
- for (size_t i = 0; i < querySet.n_cols; ++i)
- traverser.Traverse(i, *referenceTree);
-
- // Save the number of pruned nodes.
- const size_t numPrunes = traverser.NumPrunes();
-
- Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
-
- Timer::Stop("computing_products");
- return;
- }
-
- // Dual-tree implementation.
- typedef FastMKSRules<KernelType, TreeType> RuleType;
- RuleType rules(referenceSet, querySet, indices, products, metric.Kernel());
-
- typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
-
- if (queryTree)
- traverser.Traverse(*queryTree, *referenceTree);
- else
- traverser.Traverse(*referenceTree, *referenceTree);
-
- const size_t numPrunes = traverser.NumPrunes();
-
- Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
-
- Timer::Stop("computing_products");
- return;
-}
-
-/**
- * Helper function to insert a point into the neighbors and distances matrices.
- *
- * @param queryIndex Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
- * @param neighbor Index of reference point which is being inserted.
- * @param distance Distance from query point to reference point.
- */
-template<typename KernelType, typename TreeType>
-void FastMKS<KernelType, TreeType>::InsertNeighbor(arma::Mat<size_t>& indices,
- arma::mat& products,
- const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance)
-{
- // We only memmove() if there is actually a need to shift something.
- if (pos < (products.n_rows - 1))
- {
- int len = (products.n_rows - 1) - pos;
- memmove(products.colptr(queryIndex) + (pos + 1),
- products.colptr(queryIndex) + pos,
- sizeof(double) * len);
- memmove(indices.colptr(queryIndex) + (pos + 1),
- indices.colptr(queryIndex) + pos,
- sizeof(size_t) * len);
- }
-
- // Now put the new information in the right index.
- products(pos, queryIndex) = distance;
- indices(pos, queryIndex) = neighbor;
-}
-
-// Specialized implementation for tighter bounds for Gaussian.
-/*
-template<>
-void FastMKS<kernel::GaussianKernel>::Search(const size_t k,
- arma::Mat<size_t>& indices,
- arma::mat& products)
-{
- Log::Warn << "Alternate implementation!" << std::endl;
-
- // Terrible copypasta. Bad bad bad.
- // No remapping will be necessary.
- indices.set_size(k, querySet.n_cols);
- products.set_size(k, querySet.n_cols);
- products.fill(-1.0);
-
- Timer::Start("computing_products");
-
- size_t kernelEvaluations = 0;
-
- // Naive implementation.
- if (naive)
- {
- // Simple double loop. Stupid, slow, but a good benchmark.
- for (size_t q = 0; q < querySet.n_cols; ++q)
- {
- for (size_t r = 0; r < referenceSet.n_cols; ++r)
- {
- const double eval = metric.Kernel().Evaluate(querySet.unsafe_col(q),
- referenceSet.unsafe_col(r));
- ++kernelEvaluations;
-
- size_t insertPosition;
- for (insertPosition = 0; insertPosition < indices.n_rows;
- ++insertPosition)
- if (eval > products(insertPosition, q))
- break;
-
- if (insertPosition < indices.n_rows)
- InsertNeighbor(indices, products, q, insertPosition, r, eval);
- }
- }
-
- Timer::Stop("computing_products");
-
- Log::Info << "Kernel evaluations: " << kernelEvaluations << "." << std::endl;
- return;
- }
-
- // Single-tree implementation.
- if (single)
- {
- // Calculate number of pruned nodes.
- size_t numPrunes = 0;
-
- // Precalculate query products ( || q || for all q).
- arma::vec queryProducts(querySet.n_cols);
- for (size_t queryIndex = 0; queryIndex < querySet.n_cols; ++queryIndex)
- queryProducts[queryIndex] = sqrt(metric.Kernel().Evaluate(
- querySet.unsafe_col(queryIndex), querySet.unsafe_col(queryIndex)));
- kernelEvaluations += querySet.n_cols;
-
- // Screw the CoverTreeTraverser, we'll implement it by hand.
- for (size_t queryIndex = 0; queryIndex < querySet.n_cols; ++queryIndex)
- {
- // Use an array of priority queues?
- std::priority_queue<
- SearchFrame<tree::CoverTree<IPMetric<kernel::GaussianKernel> > >,
- std::vector<SearchFrame<tree::CoverTree<IPMetric<
- kernel::GaussianKernel> > > >,
- SearchFrameCompare<tree::CoverTree<IPMetric<
- kernel::GaussianKernel> > > >
- frameQueue;
-
- // Add initial frame.
- SearchFrame<tree::CoverTree<IPMetric<kernel::GaussianKernel> > >
- nextFrame;
- nextFrame.node = referenceTree;
- nextFrame.eval = metric.Kernel().Evaluate(querySet.unsafe_col(queryIndex),
- referenceSet.unsafe_col(referenceTree->Point()));
- Log::Assert(nextFrame.eval <= 1);
- ++kernelEvaluations;
-
- // The initial evaluation will be the best so far.
- indices(0, queryIndex) = referenceTree->Point();
- products(0, queryIndex) = nextFrame.eval;
-
- frameQueue.push(nextFrame);
-
- tree::CoverTree<IPMetric<kernel::GaussianKernel> >* referenceNode;
- double eval;
- double maxProduct;
-
- while (!frameQueue.empty())
- {
- // Get the information for this node.
- const SearchFrame<tree::CoverTree<IPMetric<kernel::GaussianKernel> > >&
- frame = frameQueue.top();
-
- referenceNode = frame.node;
- eval = frame.eval;
-
- // Loop through the children, seeing if we can prune them; if not, add
- // them to the queue. The self-child is different -- it has the same
- // parent (and therefore the same kernel evaluation).
- if (referenceNode->NumChildren() > 0)
- {
- SearchFrame<tree::CoverTree<IPMetric<kernel::GaussianKernel> > >
- childFrame;
-
- // We must handle the self-child differently, to avoid adding it to
- // the results twice.
- childFrame.node = &(referenceNode->Child(0));
- childFrame.eval = eval;
-
- // Alternate pruning rule.
- const double mdd = childFrame.node->FurthestDescendantDistance();
- if (eval >= (1 - std::pow(mdd, 2.0) / 2.0))
- maxProduct = 1;
- else
- maxProduct = eval * (1 - std::pow(mdd, 2.0) / 2.0) + sqrt(1 -
- std::pow(eval, 2.0)) * mdd * sqrt(1 - std::pow(mdd, 2.0) / 4.0);
-
- // Add self-child if we can't prune it.
- if (maxProduct > products(products.n_rows - 1, queryIndex))
- {
- // But only if it has children of its own.
- if (childFrame.node->NumChildren() > 0)
- frameQueue.push(childFrame);
- }
- else
- ++numPrunes;
-
- for (size_t i = 1; i < referenceNode->NumChildren(); ++i)
- {
- // Before we evaluate the child, let's see if it can possibly have
- // a better evaluation.
- const double mpdd = std::min(
- referenceNode->Child(i).ParentDistance() +
- referenceNode->Child(i).FurthestDescendantDistance(), 2.0);
- double maxChildEval = 1;
- if (eval < (1 - std::pow(mpdd, 2.0) / 2.0))
- maxChildEval = eval * (1 - std::pow(mpdd, 2.0) / 2.0) + sqrt(1 -
- std::pow(eval, 2.0)) * mpdd * sqrt(1 - std::pow(mpdd, 2.0)
- / 4.0);
-
- if (maxChildEval > products(products.n_rows - 1, queryIndex))
- {
- // Evaluate child.
- childFrame.node = &(referenceNode->Child(i));
- childFrame.eval = metric.Kernel().Evaluate(
- querySet.unsafe_col(queryIndex),
- referenceSet.unsafe_col(referenceNode->Child(i).Point()));
- ++kernelEvaluations;
-
- // Can we prune it? If we can, we can avoid putting it in the
- // queue (saves time).
- const double cmdd = childFrame.node->FurthestDescendantDistance();
- if (childFrame.eval > (1 - std::pow(cmdd, 2.0) / 2.0))
- maxProduct = 1;
- else
- maxProduct = childFrame.eval * (1 - std::pow(cmdd, 2.0) / 2.0)
- + sqrt(1 - std::pow(eval, 2.0)) * cmdd * sqrt(1 -
- std::pow(cmdd, 2.0) / 4.0);
-
- if (maxProduct > products(products.n_rows - 1, queryIndex))
- {
- // Good enough to recurse into. While we're at it, check the
- // actual evaluation and see if it's an improvement.
- if (childFrame.eval > products(products.n_rows - 1, queryIndex))
- {
- // This is a better result. Find out where to insert.
- size_t insertPosition = 0;
- for ( ; insertPosition < products.n_rows - 1;
- ++insertPosition)
- if (childFrame.eval > products(insertPosition, queryIndex))
- break;
-
- // Insert into the correct position; we are guaranteed that
- // insertPosition is valid.
- InsertNeighbor(indices, products, queryIndex, insertPosition,
- childFrame.node->Point(), childFrame.eval);
- }
-
- // Now add this to the queue (if it has any children which may
- // need to be recursed into).
- if (childFrame.node->NumChildren() > 0)
- frameQueue.push(childFrame);
- }
- else
- ++numPrunes;
- }
- else
- ++numPrunes;
- }
- }
-
- frameQueue.pop();
- }
- }
-
- Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
- Log::Info << "Kernel evaluations: " << kernelEvaluations << "."
- << std::endl;
- Log::Info << "Distance evaluations: " << distanceEvaluations << "."
- << std::endl;
-
- Timer::Stop("computing_products");
- return;
- }
-
- // Double-tree implementation.
- Log::Fatal << "Dual-tree search not implemented yet... oops..." << std::endl;
-
-}
-*/
-
-}; // namespace fastmks
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/fastmks/fastmks_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,521 @@
+/**
+ * @file fastmks_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the FastMKS class (fast max-kernel search).
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_FASTMKS_FASTMKS_IMPL_HPP
+#define __MLPACK_METHODS_FASTMKS_FASTMKS_IMPL_HPP
+
+// In case it hasn't yet been included.
+#include "fastmks.hpp"
+
+#include "fastmks_rules.hpp"
+
+#include <mlpack/core/kernels/gaussian_kernel.hpp>
+#include <queue>
+
+namespace mlpack {
+namespace fastmks {
+
+// Single dataset, no instantiated kernel.
+template<typename KernelType, typename TreeType>
+FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
+ const bool single,
+ const bool naive) :
+ referenceSet(referenceSet),
+ querySet(referenceSet),
+ referenceTree(NULL),
+ queryTree(NULL),
+ treeOwner(true),
+ single(single),
+ naive(naive)
+{
+ Timer::Start("tree_building");
+
+ if (!naive)
+ referenceTree = new TreeType(referenceSet);
+
+ Timer::Stop("tree_building");
+}
+
+// Two datasets, no instantiated kernel.
+template<typename KernelType, typename TreeType>
+FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ const bool single,
+ const bool naive) :
+ referenceSet(referenceSet),
+ querySet(querySet),
+ referenceTree(NULL),
+ queryTree(NULL),
+ treeOwner(true),
+ single(single),
+ naive(naive)
+{
+ Timer::Start("tree_building");
+
+ // If necessary, the trees should be built.
+ if (!naive)
+ referenceTree = new TreeType(referenceSet);
+
+ if (!naive && !single)
+ queryTree = new TreeType(querySet);
+
+ Timer::Stop("tree_building");
+}
+
+// One dataset, instantiated kernel.
+template<typename KernelType, typename TreeType>
+FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
+ KernelType& kernel,
+ const bool single,
+ const bool naive) :
+ referenceSet(referenceSet),
+ querySet(referenceSet),
+ referenceTree(NULL),
+ queryTree(NULL),
+ treeOwner(true),
+ single(single),
+ naive(naive),
+ metric(kernel)
+{
+ Timer::Start("tree_building");
+
+ // If necessary, the reference tree should be built. There is no query tree.
+ if (!naive)
+ referenceTree = new TreeType(referenceSet, metric);
+
+ Timer::Stop("tree_building");
+}
+
+// Two datasets, instantiated kernel.
+template<typename KernelType, typename TreeType>
+FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ KernelType& kernel,
+ const bool single,
+ const bool naive) :
+ referenceSet(referenceSet),
+ querySet(querySet),
+ referenceTree(NULL),
+ queryTree(NULL),
+ treeOwner(true),
+ single(single),
+ naive(naive),
+ metric(kernel)
+{
+ Timer::Start("tree_building");
+
+ // If necessary, the trees should be built.
+ if (!naive)
+ referenceTree = new TreeType(referenceSet, metric);
+
+ if (!naive && !single)
+ queryTree = new TreeType(querySet, metric);
+
+ Timer::Stop("tree_building");
+}
+
+// One dataset, pre-built tree.
+template<typename KernelType, typename TreeType>
+FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
+ TreeType* referenceTree,
+ const bool single,
+ const bool naive) :
+ referenceSet(referenceSet),
+ querySet(referenceSet),
+ referenceTree(referenceTree),
+ queryTree(NULL),
+ treeOwner(false),
+ single(single),
+ naive(naive),
+ metric(referenceTree->Metric())
+{
+ // Nothing to do.
+}
+
+// Two datasets, pre-built trees.
+template<typename KernelType, typename TreeType>
+FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
+ TreeType* referenceTree,
+ const arma::mat& querySet,
+ TreeType* queryTree,
+ const bool single,
+ const bool naive) :
+ referenceSet(referenceSet),
+ querySet(querySet),
+ referenceTree(referenceTree),
+ queryTree(queryTree),
+ treeOwner(false),
+ single(single),
+ naive(naive),
+ metric(referenceTree->Metric())
+{
+ // Nothing to do.
+}
+
+template<typename KernelType, typename TreeType>
+FastMKS<KernelType, TreeType>::~FastMKS()
+{
+ // If we created the trees, we must delete them.
+ if (treeOwner)
+ {
+ if (queryTree)
+ delete queryTree;
+ if (referenceTree)
+ delete referenceTree;
+ }
+}
+
+template<typename KernelType, typename TreeType>
+void FastMKS<KernelType, TreeType>::Search(const size_t k,
+ arma::Mat<size_t>& indices,
+ arma::mat& products)
+{
+ // No remapping will be necessary because we are using the cover tree.
+ indices.set_size(k, querySet.n_cols);
+ products.set_size(k, querySet.n_cols);
+ products.fill(-DBL_MAX);
+
+ Timer::Start("computing_products");
+
+ // Naive implementation.
+ if (naive)
+ {
+ // Simple double loop. Stupid, slow, but a good benchmark.
+ for (size_t q = 0; q < querySet.n_cols; ++q)
+ {
+ for (size_t r = 0; r < referenceSet.n_cols; ++r)
+ {
+ if ((&querySet == &referenceSet) && (q == r))
+ continue;
+
+ const double eval = metric.Kernel().Evaluate(querySet.unsafe_col(q),
+ referenceSet.unsafe_col(r));
+
+ size_t insertPosition;
+ for (insertPosition = 0; insertPosition < indices.n_rows;
+ ++insertPosition)
+ if (eval > products(insertPosition, q))
+ break;
+
+ if (insertPosition < indices.n_rows)
+ InsertNeighbor(indices, products, q, insertPosition, r, eval);
+ }
+ }
+
+ Timer::Stop("computing_products");
+
+ return;
+ }
+
+ // Single-tree implementation.
+ if (single)
+ {
+ // Create rules object (this will store the results). This constructor
+ // precalculates each self-kernel value.
+ typedef FastMKSRules<KernelType, TreeType> RuleType;
+ RuleType rules(referenceSet, querySet, indices, products, metric.Kernel());
+
+ typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+
+ for (size_t i = 0; i < querySet.n_cols; ++i)
+ traverser.Traverse(i, *referenceTree);
+
+ // Save the number of pruned nodes.
+ const size_t numPrunes = traverser.NumPrunes();
+
+ Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
+
+ Timer::Stop("computing_products");
+ return;
+ }
+
+ // Dual-tree implementation.
+ typedef FastMKSRules<KernelType, TreeType> RuleType;
+ RuleType rules(referenceSet, querySet, indices, products, metric.Kernel());
+
+ typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+
+ if (queryTree)
+ traverser.Traverse(*queryTree, *referenceTree);
+ else
+ traverser.Traverse(*referenceTree, *referenceTree);
+
+ const size_t numPrunes = traverser.NumPrunes();
+
+ Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
+
+ Timer::Stop("computing_products");
+ return;
+}
+
+/**
+ * Helper function to insert a point into the neighbors and distances matrices.
+ *
+ * @param queryIndex Index of point whose neighbors we are inserting into.
+ * @param pos Position in list to insert into.
+ * @param neighbor Index of reference point which is being inserted.
+ * @param distance Distance from query point to reference point.
+ */
+template<typename KernelType, typename TreeType>
+void FastMKS<KernelType, TreeType>::InsertNeighbor(arma::Mat<size_t>& indices,
+ arma::mat& products,
+ const size_t queryIndex,
+ const size_t pos,
+ const size_t neighbor,
+ const double distance)
+{
+ // We only memmove() if there is actually a need to shift something.
+ if (pos < (products.n_rows - 1))
+ {
+ int len = (products.n_rows - 1) - pos;
+ memmove(products.colptr(queryIndex) + (pos + 1),
+ products.colptr(queryIndex) + pos,
+ sizeof(double) * len);
+ memmove(indices.colptr(queryIndex) + (pos + 1),
+ indices.colptr(queryIndex) + pos,
+ sizeof(size_t) * len);
+ }
+
+ // Now put the new information in the right index.
+ products(pos, queryIndex) = distance;
+ indices(pos, queryIndex) = neighbor;
+}
+
+// Specialized implementation for tighter bounds for Gaussian.
+/*
+template<>
+void FastMKS<kernel::GaussianKernel>::Search(const size_t k,
+ arma::Mat<size_t>& indices,
+ arma::mat& products)
+{
+ Log::Warn << "Alternate implementation!" << std::endl;
+
+ // Terrible copypasta. Bad bad bad.
+ // No remapping will be necessary.
+ indices.set_size(k, querySet.n_cols);
+ products.set_size(k, querySet.n_cols);
+ products.fill(-1.0);
+
+ Timer::Start("computing_products");
+
+ size_t kernelEvaluations = 0;
+
+ // Naive implementation.
+ if (naive)
+ {
+ // Simple double loop. Stupid, slow, but a good benchmark.
+ for (size_t q = 0; q < querySet.n_cols; ++q)
+ {
+ for (size_t r = 0; r < referenceSet.n_cols; ++r)
+ {
+ const double eval = metric.Kernel().Evaluate(querySet.unsafe_col(q),
+ referenceSet.unsafe_col(r));
+ ++kernelEvaluations;
+
+ size_t insertPosition;
+ for (insertPosition = 0; insertPosition < indices.n_rows;
+ ++insertPosition)
+ if (eval > products(insertPosition, q))
+ break;
+
+ if (insertPosition < indices.n_rows)
+ InsertNeighbor(indices, products, q, insertPosition, r, eval);
+ }
+ }
+
+ Timer::Stop("computing_products");
+
+ Log::Info << "Kernel evaluations: " << kernelEvaluations << "." << std::endl;
+ return;
+ }
+
+ // Single-tree implementation.
+ if (single)
+ {
+ // Calculate number of pruned nodes.
+ size_t numPrunes = 0;
+
+ // Precalculate query products ( || q || for all q).
+ arma::vec queryProducts(querySet.n_cols);
+ for (size_t queryIndex = 0; queryIndex < querySet.n_cols; ++queryIndex)
+ queryProducts[queryIndex] = sqrt(metric.Kernel().Evaluate(
+ querySet.unsafe_col(queryIndex), querySet.unsafe_col(queryIndex)));
+ kernelEvaluations += querySet.n_cols;
+
+ // Screw the CoverTreeTraverser, we'll implement it by hand.
+ for (size_t queryIndex = 0; queryIndex < querySet.n_cols; ++queryIndex)
+ {
+ // Use an array of priority queues?
+ std::priority_queue<
+ SearchFrame<tree::CoverTree<IPMetric<kernel::GaussianKernel> > >,
+ std::vector<SearchFrame<tree::CoverTree<IPMetric<
+ kernel::GaussianKernel> > > >,
+ SearchFrameCompare<tree::CoverTree<IPMetric<
+ kernel::GaussianKernel> > > >
+ frameQueue;
+
+ // Add initial frame.
+ SearchFrame<tree::CoverTree<IPMetric<kernel::GaussianKernel> > >
+ nextFrame;
+ nextFrame.node = referenceTree;
+ nextFrame.eval = metric.Kernel().Evaluate(querySet.unsafe_col(queryIndex),
+ referenceSet.unsafe_col(referenceTree->Point()));
+ Log::Assert(nextFrame.eval <= 1);
+ ++kernelEvaluations;
+
+ // The initial evaluation will be the best so far.
+ indices(0, queryIndex) = referenceTree->Point();
+ products(0, queryIndex) = nextFrame.eval;
+
+ frameQueue.push(nextFrame);
+
+ tree::CoverTree<IPMetric<kernel::GaussianKernel> >* referenceNode;
+ double eval;
+ double maxProduct;
+
+ while (!frameQueue.empty())
+ {
+ // Get the information for this node.
+ const SearchFrame<tree::CoverTree<IPMetric<kernel::GaussianKernel> > >&
+ frame = frameQueue.top();
+
+ referenceNode = frame.node;
+ eval = frame.eval;
+
+ // Loop through the children, seeing if we can prune them; if not, add
+ // them to the queue. The self-child is different -- it has the same
+ // parent (and therefore the same kernel evaluation).
+ if (referenceNode->NumChildren() > 0)
+ {
+ SearchFrame<tree::CoverTree<IPMetric<kernel::GaussianKernel> > >
+ childFrame;
+
+ // We must handle the self-child differently, to avoid adding it to
+ // the results twice.
+ childFrame.node = &(referenceNode->Child(0));
+ childFrame.eval = eval;
+
+ // Alternate pruning rule.
+ const double mdd = childFrame.node->FurthestDescendantDistance();
+ if (eval >= (1 - std::pow(mdd, 2.0) / 2.0))
+ maxProduct = 1;
+ else
+ maxProduct = eval * (1 - std::pow(mdd, 2.0) / 2.0) + sqrt(1 -
+ std::pow(eval, 2.0)) * mdd * sqrt(1 - std::pow(mdd, 2.0) / 4.0);
+
+ // Add self-child if we can't prune it.
+ if (maxProduct > products(products.n_rows - 1, queryIndex))
+ {
+ // But only if it has children of its own.
+ if (childFrame.node->NumChildren() > 0)
+ frameQueue.push(childFrame);
+ }
+ else
+ ++numPrunes;
+
+ for (size_t i = 1; i < referenceNode->NumChildren(); ++i)
+ {
+ // Before we evaluate the child, let's see if it can possibly have
+ // a better evaluation.
+ const double mpdd = std::min(
+ referenceNode->Child(i).ParentDistance() +
+ referenceNode->Child(i).FurthestDescendantDistance(), 2.0);
+ double maxChildEval = 1;
+ if (eval < (1 - std::pow(mpdd, 2.0) / 2.0))
+ maxChildEval = eval * (1 - std::pow(mpdd, 2.0) / 2.0) + sqrt(1 -
+ std::pow(eval, 2.0)) * mpdd * sqrt(1 - std::pow(mpdd, 2.0)
+ / 4.0);
+
+ if (maxChildEval > products(products.n_rows - 1, queryIndex))
+ {
+ // Evaluate child.
+ childFrame.node = &(referenceNode->Child(i));
+ childFrame.eval = metric.Kernel().Evaluate(
+ querySet.unsafe_col(queryIndex),
+ referenceSet.unsafe_col(referenceNode->Child(i).Point()));
+ ++kernelEvaluations;
+
+ // Can we prune it? If we can, we can avoid putting it in the
+ // queue (saves time).
+ const double cmdd = childFrame.node->FurthestDescendantDistance();
+ if (childFrame.eval > (1 - std::pow(cmdd, 2.0) / 2.0))
+ maxProduct = 1;
+ else
+ maxProduct = childFrame.eval * (1 - std::pow(cmdd, 2.0) / 2.0)
+ + sqrt(1 - std::pow(eval, 2.0)) * cmdd * sqrt(1 -
+ std::pow(cmdd, 2.0) / 4.0);
+
+ if (maxProduct > products(products.n_rows - 1, queryIndex))
+ {
+ // Good enough to recurse into. While we're at it, check the
+ // actual evaluation and see if it's an improvement.
+ if (childFrame.eval > products(products.n_rows - 1, queryIndex))
+ {
+ // This is a better result. Find out where to insert.
+ size_t insertPosition = 0;
+ for ( ; insertPosition < products.n_rows - 1;
+ ++insertPosition)
+ if (childFrame.eval > products(insertPosition, queryIndex))
+ break;
+
+ // Insert into the correct position; we are guaranteed that
+ // insertPosition is valid.
+ InsertNeighbor(indices, products, queryIndex, insertPosition,
+ childFrame.node->Point(), childFrame.eval);
+ }
+
+ // Now add this to the queue (if it has any children which may
+ // need to be recursed into).
+ if (childFrame.node->NumChildren() > 0)
+ frameQueue.push(childFrame);
+ }
+ else
+ ++numPrunes;
+ }
+ else
+ ++numPrunes;
+ }
+ }
+
+ frameQueue.pop();
+ }
+ }
+
+ Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
+ Log::Info << "Kernel evaluations: " << kernelEvaluations << "."
+ << std::endl;
+ Log::Info << "Distance evaluations: " << distanceEvaluations << "."
+ << std::endl;
+
+ Timer::Stop("computing_products");
+ return;
+ }
+
+ // Double-tree implementation.
+ Log::Fatal << "Dual-tree search not implemented yet... oops..." << std::endl;
+
+}
+*/
+
+}; // namespace fastmks
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/fastmks/fastmks_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,325 +0,0 @@
-/**
- * @file fastmks_main.cpp
- * @author Ryan Curtin
- *
- * Main executable for maximum inner product search.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/core/kernels/linear_kernel.hpp>
-#include <mlpack/core/kernels/polynomial_kernel.hpp>
-#include <mlpack/core/kernels/cosine_distance.hpp>
-#include <mlpack/core/kernels/gaussian_kernel.hpp>
-#include <mlpack/core/kernels/hyperbolic_tangent_kernel.hpp>
-#include <mlpack/core/kernels/triangular_kernel.hpp>
-#include <mlpack/core/kernels/epanechnikov_kernel.hpp>
-
-#include "fastmks.hpp"
-
-using namespace std;
-using namespace mlpack;
-using namespace mlpack::fastmks;
-using namespace mlpack::kernel;
-using namespace mlpack::tree;
-
-PROGRAM_INFO("FastMKS (Fast Max-Kernel Search)",
- "This program will find the k maximum kernel of a set of points, "
- "using a query set and a reference set (which can optionally be the same "
- "set). More specifically, for each point in the query set, the k points in"
- " the reference set with maximum kernel evaluations are found. The kernel "
- "function used is specified by --kernel."
- "\n\n"
- "For example, the following command will calculate, for each point in "
- "'query.csv', the five points in 'reference.csv' with maximum kernel "
- "evaluation using the linear kernel. The kernel evaluations are stored in "
- "'kernels.csv' and the indices are stored in 'indices.csv'."
- "\n\n"
- "$ fastmks --k 5 --reference_file reference.csv --query_file query.csv\n"
- " --indices_file indices.csv --products_file kernels.csv --kernel linear"
- "\n\n"
- "The output files are organized such that row i and column j in the indices"
- " output file corresponds to the index of the point in the reference set "
- "that has i'th largest kernel evaluation with the point in the query set "
- "with index j. Row i and column j in the products output file corresponds "
- "to the kernel evaluation between those two points."
- "\n\n"
- "This executable performs FastMKS using a cover tree. The base used to "
- "build the cover tree can be specified with the --base option.");
-
-// Define our input parameters.
-PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
- "r");
-PARAM_STRING("query_file", "File containing the query dataset.", "q", "");
-
-PARAM_INT_REQ("k", "Number of maximum inner products to find.", "k");
-
-PARAM_STRING("products_file", "File to save inner products into.", "p", "");
-PARAM_STRING("indices_file", "File to save indices of inner products into.",
- "i", "");
-
-PARAM_STRING("kernel", "Kernel type to use: 'linear', 'polynomial', 'cosine', "
- "'gaussian', 'epanechnikov', 'triangular', 'hyptan'.", "K", "linear");
-
-PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
-PARAM_FLAG("single", "If true, single-tree search is used (as opposed to "
- "dual-tree search.", "s");
-
-// Cover tree parameter.
-PARAM_DOUBLE("base", "Base to use during cover tree construction.", "b", 2.0);
-
-// Kernel parameters.
-PARAM_DOUBLE("degree", "Degree of polynomial kernel.", "d", 2.0);
-PARAM_DOUBLE("offset", "Offset of kernel (for polynomial and hyptan kernels).",
- "o", 0.0);
-PARAM_DOUBLE("bandwidth", "Bandwidth (for Gaussian, Epanechnikov, and "
- "triangular kernels).", "w", 1.0);
-PARAM_DOUBLE("scale", "Scale of kernel (for hyptan kernel).", "s", 1.0);
-
-//! Run FastMKS on a single dataset for the given kernel type.
-template<typename KernelType>
-void RunFastMKS(const arma::mat& referenceData,
- const bool single,
- const bool naive,
- const double base,
- const size_t k,
- arma::Mat<size_t>& indices,
- arma::mat& products,
- KernelType& kernel)
-{
- // Create the tree with the specified base.
- typedef CoverTree<IPMetric<KernelType>, FirstPointIsRoot, FastMKSStat>
- TreeType;
- IPMetric<KernelType> metric(kernel);
- TreeType tree(referenceData, metric, base);
-
- // Create FastMKS object.
- FastMKS<KernelType> fastmks(referenceData, &tree, (single && !naive), naive);
-
- // Now search with it.
- fastmks.Search(k, indices, products);
-}
-
-//! Run FastMKS for a given query and reference set using the given kernel type.
-template<typename KernelType>
-void RunFastMKS(const arma::mat& referenceData,
- const arma::mat& queryData,
- const bool single,
- const bool naive,
- const double base,
- const size_t k,
- arma::Mat<size_t>& indices,
- arma::mat& products,
- KernelType& kernel)
-{
- // Create the tree with the specified base.
- typedef CoverTree<IPMetric<KernelType>, FirstPointIsRoot, FastMKSStat>
- TreeType;
- IPMetric<KernelType> metric(kernel);
- TreeType referenceTree(referenceData, metric, base);
- TreeType queryTree(queryData, metric, base);
-
- // Create FastMKS object.
- FastMKS<KernelType> fastmks(referenceData, &referenceTree, queryData,
- &queryTree, (single && !naive), naive);
-
- // Now search with it.
- fastmks.Search(k, indices, products);
-}
-
-int main(int argc, char** argv)
-{
- CLI::ParseCommandLine(argc, argv);
-
- // Get reference dataset filename.
- const string referenceFile = CLI::GetParam<string>("reference_file");
-
- // The number of max kernel values to find.
- const size_t k = CLI::GetParam<int>("k");
-
- // Runtime parameters.
- const bool naive = CLI::HasParam("naive");
- const bool single = CLI::HasParam("single");
-
- // For cover tree construction.
- const double base = CLI::GetParam<double>("base");
-
- // Kernel parameters.
- const string kernelType = CLI::GetParam<string>("kernel");
- const double degree = CLI::GetParam<double>("degree");
- const double offset = CLI::GetParam<double>("offset");
- const double bandwidth = CLI::GetParam<double>("bandwidth");
- const double scale = CLI::GetParam<double>("scale");
-
- // The datasets. The query matrix may never be used.
- arma::mat referenceData;
- arma::mat queryData;
-
- data::Load(referenceFile, referenceData, true);
-
- Log::Info << "Loaded reference data from '" << referenceFile << "' ("
- << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
-
- // Sanity check on k value.
- if (k > referenceData.n_cols)
- {
- Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
- Log::Fatal << "than or equal to the number of reference points (";
- Log::Fatal << referenceData.n_cols << ")." << endl;
- }
-
- // Check on kernel type.
- if ((kernelType != "linear") && (kernelType != "polynomial") &&
- (kernelType != "cosine") && (kernelType != "gaussian") &&
- (kernelType != "graph") && (kernelType != "approxGraph") &&
- (kernelType != "triangular") && (kernelType != "hyptan") &&
- (kernelType != "inv-mq") && (kernelType != "epanechnikov"))
- {
- Log::Fatal << "Invalid kernel type: '" << kernelType << "'; must be ";
- Log::Fatal << "'linear' or 'polynomial'." << endl;
- }
-
- // Load the query matrix, if we can.
- if (CLI::HasParam("query_file"))
- {
- const string queryFile = CLI::GetParam<string>("query_file");
- data::Load(queryFile, queryData, true);
-
- Log::Info << "Loaded query data from '" << queryFile << "' ("
- << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
- }
- else
- {
- Log::Info << "Using reference dataset as query dataset (--query_file not "
- << "specified)." << endl;
- }
-
- // Naive mode overrides single mode.
- if (naive && single)
- {
- Log::Warn << "--single ignored because --naive is present." << endl;
- }
-
- // Matrices for output storage.
- arma::Mat<size_t> indices;
- arma::mat products;
-
- // Construct FastMKS object.
- if (queryData.n_elem == 0)
- {
- if (kernelType == "linear")
- {
- LinearKernel lk;
- RunFastMKS<LinearKernel>(referenceData, single, naive, base, k, indices,
- products, lk);
- }
- else if (kernelType == "polynomial")
- {
-
- PolynomialKernel pk(degree, offset);
- RunFastMKS<PolynomialKernel>(referenceData, single, naive, base, k,
- indices, products, pk);
- }
- else if (kernelType == "cosine")
- {
- CosineDistance cd;
- RunFastMKS<CosineDistance>(referenceData, single, naive, base, k, indices,
- products, cd);
- }
- else if (kernelType == "gaussian")
- {
- GaussianKernel gk(bandwidth);
- RunFastMKS<GaussianKernel>(referenceData, single, naive, base, k, indices,
- products, gk);
- }
- else if (kernelType == "epanechnikov")
- {
- EpanechnikovKernel ek(bandwidth);
- RunFastMKS<EpanechnikovKernel>(referenceData, single, naive, base, k,
- indices, products, ek);
- }
- else if (kernelType == "triangular")
- {
- TriangularKernel tk(bandwidth);
- RunFastMKS<TriangularKernel>(referenceData, single, naive, base, k,
- indices, products, tk);
- }
- else if (kernelType == "hyptan")
- {
- HyperbolicTangentKernel htk(scale, offset);
- RunFastMKS<HyperbolicTangentKernel>(referenceData, single, naive, base, k,
- indices, products, htk);
- }
- }
- else
- {
- if (kernelType == "linear")
- {
- LinearKernel lk;
- RunFastMKS<LinearKernel>(referenceData, queryData, single, naive, base, k,
- indices, products, lk);
- }
- else if (kernelType == "polynomial")
- {
- PolynomialKernel pk(degree, offset);
- RunFastMKS<PolynomialKernel>(referenceData, queryData, single, naive,
- base, k, indices, products, pk);
- }
- else if (kernelType == "cosine")
- {
- CosineDistance cd;
- RunFastMKS<CosineDistance>(referenceData, queryData, single, naive, base,
- k, indices, products, cd);
- }
- else if (kernelType == "gaussian")
- {
- GaussianKernel gk(bandwidth);
- RunFastMKS<GaussianKernel>(referenceData, queryData, single, naive, base,
- k, indices, products, gk);
- }
- else if (kernelType == "epanechnikov")
- {
- EpanechnikovKernel ek(bandwidth);
- RunFastMKS<EpanechnikovKernel>(referenceData, queryData, single, naive,
- base, k, indices, products, ek);
- }
- else if (kernelType == "triangular")
- {
- TriangularKernel tk(bandwidth);
- RunFastMKS<TriangularKernel>(referenceData, queryData, single, naive,
- base, k, indices, products, tk);
- }
- else if (kernelType == "hyptan")
- {
- HyperbolicTangentKernel htk(scale, offset);
- RunFastMKS<HyperbolicTangentKernel>(referenceData, queryData, single,
- naive, base, k, indices, products, htk);
- }
- }
-
- // Save output, if we were asked to.
- if (CLI::HasParam("products_file"))
- {
- const string productsFile = CLI::GetParam<string>("products_file");
- data::Save(productsFile, products, false);
- }
-
- if (CLI::HasParam("indices_file"))
- {
- const string indicesFile = CLI::GetParam<string>("indices_file");
- data::Save(indicesFile, indices, false);
- }
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/fastmks/fastmks_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,325 @@
+/**
+ * @file fastmks_main.cpp
+ * @author Ryan Curtin
+ *
+ * Main executable for maximum inner product search.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/kernels/linear_kernel.hpp>
+#include <mlpack/core/kernels/polynomial_kernel.hpp>
+#include <mlpack/core/kernels/cosine_distance.hpp>
+#include <mlpack/core/kernels/gaussian_kernel.hpp>
+#include <mlpack/core/kernels/hyperbolic_tangent_kernel.hpp>
+#include <mlpack/core/kernels/triangular_kernel.hpp>
+#include <mlpack/core/kernels/epanechnikov_kernel.hpp>
+
+#include "fastmks.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::fastmks;
+using namespace mlpack::kernel;
+using namespace mlpack::tree;
+
+PROGRAM_INFO("FastMKS (Fast Max-Kernel Search)",
+ "This program will find the k maximum kernel of a set of points, "
+ "using a query set and a reference set (which can optionally be the same "
+ "set). More specifically, for each point in the query set, the k points in"
+ " the reference set with maximum kernel evaluations are found. The kernel "
+ "function used is specified by --kernel."
+ "\n\n"
+ "For example, the following command will calculate, for each point in "
+ "'query.csv', the five points in 'reference.csv' with maximum kernel "
+ "evaluation using the linear kernel. The kernel evaluations are stored in "
+ "'kernels.csv' and the indices are stored in 'indices.csv'."
+ "\n\n"
+ "$ fastmks --k 5 --reference_file reference.csv --query_file query.csv\n"
+ " --indices_file indices.csv --products_file kernels.csv --kernel linear"
+ "\n\n"
+ "The output files are organized such that row i and column j in the indices"
+ " output file corresponds to the index of the point in the reference set "
+ "that has i'th largest kernel evaluation with the point in the query set "
+ "with index j. Row i and column j in the products output file corresponds "
+ "to the kernel evaluation between those two points."
+ "\n\n"
+ "This executable performs FastMKS using a cover tree. The base used to "
+ "build the cover tree can be specified with the --base option.");
+
+// Define our input parameters.
+PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
+ "r");
+PARAM_STRING("query_file", "File containing the query dataset.", "q", "");
+
+PARAM_INT_REQ("k", "Number of maximum inner products to find.", "k");
+
+PARAM_STRING("products_file", "File to save inner products into.", "p", "");
+PARAM_STRING("indices_file", "File to save indices of inner products into.",
+ "i", "");
+
+PARAM_STRING("kernel", "Kernel type to use: 'linear', 'polynomial', 'cosine', "
+ "'gaussian', 'epanechnikov', 'triangular', 'hyptan'.", "K", "linear");
+
+PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
+PARAM_FLAG("single", "If true, single-tree search is used (as opposed to "
+ "dual-tree search.", "s");
+
+// Cover tree parameter.
+PARAM_DOUBLE("base", "Base to use during cover tree construction.", "b", 2.0);
+
+// Kernel parameters.
+PARAM_DOUBLE("degree", "Degree of polynomial kernel.", "d", 2.0);
+PARAM_DOUBLE("offset", "Offset of kernel (for polynomial and hyptan kernels).",
+ "o", 0.0);
+PARAM_DOUBLE("bandwidth", "Bandwidth (for Gaussian, Epanechnikov, and "
+ "triangular kernels).", "w", 1.0);
+PARAM_DOUBLE("scale", "Scale of kernel (for hyptan kernel).", "s", 1.0);
+
+//! Run FastMKS on a single dataset for the given kernel type.
+template<typename KernelType>
+void RunFastMKS(const arma::mat& referenceData,
+ const bool single,
+ const bool naive,
+ const double base,
+ const size_t k,
+ arma::Mat<size_t>& indices,
+ arma::mat& products,
+ KernelType& kernel)
+{
+ // Create the tree with the specified base.
+ typedef CoverTree<IPMetric<KernelType>, FirstPointIsRoot, FastMKSStat>
+ TreeType;
+ IPMetric<KernelType> metric(kernel);
+ TreeType tree(referenceData, metric, base);
+
+ // Create FastMKS object.
+ FastMKS<KernelType> fastmks(referenceData, &tree, (single && !naive), naive);
+
+ // Now search with it.
+ fastmks.Search(k, indices, products);
+}
+
+//! Run FastMKS for a given query and reference set using the given kernel type.
+template<typename KernelType>
+void RunFastMKS(const arma::mat& referenceData,
+ const arma::mat& queryData,
+ const bool single,
+ const bool naive,
+ const double base,
+ const size_t k,
+ arma::Mat<size_t>& indices,
+ arma::mat& products,
+ KernelType& kernel)
+{
+ // Create the tree with the specified base.
+ typedef CoverTree<IPMetric<KernelType>, FirstPointIsRoot, FastMKSStat>
+ TreeType;
+ IPMetric<KernelType> metric(kernel);
+ TreeType referenceTree(referenceData, metric, base);
+ TreeType queryTree(queryData, metric, base);
+
+ // Create FastMKS object.
+ FastMKS<KernelType> fastmks(referenceData, &referenceTree, queryData,
+ &queryTree, (single && !naive), naive);
+
+ // Now search with it.
+ fastmks.Search(k, indices, products);
+}
+
+int main(int argc, char** argv)
+{
+ CLI::ParseCommandLine(argc, argv);
+
+ // Get reference dataset filename.
+ const string referenceFile = CLI::GetParam<string>("reference_file");
+
+ // The number of max kernel values to find.
+ const size_t k = CLI::GetParam<int>("k");
+
+ // Runtime parameters.
+ const bool naive = CLI::HasParam("naive");
+ const bool single = CLI::HasParam("single");
+
+ // For cover tree construction.
+ const double base = CLI::GetParam<double>("base");
+
+ // Kernel parameters.
+ const string kernelType = CLI::GetParam<string>("kernel");
+ const double degree = CLI::GetParam<double>("degree");
+ const double offset = CLI::GetParam<double>("offset");
+ const double bandwidth = CLI::GetParam<double>("bandwidth");
+ const double scale = CLI::GetParam<double>("scale");
+
+ // The datasets. The query matrix may never be used.
+ arma::mat referenceData;
+ arma::mat queryData;
+
+ data::Load(referenceFile, referenceData, true);
+
+ Log::Info << "Loaded reference data from '" << referenceFile << "' ("
+ << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
+
+ // Sanity check on k value.
+ if (k > referenceData.n_cols)
+ {
+ Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
+ Log::Fatal << "than or equal to the number of reference points (";
+ Log::Fatal << referenceData.n_cols << ")." << endl;
+ }
+
+ // Check on kernel type.
+ if ((kernelType != "linear") && (kernelType != "polynomial") &&
+ (kernelType != "cosine") && (kernelType != "gaussian") &&
+ (kernelType != "graph") && (kernelType != "approxGraph") &&
+ (kernelType != "triangular") && (kernelType != "hyptan") &&
+ (kernelType != "inv-mq") && (kernelType != "epanechnikov"))
+ {
+ Log::Fatal << "Invalid kernel type: '" << kernelType << "'; must be ";
+ Log::Fatal << "'linear' or 'polynomial'." << endl;
+ }
+
+ // Load the query matrix, if we can.
+ if (CLI::HasParam("query_file"))
+ {
+ const string queryFile = CLI::GetParam<string>("query_file");
+ data::Load(queryFile, queryData, true);
+
+ Log::Info << "Loaded query data from '" << queryFile << "' ("
+ << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+ }
+ else
+ {
+ Log::Info << "Using reference dataset as query dataset (--query_file not "
+ << "specified)." << endl;
+ }
+
+ // Naive mode overrides single mode.
+ if (naive && single)
+ {
+ Log::Warn << "--single ignored because --naive is present." << endl;
+ }
+
+ // Matrices for output storage.
+ arma::Mat<size_t> indices;
+ arma::mat products;
+
+ // Construct FastMKS object.
+ if (queryData.n_elem == 0)
+ {
+ if (kernelType == "linear")
+ {
+ LinearKernel lk;
+ RunFastMKS<LinearKernel>(referenceData, single, naive, base, k, indices,
+ products, lk);
+ }
+ else if (kernelType == "polynomial")
+ {
+
+ PolynomialKernel pk(degree, offset);
+ RunFastMKS<PolynomialKernel>(referenceData, single, naive, base, k,
+ indices, products, pk);
+ }
+ else if (kernelType == "cosine")
+ {
+ CosineDistance cd;
+ RunFastMKS<CosineDistance>(referenceData, single, naive, base, k, indices,
+ products, cd);
+ }
+ else if (kernelType == "gaussian")
+ {
+ GaussianKernel gk(bandwidth);
+ RunFastMKS<GaussianKernel>(referenceData, single, naive, base, k, indices,
+ products, gk);
+ }
+ else if (kernelType == "epanechnikov")
+ {
+ EpanechnikovKernel ek(bandwidth);
+ RunFastMKS<EpanechnikovKernel>(referenceData, single, naive, base, k,
+ indices, products, ek);
+ }
+ else if (kernelType == "triangular")
+ {
+ TriangularKernel tk(bandwidth);
+ RunFastMKS<TriangularKernel>(referenceData, single, naive, base, k,
+ indices, products, tk);
+ }
+ else if (kernelType == "hyptan")
+ {
+ HyperbolicTangentKernel htk(scale, offset);
+ RunFastMKS<HyperbolicTangentKernel>(referenceData, single, naive, base, k,
+ indices, products, htk);
+ }
+ }
+ else
+ {
+ if (kernelType == "linear")
+ {
+ LinearKernel lk;
+ RunFastMKS<LinearKernel>(referenceData, queryData, single, naive, base, k,
+ indices, products, lk);
+ }
+ else if (kernelType == "polynomial")
+ {
+ PolynomialKernel pk(degree, offset);
+ RunFastMKS<PolynomialKernel>(referenceData, queryData, single, naive,
+ base, k, indices, products, pk);
+ }
+ else if (kernelType == "cosine")
+ {
+ CosineDistance cd;
+ RunFastMKS<CosineDistance>(referenceData, queryData, single, naive, base,
+ k, indices, products, cd);
+ }
+ else if (kernelType == "gaussian")
+ {
+ GaussianKernel gk(bandwidth);
+ RunFastMKS<GaussianKernel>(referenceData, queryData, single, naive, base,
+ k, indices, products, gk);
+ }
+ else if (kernelType == "epanechnikov")
+ {
+ EpanechnikovKernel ek(bandwidth);
+ RunFastMKS<EpanechnikovKernel>(referenceData, queryData, single, naive,
+ base, k, indices, products, ek);
+ }
+ else if (kernelType == "triangular")
+ {
+ TriangularKernel tk(bandwidth);
+ RunFastMKS<TriangularKernel>(referenceData, queryData, single, naive,
+ base, k, indices, products, tk);
+ }
+ else if (kernelType == "hyptan")
+ {
+ HyperbolicTangentKernel htk(scale, offset);
+ RunFastMKS<HyperbolicTangentKernel>(referenceData, queryData, single,
+ naive, base, k, indices, products, htk);
+ }
+ }
+
+ // Save output, if we were asked to.
+ if (CLI::HasParam("products_file"))
+ {
+ const string productsFile = CLI::GetParam<string>("products_file");
+ data::Save(productsFile, products, false);
+ }
+
+ if (CLI::HasParam("indices_file"))
+ {
+ const string indicesFile = CLI::GetParam<string>("indices_file");
+ data::Save(indicesFile, indices, false);
+ }
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_rules.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/fastmks/fastmks_rules.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_rules.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,160 +0,0 @@
-/**
- * @file fastmks_rules.hpp
- * @author Ryan Curtin
- *
- * Rules for the single or dual tree traversal for fast max-kernel search.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_FASTMKS_FASTMKS_RULES_HPP
-#define __MLPACK_METHODS_FASTMKS_FASTMKS_RULES_HPP
-
-#include <mlpack/core.hpp>
-#include <mlpack/core/tree/cover_tree/cover_tree.hpp>
-
-namespace mlpack {
-namespace fastmks {
-
-/**
- * The base case and pruning rules for FastMKS (fast max-kernel search).
- */
-template<typename KernelType, typename TreeType>
-class FastMKSRules
-{
- public:
- FastMKSRules(const arma::mat& referenceSet,
- const arma::mat& querySet,
- arma::Mat<size_t>& indices,
- arma::mat& products,
- KernelType& kernel);
-
- //! Compute the base case (kernel value) between two points.
- double BaseCase(const size_t queryIndex, const size_t referenceIndex);
-
- /**
- * Get the score for recursion order. A low score indicates priority for
- * recursion, while DBL_MAX indicates that the node should not be recursed
- * into at all (it should be pruned).
- *
- * @param queryIndex Index of query point.
- * @param referenceNode Candidate to be recursed into.
- */
- double Score(const size_t queryIndex, TreeType& referenceNode) const;
-
- /**
- * Get the score for recursion order, passing the base case result (in the
- * situation where it may be needed to calculate the recursion order). A low
- * score indicates priority for recursion, while DBL_MAX indicates that the
- * node should not be recursed into at all (it should be pruned).
- *
- * @param queryIndex Index of query point.
- * @param referenceNode Candidate node to be recursed into.
- * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
- */
- double Score(const size_t queryIndex,
- TreeType& referenceNode,
- const double baseCaseResult) const;
-
- /**
- * Get the score for recursion order. A low score indicates priority for
- * recursion, while DBL_MAX indicates that the node should not be recursed
- * into at all (it should be pruned).
- *
- * @param queryNode Candidate query node to be recursed into.
- * @param referenceNode Candidate reference node to be recursed into.
- */
- double Score(TreeType& queryNode, TreeType& referenceNode) const;
-
- /**
- * Get the score for recursion order, passing the base case result (in the
- * situation where it may be needed to calculate the recursion order). A low
- * score indicates priority for recursion, while DBL_MAX indicates that the
- * node should not be recursed into at all (it should be pruned).
- *
- * @param queryNode Candidate query node to be recursed into.
- * @param referenceNode Candidate reference node to be recursed into.
- * @param baseCaseResult Result of BaseCase(queryNode, referenceNode).
- */
- double Score(TreeType& queryNode,
- TreeType& referenceNode,
- const double baseCaseResult) const;
-
- /**
- * Re-evaluate the score for recursion order. A low score indicates priority
- * for recursion, while DBL_MAX indicates that a node should not be recursed
- * into at all (it should be pruned). This is used when the score has already
- * been calculated, but another recursion may have modified the bounds for
- * pruning. So the old score is checked against the new pruning bound.
- *
- * @param queryIndex Index of query point.
- * @param referenceNode Candidate node to be recursed into.
- * @param oldScore Old score produced by Score() (or Rescore()).
- */
- double Rescore(const size_t queryIndex,
- TreeType& referenceNode,
- const double oldScore) const;
-
- /**
- * Re-evaluate the score for recursion order. A low score indicates priority
- * for recursion, while DBL_MAX indicates that a node should not be recursed
- * into at all (it should be pruned). This is used when the score has already
- * been calculated, but another recursion may have modified the bounds for
- * pruning. So the old score is checked against the new pruning bound.
- *
- * @param queryNode Candidate query node to be recursed into.
- * @param referenceNode Candidate reference node to be recursed into.
- * @param oldScore Old score produced by Score() (or Rescore()).
- */
- double Rescore(TreeType& queryNode,
- TreeType& referenceNode,
- const double oldScore) const;
-
- private:
- //! The reference dataset.
- const arma::mat& referenceSet;
- //! The query dataset.
- const arma::mat& querySet;
-
- //! The indices of the maximum kernel results.
- arma::Mat<size_t>& indices;
- //! The maximum kernels.
- arma::mat& products;
-
- //! Cached query set self-kernels (|| q || for each q).
- arma::vec queryKernels;
- //! Cached reference set self-kernels (|| r || for each r).
- arma::vec referenceKernels;
-
- //! The instantiated kernel.
- KernelType& kernel;
-
- //! Calculate the bound for a given query node.
- double CalculateBound(TreeType& queryNode) const;
-
- //! Utility function to insert neighbor into list of results.
- void InsertNeighbor(const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance);
-};
-
-}; // namespace fastmks
-}; // namespace mlpack
-
-// Include implementation.
-#include "fastmks_rules_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_rules.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/fastmks/fastmks_rules.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_rules.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_rules.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,160 @@
+/**
+ * @file fastmks_rules.hpp
+ * @author Ryan Curtin
+ *
+ * Rules for the single or dual tree traversal for fast max-kernel search.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_FASTMKS_FASTMKS_RULES_HPP
+#define __MLPACK_METHODS_FASTMKS_FASTMKS_RULES_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/tree/cover_tree/cover_tree.hpp>
+
+namespace mlpack {
+namespace fastmks {
+
+/**
+ * The base case and pruning rules for FastMKS (fast max-kernel search).
+ */
+template<typename KernelType, typename TreeType>
+class FastMKSRules
+{
+ public:
+ FastMKSRules(const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ arma::Mat<size_t>& indices,
+ arma::mat& products,
+ KernelType& kernel);
+
+ //! Compute the base case (kernel value) between two points.
+ double BaseCase(const size_t queryIndex, const size_t referenceIndex);
+
+ /**
+ * Get the score for recursion order. A low score indicates priority for
+ * recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned).
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate to be recursed into.
+ */
+ double Score(const size_t queryIndex, TreeType& referenceNode) const;
+
+ /**
+ * Get the score for recursion order, passing the base case result (in the
+ * situation where it may be needed to calculate the recursion order). A low
+ * score indicates priority for recursion, while DBL_MAX indicates that the
+ * node should not be recursed into at all (it should be pruned).
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
+ */
+ double Score(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double baseCaseResult) const;
+
+ /**
+ * Get the score for recursion order. A low score indicates priority for
+ * recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned).
+ *
+ * @param queryNode Candidate query node to be recursed into.
+ * @param referenceNode Candidate reference node to be recursed into.
+ */
+ double Score(TreeType& queryNode, TreeType& referenceNode) const;
+
+ /**
+ * Get the score for recursion order, passing the base case result (in the
+ * situation where it may be needed to calculate the recursion order). A low
+ * score indicates priority for recursion, while DBL_MAX indicates that the
+ * node should not be recursed into at all (it should be pruned).
+ *
+ * @param queryNode Candidate query node to be recursed into.
+ * @param referenceNode Candidate reference node to be recursed into.
+ * @param baseCaseResult Result of BaseCase(queryNode, referenceNode).
+ */
+ double Score(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double baseCaseResult) const;
+
+ /**
+ * Re-evaluate the score for recursion order. A low score indicates priority
+ * for recursion, while DBL_MAX indicates that a node should not be recursed
+ * into at all (it should be pruned). This is used when the score has already
+ * been calculated, but another recursion may have modified the bounds for
+ * pruning. So the old score is checked against the new pruning bound.
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ * @param oldScore Old score produced by Score() (or Rescore()).
+ */
+ double Rescore(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double oldScore) const;
+
+ /**
+ * Re-evaluate the score for recursion order. A low score indicates priority
+ * for recursion, while DBL_MAX indicates that a node should not be recursed
+ * into at all (it should be pruned). This is used when the score has already
+ * been calculated, but another recursion may have modified the bounds for
+ * pruning. So the old score is checked against the new pruning bound.
+ *
+ * @param queryNode Candidate query node to be recursed into.
+ * @param referenceNode Candidate reference node to be recursed into.
+ * @param oldScore Old score produced by Score() (or Rescore()).
+ */
+ double Rescore(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double oldScore) const;
+
+ private:
+ //! The reference dataset.
+ const arma::mat& referenceSet;
+ //! The query dataset.
+ const arma::mat& querySet;
+
+ //! The indices of the maximum kernel results.
+ arma::Mat<size_t>& indices;
+ //! The maximum kernels.
+ arma::mat& products;
+
+ //! Cached query set self-kernels (|| q || for each q).
+ arma::vec queryKernels;
+ //! Cached reference set self-kernels (|| r || for each r).
+ arma::vec referenceKernels;
+
+ //! The instantiated kernel.
+ KernelType& kernel;
+
+ //! Calculate the bound for a given query node.
+ double CalculateBound(TreeType& queryNode) const;
+
+ //! Utility function to insert neighbor into list of results.
+ void InsertNeighbor(const size_t queryIndex,
+ const size_t pos,
+ const size_t neighbor,
+ const double distance);
+};
+
+}; // namespace fastmks
+}; // namespace mlpack
+
+// Include implementation.
+#include "fastmks_rules_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,299 +0,0 @@
-/**
- * @file fastmks_rules_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of FastMKSRules for cover tree search.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_FASTMKS_FASTMKS_RULES_IMPL_HPP
-#define __MLPACK_METHODS_FASTMKS_FASTMKS_RULES_IMPL_HPP
-
-// In case it hasn't already been included.
-#include "fastmks_rules.hpp"
-
-namespace mlpack {
-namespace fastmks {
-
-template<typename KernelType, typename TreeType>
-FastMKSRules<KernelType, TreeType>::FastMKSRules(const arma::mat& referenceSet,
- const arma::mat& querySet,
- arma::Mat<size_t>& indices,
- arma::mat& products,
- KernelType& kernel) :
- referenceSet(referenceSet),
- querySet(querySet),
- indices(indices),
- products(products),
- kernel(kernel)
-{
- // Precompute each self-kernel.
- queryKernels.set_size(querySet.n_cols);
- for (size_t i = 0; i < querySet.n_cols; ++i)
- queryKernels[i] = sqrt(kernel.Evaluate(querySet.unsafe_col(i),
- querySet.unsafe_col(i)));
-
- referenceKernels.set_size(referenceSet.n_cols);
- for (size_t i = 0; i < referenceSet.n_cols; ++i)
- referenceKernels[i] = sqrt(kernel.Evaluate(referenceSet.unsafe_col(i),
- referenceSet.unsafe_col(i)));
-}
-
-template<typename KernelType, typename TreeType>
-inline force_inline
-double FastMKSRules<KernelType, TreeType>::BaseCase(
- const size_t queryIndex,
- const size_t referenceIndex)
-{
-
- double kernelEval = kernel.Evaluate(querySet.unsafe_col(queryIndex),
- referenceSet.unsafe_col(referenceIndex));
-
- // If the reference and query sets are identical, we still need to compute the
- // base case (so that things can be bounded properly), but we won't add it to
- // the results.
- if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
- return kernelEval;
-
- // If this is a better candidate, insert it into the list.
- if (kernelEval < products(products.n_rows - 1, queryIndex))
- return kernelEval;
-
- size_t insertPosition = 0;
- for ( ; insertPosition < products.n_rows; ++insertPosition)
- if (kernelEval >= products(insertPosition, queryIndex))
- break;
-
- InsertNeighbor(queryIndex, insertPosition, referenceIndex, kernelEval);
-
- return kernelEval;
-}
-
-template<typename MetricType, typename TreeType>
-double FastMKSRules<MetricType, TreeType>::Score(const size_t queryIndex,
- TreeType& referenceNode) const
-{
- // Calculate the maximum possible kernel value.
- const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
- const arma::vec refCentroid;
- referenceNode.Bound().Centroid(refCentroid);
-
- const double maxKernel = kernel.Evaluate(queryPoint, refCentroid) +
- referenceNode.FurthestDescendantDistance() * queryKernels[queryIndex];
-
- // Compare with the current best.
- const double bestKernel = products(products.n_rows - 1, queryIndex);
-
- // We return the inverse of the maximum kernel so that larger kernels are
- // recursed into first.
- return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
-}
-
-template<typename MetricType, typename TreeType>
-double FastMKSRules<MetricType, TreeType>::Score(
- const size_t queryIndex,
- TreeType& referenceNode,
- const double baseCaseResult) const
-{
- // We already have the base case result. Add the bound.
- const double maxKernel = baseCaseResult +
- referenceNode.FurthestDescendantDistance() * queryKernels[queryIndex];
- const double bestKernel = products(products.n_rows - 1, queryIndex);
-
- // We return the inverse of the maximum kernel so that larger kernels are
- // recursed into first.
- return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
-}
-
-template<typename MetricType, typename TreeType>
-double FastMKSRules<MetricType, TreeType>::Score(TreeType& queryNode,
- TreeType& referenceNode) const
-{
- // Calculate the maximum possible kernel value.
- const arma::vec queryCentroid;
- const arma::vec refCentroid;
- queryNode.Bound().Centroid(queryCentroid);
- referenceNode.Bound().Centroid(refCentroid);
-
- const double refKernelTerm = queryNode.FurthestDescendantDistance() *
- referenceNode.Stat().SelfKernel();
- const double queryKernelTerm = referenceNode.FurthestDescendantDistance() *
- queryNode.Stat().SelfKernel();
-
- const double maxKernel = kernel.Evaluate(queryCentroid, refCentroid) +
- refKernelTerm + queryKernelTerm +
- (queryNode.FurthestDescendantDistance() *
- referenceNode.FurthestDescendantDistance());
-
- // The existing bound.
- queryNode.Stat().Bound() = CalculateBound(queryNode);
- const double bestKernel = queryNode.Stat().Bound();
-
- // We return the inverse of the maximum kernel so that larger kernels are
- // recursed into first.
- return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
-}
-
-template<typename MetricType, typename TreeType>
-double FastMKSRules<MetricType, TreeType>::Score(
- TreeType& queryNode,
- TreeType& referenceNode,
- const double baseCaseResult) const
-{
- // We already have the base case, so we need to add the bounds.
- const double refKernelTerm = queryNode.FurthestDescendantDistance() *
- referenceNode.Stat().SelfKernel();
- const double queryKernelTerm = referenceNode.FurthestDescendantDistance() *
- queryNode.Stat().SelfKernel();
-
- const double maxKernel = baseCaseResult + refKernelTerm + queryKernelTerm +
- (queryNode.FurthestDescendantDistance() *
- referenceNode.FurthestDescendantDistance());
-
- // The existing bound.
- queryNode.Stat().Bound() = CalculateBound(queryNode);
- const double bestKernel = queryNode.Stat().Bound();
-
- // We return the inverse of the maximum kernel so that larger kernels are
- // recursed into first.
- return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
-}
-
-template<typename MetricType, typename TreeType>
-double FastMKSRules<MetricType, TreeType>::Rescore(const size_t queryIndex,
- TreeType& /*referenceNode*/,
- const double oldScore) const
-{
- const double bestKernel = products(products.n_rows - 1, queryIndex);
-
- return ((1.0 / oldScore) > bestKernel) ? oldScore : DBL_MAX;
-}
-
-template<typename MetricType, typename TreeType>
-double FastMKSRules<MetricType, TreeType>::Rescore(TreeType& queryNode,
- TreeType& /*referenceNode*/,
- const double oldScore) const
-{
- queryNode.Stat().Bound() = CalculateBound(queryNode);
- const double bestKernel = queryNode.Stat().Bound();
-
- return ((1.0 / oldScore) > bestKernel) ? oldScore : DBL_MAX;
-}
-
-/**
- * Calculate the bound for the given query node. This bound represents the
- * minimum value which a node combination must achieve to guarantee an
- * improvement in the results.
- *
- * @param queryNode Query node to calculate bound for.
- */
-template<typename MetricType, typename TreeType>
-double FastMKSRules<MetricType, TreeType>::CalculateBound(TreeType& queryNode)
- const
-{
- // We have four possible bounds -- just like NeighborSearchRules, but they are
- // slightly different in this context.
- //
- // (1) min ( min_{all points p in queryNode} P_p[k],
- // min_{all children c in queryNode} B(c) );
- // (2) max_{all points p in queryNode} P_p[k] + (worst child distance + worst
- // descendant distance) sqrt(K(I_p[k], I_p[k]));
- // (3) max_{all children c in queryNode} B(c) + <-- not done yet. ignored.
- // (4) B(parent of queryNode);
- double worstPointKernel = DBL_MAX;
- double bestAdjustedPointKernel = -DBL_MAX;
-// double bestPointSelfKernel = -DBL_MAX;
- const double queryDescendantDistance = queryNode.FurthestDescendantDistance();
-
- // Loop over all points in this node to find the best and worst.
- for (size_t i = 0; i < queryNode.NumPoints(); ++i)
- {
- const size_t point = queryNode.Point(i);
- if (products(products.n_rows - 1, point) < worstPointKernel)
- worstPointKernel = products(products.n_rows - 1, point);
-
- if (products(products.n_rows - 1, point) == -DBL_MAX)
- continue; // Avoid underflow.
-
- const double candidateKernel = products(products.n_rows - 1, point) -
- (2 * queryDescendantDistance) *
- referenceKernels[indices(indices.n_rows - 1, point)];
-
- if (candidateKernel > bestAdjustedPointKernel)
- bestAdjustedPointKernel = candidateKernel;
- }
-
- // Loop over all the children in the node.
- double worstChildKernel = DBL_MAX;
-
- for (size_t i = 0; i < queryNode.NumChildren(); ++i)
- {
- if (queryNode.Child(i).Stat().Bound() < worstChildKernel)
- worstChildKernel = queryNode.Child(i).Stat().Bound();
- }
-
- // Now assemble bound (1).
- const double firstBound = (worstPointKernel < worstChildKernel) ?
- worstPointKernel : worstChildKernel;
-
- // Bound (2) is bestAdjustedPointKernel.
- const double fourthBound = (queryNode.Parent() == NULL) ? -DBL_MAX :
- queryNode.Parent()->Stat().Bound();
-
- // Pick the best of these bounds.
- const double interA = (firstBound > bestAdjustedPointKernel) ? firstBound :
- bestAdjustedPointKernel;
-// const double interA = 0.0;
- const double interB = fourthBound;
-
- return (interA > interB) ? interA : interB;
-}
-
-/**
- * Helper function to insert a point into the neighbors and distances matrices.
- *
- * @param queryIndex Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
- * @param neighbor Index of reference point which is being inserted.
- * @param distance Distance from query point to reference point.
- */
-template<typename MetricType, typename TreeType>
-void FastMKSRules<MetricType, TreeType>::InsertNeighbor(const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance)
-{
- // We only memmove() if there is actually a need to shift something.
- if (pos < (products.n_rows - 1))
- {
- int len = (products.n_rows - 1) - pos;
- memmove(products.colptr(queryIndex) + (pos + 1),
- products.colptr(queryIndex) + pos,
- sizeof(double) * len);
- memmove(indices.colptr(queryIndex) + (pos + 1),
- indices.colptr(queryIndex) + pos,
- sizeof(size_t) * len);
- }
-
- // Now put the new information in the right index.
- products(pos, queryIndex) = distance;
- indices(pos, queryIndex) = neighbor;
-}
-
-}; // namespace fastmks
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,299 @@
+/**
+ * @file fastmks_rules_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of FastMKSRules for cover tree search.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_FASTMKS_FASTMKS_RULES_IMPL_HPP
+#define __MLPACK_METHODS_FASTMKS_FASTMKS_RULES_IMPL_HPP
+
+// In case it hasn't already been included.
+#include "fastmks_rules.hpp"
+
+namespace mlpack {
+namespace fastmks {
+
+template<typename KernelType, typename TreeType>
+FastMKSRules<KernelType, TreeType>::FastMKSRules(const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ arma::Mat<size_t>& indices,
+ arma::mat& products,
+ KernelType& kernel) :
+ referenceSet(referenceSet),
+ querySet(querySet),
+ indices(indices),
+ products(products),
+ kernel(kernel)
+{
+ // Precompute each self-kernel.
+ queryKernels.set_size(querySet.n_cols);
+ for (size_t i = 0; i < querySet.n_cols; ++i)
+ queryKernels[i] = sqrt(kernel.Evaluate(querySet.unsafe_col(i),
+ querySet.unsafe_col(i)));
+
+ referenceKernels.set_size(referenceSet.n_cols);
+ for (size_t i = 0; i < referenceSet.n_cols; ++i)
+ referenceKernels[i] = sqrt(kernel.Evaluate(referenceSet.unsafe_col(i),
+ referenceSet.unsafe_col(i)));
+}
+
+template<typename KernelType, typename TreeType>
+inline force_inline
+double FastMKSRules<KernelType, TreeType>::BaseCase(
+ const size_t queryIndex,
+ const size_t referenceIndex)
+{
+
+ double kernelEval = kernel.Evaluate(querySet.unsafe_col(queryIndex),
+ referenceSet.unsafe_col(referenceIndex));
+
+ // If the reference and query sets are identical, we still need to compute the
+ // base case (so that things can be bounded properly), but we won't add it to
+ // the results.
+ if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
+ return kernelEval;
+
+ // If this is a better candidate, insert it into the list.
+ if (kernelEval < products(products.n_rows - 1, queryIndex))
+ return kernelEval;
+
+ size_t insertPosition = 0;
+ for ( ; insertPosition < products.n_rows; ++insertPosition)
+ if (kernelEval >= products(insertPosition, queryIndex))
+ break;
+
+ InsertNeighbor(queryIndex, insertPosition, referenceIndex, kernelEval);
+
+ return kernelEval;
+}
+
+template<typename MetricType, typename TreeType>
+double FastMKSRules<MetricType, TreeType>::Score(const size_t queryIndex,
+ TreeType& referenceNode) const
+{
+ // Calculate the maximum possible kernel value.
+ const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+ const arma::vec refCentroid;
+ referenceNode.Bound().Centroid(refCentroid);
+
+ const double maxKernel = kernel.Evaluate(queryPoint, refCentroid) +
+ referenceNode.FurthestDescendantDistance() * queryKernels[queryIndex];
+
+ // Compare with the current best.
+ const double bestKernel = products(products.n_rows - 1, queryIndex);
+
+ // We return the inverse of the maximum kernel so that larger kernels are
+ // recursed into first.
+ return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
+}
+
+template<typename MetricType, typename TreeType>
+double FastMKSRules<MetricType, TreeType>::Score(
+ const size_t queryIndex,
+ TreeType& referenceNode,
+ const double baseCaseResult) const
+{
+ // We already have the base case result. Add the bound.
+ const double maxKernel = baseCaseResult +
+ referenceNode.FurthestDescendantDistance() * queryKernels[queryIndex];
+ const double bestKernel = products(products.n_rows - 1, queryIndex);
+
+ // We return the inverse of the maximum kernel so that larger kernels are
+ // recursed into first.
+ return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
+}
+
+template<typename MetricType, typename TreeType>
+double FastMKSRules<MetricType, TreeType>::Score(TreeType& queryNode,
+ TreeType& referenceNode) const
+{
+ // Calculate the maximum possible kernel value.
+ const arma::vec queryCentroid;
+ const arma::vec refCentroid;
+ queryNode.Bound().Centroid(queryCentroid);
+ referenceNode.Bound().Centroid(refCentroid);
+
+ const double refKernelTerm = queryNode.FurthestDescendantDistance() *
+ referenceNode.Stat().SelfKernel();
+ const double queryKernelTerm = referenceNode.FurthestDescendantDistance() *
+ queryNode.Stat().SelfKernel();
+
+ const double maxKernel = kernel.Evaluate(queryCentroid, refCentroid) +
+ refKernelTerm + queryKernelTerm +
+ (queryNode.FurthestDescendantDistance() *
+ referenceNode.FurthestDescendantDistance());
+
+ // The existing bound.
+ queryNode.Stat().Bound() = CalculateBound(queryNode);
+ const double bestKernel = queryNode.Stat().Bound();
+
+ // We return the inverse of the maximum kernel so that larger kernels are
+ // recursed into first.
+ return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
+}
+
+template<typename MetricType, typename TreeType>
+double FastMKSRules<MetricType, TreeType>::Score(
+ TreeType& queryNode,
+ TreeType& referenceNode,
+ const double baseCaseResult) const
+{
+ // We already have the base case, so we need to add the bounds.
+ const double refKernelTerm = queryNode.FurthestDescendantDistance() *
+ referenceNode.Stat().SelfKernel();
+ const double queryKernelTerm = referenceNode.FurthestDescendantDistance() *
+ queryNode.Stat().SelfKernel();
+
+ const double maxKernel = baseCaseResult + refKernelTerm + queryKernelTerm +
+ (queryNode.FurthestDescendantDistance() *
+ referenceNode.FurthestDescendantDistance());
+
+ // The existing bound.
+ queryNode.Stat().Bound() = CalculateBound(queryNode);
+ const double bestKernel = queryNode.Stat().Bound();
+
+ // We return the inverse of the maximum kernel so that larger kernels are
+ // recursed into first.
+ return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
+}
+
+template<typename MetricType, typename TreeType>
+double FastMKSRules<MetricType, TreeType>::Rescore(const size_t queryIndex,
+ TreeType& /*referenceNode*/,
+ const double oldScore) const
+{
+ const double bestKernel = products(products.n_rows - 1, queryIndex);
+
+ return ((1.0 / oldScore) > bestKernel) ? oldScore : DBL_MAX;
+}
+
+template<typename MetricType, typename TreeType>
+double FastMKSRules<MetricType, TreeType>::Rescore(TreeType& queryNode,
+ TreeType& /*referenceNode*/,
+ const double oldScore) const
+{
+ queryNode.Stat().Bound() = CalculateBound(queryNode);
+ const double bestKernel = queryNode.Stat().Bound();
+
+ return ((1.0 / oldScore) > bestKernel) ? oldScore : DBL_MAX;
+}
+
+/**
+ * Calculate the bound for the given query node. This bound represents the
+ * minimum value which a node combination must achieve to guarantee an
+ * improvement in the results.
+ *
+ * @param queryNode Query node to calculate bound for.
+ */
+template<typename MetricType, typename TreeType>
+double FastMKSRules<MetricType, TreeType>::CalculateBound(TreeType& queryNode)
+ const
+{
+ // We have four possible bounds -- just like NeighborSearchRules, but they are
+ // slightly different in this context.
+ //
+ // (1) min ( min_{all points p in queryNode} P_p[k],
+ // min_{all children c in queryNode} B(c) );
+ // (2) max_{all points p in queryNode} P_p[k] + (worst child distance + worst
+ // descendant distance) sqrt(K(I_p[k], I_p[k]));
+ // (3) max_{all children c in queryNode} B(c) + <-- not done yet. ignored.
+ // (4) B(parent of queryNode);
+ double worstPointKernel = DBL_MAX;
+ double bestAdjustedPointKernel = -DBL_MAX;
+// double bestPointSelfKernel = -DBL_MAX;
+ const double queryDescendantDistance = queryNode.FurthestDescendantDistance();
+
+ // Loop over all points in this node to find the best and worst.
+ for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+ {
+ const size_t point = queryNode.Point(i);
+ if (products(products.n_rows - 1, point) < worstPointKernel)
+ worstPointKernel = products(products.n_rows - 1, point);
+
+ if (products(products.n_rows - 1, point) == -DBL_MAX)
+ continue; // Avoid underflow.
+
+ const double candidateKernel = products(products.n_rows - 1, point) -
+ (2 * queryDescendantDistance) *
+ referenceKernels[indices(indices.n_rows - 1, point)];
+
+ if (candidateKernel > bestAdjustedPointKernel)
+ bestAdjustedPointKernel = candidateKernel;
+ }
+
+ // Loop over all the children in the node.
+ double worstChildKernel = DBL_MAX;
+
+ for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+ {
+ if (queryNode.Child(i).Stat().Bound() < worstChildKernel)
+ worstChildKernel = queryNode.Child(i).Stat().Bound();
+ }
+
+ // Now assemble bound (1).
+ const double firstBound = (worstPointKernel < worstChildKernel) ?
+ worstPointKernel : worstChildKernel;
+
+ // Bound (2) is bestAdjustedPointKernel.
+ const double fourthBound = (queryNode.Parent() == NULL) ? -DBL_MAX :
+ queryNode.Parent()->Stat().Bound();
+
+ // Pick the best of these bounds.
+ const double interA = (firstBound > bestAdjustedPointKernel) ? firstBound :
+ bestAdjustedPointKernel;
+// const double interA = 0.0;
+ const double interB = fourthBound;
+
+ return (interA > interB) ? interA : interB;
+}
+
+/**
+ * Helper function to insert a point into the neighbors and distances matrices.
+ *
+ * @param queryIndex Index of point whose neighbors we are inserting into.
+ * @param pos Position in list to insert into.
+ * @param neighbor Index of reference point which is being inserted.
+ * @param distance Distance from query point to reference point.
+ */
+template<typename MetricType, typename TreeType>
+void FastMKSRules<MetricType, TreeType>::InsertNeighbor(const size_t queryIndex,
+ const size_t pos,
+ const size_t neighbor,
+ const double distance)
+{
+ // We only memmove() if there is actually a need to shift something.
+ if (pos < (products.n_rows - 1))
+ {
+ int len = (products.n_rows - 1) - pos;
+ memmove(products.colptr(queryIndex) + (pos + 1),
+ products.colptr(queryIndex) + pos,
+ sizeof(double) * len);
+ memmove(indices.colptr(queryIndex) + (pos + 1),
+ indices.colptr(queryIndex) + pos,
+ sizeof(size_t) * len);
+ }
+
+ // Now put the new information in the right index.
+ products(pos, queryIndex) = distance;
+ indices(pos, queryIndex) = neighbor;
+}
+
+}; // namespace fastmks
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_stat.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/fastmks/fastmks_stat.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_stat.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,104 +0,0 @@
-/**
- * @file fastmks_stat.hpp
- * @author Ryan Curtin
- *
- * The statistic used in trees with FastMKS.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_FASTMKS_FASTMKS_STAT_HPP
-#define __MLPACK_METHODS_FASTMKS_FASTMKS_STAT_HPP
-
-#include <mlpack/core.hpp>
-#include <mlpack/core/tree/tree_traits.hpp>
-
-namespace mlpack {
-namespace fastmks {
-
-/**
- * The statistic used in trees with FastMKS. This stores both the bound and the
- * self-kernels for each node in the tree.
- */
-class FastMKSStat
-{
- public:
- /**
- * Default initialization.
- */
- FastMKSStat() : bound(-DBL_MAX), selfKernel(0.0) { }
-
- /**
- * Initialize this statistic for the given tree node. The TreeType's metric
- * better be IPMetric with some kernel type (that is, Metric().Kernel() must
- * exist).
- *
- * @param node Node that this statistic is built for.
- */
- template<typename TreeType>
- FastMKSStat(const TreeType& node) :
- bound(-DBL_MAX)
- {
- // Do we have to calculate the centroid?
- if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
- {
- // If this type of tree has self-children, then maybe the evaluation is
- // already done. These statistics are built bottom-up, so the child stat
- // should already be done.
- if ((tree::TreeTraits<TreeType>::HasSelfChildren) &&
- (node.NumChildren() > 0) &&
- (node.Point(0) == node.Child(0).Point(0)))
- {
- selfKernel = node.Child(0).Stat().SelfKernel();
- }
- else
- {
- selfKernel = sqrt(node.Metric().Kernel().Evaluate(
- node.Dataset().unsafe_col(node.Point(0)),
- node.Dataset().unsafe_col(node.Point(0))));
- }
- }
- else
- {
- // Calculate the centroid.
- arma::vec centroid;
- node.Centroid(centroid);
-
- selfKernel = sqrt(node.Metric().Kernel().Evaluate(centroid, centroid));
- }
- }
-
- //! Get the self-kernel.
- double SelfKernel() const { return selfKernel; }
- //! Modify the self-kernel.
- double& SelfKernel() { return selfKernel; }
-
- //! Get the bound.
- double Bound() const { return bound; }
- //! Modify the bound.
- double& Bound() { return bound; }
-
- private:
- //! The bound for pruning.
- double bound;
-
- //! The self-kernel evaluation: sqrt(K(centroid, centroid)).
- double selfKernel;
-};
-
-}; // namespace fastmks
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_stat.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/fastmks/fastmks_stat.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_stat.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/fastmks_stat.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,104 @@
+/**
+ * @file fastmks_stat.hpp
+ * @author Ryan Curtin
+ *
+ * The statistic used in trees with FastMKS.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_FASTMKS_FASTMKS_STAT_HPP
+#define __MLPACK_METHODS_FASTMKS_FASTMKS_STAT_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/tree/tree_traits.hpp>
+
+namespace mlpack {
+namespace fastmks {
+
+/**
+ * The statistic used in trees with FastMKS. This stores both the bound and the
+ * self-kernels for each node in the tree.
+ */
+class FastMKSStat
+{
+ public:
+ /**
+ * Default initialization.
+ */
+ FastMKSStat() : bound(-DBL_MAX), selfKernel(0.0) { }
+
+ /**
+ * Initialize this statistic for the given tree node. The TreeType's metric
+ * better be IPMetric with some kernel type (that is, Metric().Kernel() must
+ * exist).
+ *
+ * @param node Node that this statistic is built for.
+ */
+ template<typename TreeType>
+ FastMKSStat(const TreeType& node) :
+ bound(-DBL_MAX)
+ {
+ // Do we have to calculate the centroid?
+ if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
+ {
+ // If this type of tree has self-children, then maybe the evaluation is
+ // already done. These statistics are built bottom-up, so the child stat
+ // should already be done.
+ if ((tree::TreeTraits<TreeType>::HasSelfChildren) &&
+ (node.NumChildren() > 0) &&
+ (node.Point(0) == node.Child(0).Point(0)))
+ {
+ selfKernel = node.Child(0).Stat().SelfKernel();
+ }
+ else
+ {
+ selfKernel = sqrt(node.Metric().Kernel().Evaluate(
+ node.Dataset().unsafe_col(node.Point(0)),
+ node.Dataset().unsafe_col(node.Point(0))));
+ }
+ }
+ else
+ {
+ // Calculate the centroid.
+ arma::vec centroid;
+ node.Centroid(centroid);
+
+ selfKernel = sqrt(node.Metric().Kernel().Evaluate(centroid, centroid));
+ }
+ }
+
+ //! Get the self-kernel.
+ double SelfKernel() const { return selfKernel; }
+ //! Modify the self-kernel.
+ double& SelfKernel() { return selfKernel; }
+
+ //! Get the bound.
+ double Bound() const { return bound; }
+ //! Modify the bound.
+ double& Bound() { return bound; }
+
+ private:
+ //! The bound for pruning.
+ double bound;
+
+ //! The self-kernel evaluation: sqrt(K(centroid, centroid)).
+ double selfKernel;
+};
+
+}; // namespace fastmks
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/ip_metric.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/fastmks/ip_metric.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/ip_metric.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,66 +0,0 @@
-/**
- * @file ip_metric.hpp
- * @author Ryan Curtin
- *
- * Inner product induced metric. If given a kernel function, this gives the
- * complementary metric.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_FASTMKS_IP_METRIC_HPP
-#define __MLPACK_METHODS_FASTMKS_IP_METRIC_HPP
-
-namespace mlpack {
-namespace fastmks /** Fast maximum kernel search. */ {
-
-template<typename KernelType>
-class IPMetric
-{
- public:
- //! Create the IPMetric without an instantiated kernel.
- IPMetric();
-
- //! Create the IPMetric with an instantiated kernel.
- IPMetric(KernelType& kernel);
-
- //! Destroy the IPMetric object.
- ~IPMetric();
-
- /**
- * Evaluate the metric.
- */
- template<typename Vec1Type, typename Vec2Type>
- double Evaluate(const Vec1Type& a, const Vec2Type& b);
-
- //! Get the kernel.
- const KernelType& Kernel() const { return kernel; }
- //! Modify the kernel.
- KernelType& Kernel() { return kernel; }
-
- private:
- //! The locally stored kernel, if it is necessary.
- KernelType* localKernel;
- //! The reference to the kernel that is being used.
- KernelType& kernel;
-};
-
-}; // namespace fastmks
-}; // namespace mlpack
-
-// Include implementation.
-#include "ip_metric_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/ip_metric.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/fastmks/ip_metric.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/ip_metric.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/ip_metric.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,66 @@
+/**
+ * @file ip_metric.hpp
+ * @author Ryan Curtin
+ *
+ * Inner product induced metric. If given a kernel function, this gives the
+ * complementary metric.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_FASTMKS_IP_METRIC_HPP
+#define __MLPACK_METHODS_FASTMKS_IP_METRIC_HPP
+
+namespace mlpack {
+namespace fastmks /** Fast maximum kernel search. */ {
+
+template<typename KernelType>
+class IPMetric
+{
+ public:
+ //! Create the IPMetric without an instantiated kernel.
+ IPMetric();
+
+ //! Create the IPMetric with an instantiated kernel.
+ IPMetric(KernelType& kernel);
+
+ //! Destroy the IPMetric object.
+ ~IPMetric();
+
+ /**
+ * Evaluate the metric.
+ */
+ template<typename Vec1Type, typename Vec2Type>
+ double Evaluate(const Vec1Type& a, const Vec2Type& b);
+
+ //! Get the kernel.
+ const KernelType& Kernel() const { return kernel; }
+ //! Modify the kernel.
+ KernelType& Kernel() { return kernel; }
+
+ private:
+ //! The locally stored kernel, if it is necessary.
+ KernelType* localKernel;
+ //! The reference to the kernel that is being used.
+ KernelType& kernel;
+};
+
+}; // namespace fastmks
+}; // namespace mlpack
+
+// Include implementation.
+#include "ip_metric_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/ip_metric_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/fastmks/ip_metric_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/ip_metric_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,84 +0,0 @@
-/**
- * @file ip_metric_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of the IPMetric.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_FASTMKS_IP_METRIC_IMPL_HPP
-#define __MLPACK_METHODS_FASTMKS_IP_METRIC_IMPL_HPP
-
-// In case it hasn't been included yet.
-#include "ip_metric_impl.hpp"
-
-#include <mlpack/core/metrics/lmetric.hpp>
-#include <mlpack/core/kernels/linear_kernel.hpp>
-
-namespace mlpack {
-namespace fastmks {
-
-// Constructor with no instantiated kernel.
-template<typename KernelType>
-IPMetric<KernelType>::IPMetric() :
- localKernel(new KernelType()),
- kernel(*localKernel)
-{
- // Nothing to do.
-}
-
-// Constructor with instantiated kernel.
-template<typename KernelType>
-IPMetric<KernelType>::IPMetric(KernelType& kernel) :
- localKernel(NULL),
- kernel(kernel)
-{
- // Nothing to do.
-}
-
-// Destructor for the IPMetric.
-template<typename KernelType>
-IPMetric<KernelType>::~IPMetric()
-{
- if (localKernel != NULL)
- delete localKernel;
-}
-
-template<typename KernelType>
-template<typename Vec1Type, typename Vec2Type>
-inline double IPMetric<KernelType>::Evaluate(const Vec1Type& a,
- const Vec2Type& b)
-{
- // This is the metric induced by the kernel function.
- // Maybe we can do better by caching some of this?
- return sqrt(kernel.Evaluate(a, a) + kernel.Evaluate(b, b) -
- 2 * kernel.Evaluate(a, b));
-}
-
-// A specialization for the linear kernel, which actually just turns out to be
-// the Euclidean distance.
-template<>
-template<typename Vec1Type, typename Vec2Type>
-inline double IPMetric<kernel::LinearKernel>::Evaluate(const Vec1Type& a,
- const Vec2Type& b)
-{
- return metric::LMetric<2, true>::Evaluate(a, b);
-}
-
-}; // namespace fastmks
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/ip_metric_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/fastmks/ip_metric_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/ip_metric_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/fastmks/ip_metric_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,84 @@
+/**
+ * @file ip_metric_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the IPMetric.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_FASTMKS_IP_METRIC_IMPL_HPP
+#define __MLPACK_METHODS_FASTMKS_IP_METRIC_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "ip_metric_impl.hpp"
+
+#include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/core/kernels/linear_kernel.hpp>
+
+namespace mlpack {
+namespace fastmks {
+
+// Constructor with no instantiated kernel.
+template<typename KernelType>
+IPMetric<KernelType>::IPMetric() :
+ localKernel(new KernelType()),
+ kernel(*localKernel)
+{
+ // Nothing to do.
+}
+
+// Constructor with instantiated kernel.
+template<typename KernelType>
+IPMetric<KernelType>::IPMetric(KernelType& kernel) :
+ localKernel(NULL),
+ kernel(kernel)
+{
+ // Nothing to do.
+}
+
+// Destructor for the IPMetric.
+template<typename KernelType>
+IPMetric<KernelType>::~IPMetric()
+{
+ if (localKernel != NULL)
+ delete localKernel;
+}
+
+template<typename KernelType>
+template<typename Vec1Type, typename Vec2Type>
+inline double IPMetric<KernelType>::Evaluate(const Vec1Type& a,
+ const Vec2Type& b)
+{
+ // This is the metric induced by the kernel function.
+ // Maybe we can do better by caching some of this?
+ return sqrt(kernel.Evaluate(a, a) + kernel.Evaluate(b, b) -
+ 2 * kernel.Evaluate(a, b));
+}
+
+// A specialization for the linear kernel, which actually just turns out to be
+// the Euclidean distance.
+template<>
+template<typename Vec1Type, typename Vec2Type>
+inline double IPMetric<kernel::LinearKernel>::Evaluate(const Vec1Type& a,
+ const Vec2Type& b)
+{
+ return metric::LMetric<2, true>::Evaluate(a, b);
+}
+
+}; // namespace fastmks
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/em_fit.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/gmm/em_fit.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/em_fit.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,175 +0,0 @@
-/**
- * @file em_fit.hpp
- * @author Ryan Curtin
- *
- * Utility class to fit a GMM using the EM algorithm. Used by
- * GMM::Estimate<>().
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_GMM_EM_FIT_HPP
-#define __MLPACK_METHODS_GMM_EM_FIT_HPP
-
-#include <mlpack/core.hpp>
-
-// Default clustering mechanism.
-#include <mlpack/methods/kmeans/kmeans.hpp>
-
-namespace mlpack {
-namespace gmm {
-
-/**
- * This class contains methods which can fit a GMM to observations using the EM
- * algorithm. It requires an initial clustering mechanism, which is by default
- * the KMeans algorithm. The clustering mechanism must implement the following
- * method:
- *
- * - void Cluster(const arma::mat& observations,
- * const size_t clusters,
- * arma::Col<size_t>& assignments);
- *
- * This method should create 'clusters' clusters, and return the assignment of
- * each point to a cluster.
- */
-template<typename InitialClusteringType = kmeans::KMeans<> >
-class EMFit
-{
- public:
- /**
- * Construct the EMFit object, optionally passing an InitialClusteringType
- * object (just in case it needs to store state). Setting the maximum number
- * of iterations to 0 means that the EM algorithm will iterate until
- * convergence (with the given tolerance).
- *
- * The parameter forcePositive controls whether or not the covariance matrices
- * are checked for positive definiteness at each iteration. This could be a
- * time-consuming task, so, if you know your data is well-behaved, you can set
- * it to false and save some runtime.
- *
- * @param maxIterations Maximum number of iterations for EM.
- * @param tolerance Log-likelihood tolerance required for convergence.
- * @param forcePositive Check for positive-definiteness of each covariance
- * matrix at each iteration.
- * @param clusterer Object which will perform the initial clustering.
- */
- EMFit(const size_t maxIterations = 300,
- const double tolerance = 1e-10,
- const bool forcePositive = true,
- InitialClusteringType clusterer = InitialClusteringType());
-
- /**
- * Fit the observations to a Gaussian mixture model (GMM) using the EM
- * algorithm. The size of the vectors (indicating the number of components)
- * must already be set.
- *
- * @param observations List of observations to train on.
- * @param means Vector to store trained means in.
- * @param covariances Vector to store trained covariances in.
- * @param weights Vector to store a priori weights in.
- */
- void Estimate(const arma::mat& observations,
- std::vector<arma::vec>& means,
- std::vector<arma::mat>& covariances,
- arma::vec& weights);
-
- /**
- * Fit the observations to a Gaussian mixture model (GMM) using the EM
- * algorithm, taking into account the probabilities of each point being from
- * this mixture. The size of the vectors (indicating the number of
- * components) must already be set.
- *
- * @param observations List of observations to train on.
- * @param probabilities Probability of each point being from this model.
- * @param means Vector to store trained means in.
- * @param covariances Vector to store trained covariances in.
- * @param weights Vector to store a priori weights in.
- */
- void Estimate(const arma::mat& observations,
- const arma::vec& probabilities,
- std::vector<arma::vec>& means,
- std::vector<arma::mat>& covariances,
- arma::vec& weights);
-
- //! Get the clusterer.
- const InitialClusteringType& Clusterer() const { return clusterer; }
- //! Modify the clusterer.
- InitialClusteringType& Clusterer() { return clusterer; }
-
- //! Get the maximum number of iterations of the EM algorithm.
- size_t MaxIterations() const { return maxIterations; }
- //! Modify the maximum number of iterations of the EM algorithm.
- size_t& MaxIterations() { return maxIterations; }
-
- //! Get the tolerance for the convergence of the EM algorithm.
- double Tolerance() const { return tolerance; }
- //! Modify the tolerance for the convergence of the EM algorithm.
- double& Tolerance() { return tolerance; }
-
- //! Get whether or not the covariance matrices are forced to be positive
- //! definite.
- bool ForcePositive() const { return forcePositive; }
- //! Modify whether or not the covariance matrices are forced to be positive
- //! definite.
- bool& ForcePositive() { return forcePositive; }
-
- private:
- /**
- * Run the clusterer, and then turn the cluster assignments into Gaussians.
- * This is a helper function for both overloads of Estimate(). The vectors
- * must be already set to the number of clusters.
- *
- * @param observations List of observations.
- * @param means Vector to store means in.
- * @param covariances Vector to store covariances in.
- * @param weights Vector to store a priori weights in.
- */
- void InitialClustering(const arma::mat& observations,
- std::vector<arma::vec>& means,
- std::vector<arma::mat>& covariances,
- arma::vec& weights);
-
- /**
- * Calculate the log-likelihood of a model. Yes, this is reimplemented in the
- * GMM code. Intuition suggests that the log-likelihood is not the best way
- * to determine if the EM algorithm has converged.
- *
- * @param data Data matrix.
- * @param means Vector of means.
- * @param covariances Vector of covariance matrices.
- * @param weights Vector of a priori weights.
- */
- double LogLikelihood(const arma::mat& data,
- const std::vector<arma::vec>& means,
- const std::vector<arma::mat>& covariances,
- const arma::vec& weights) const;
-
- //! Maximum iterations of EM algorithm.
- size_t maxIterations;
- //! Tolerance for convergence of EM.
- double tolerance;
- //! Whether or not to force positive definiteness of covariance matrices.
- bool forcePositive;
- //! Object which will perform the clustering.
- InitialClusteringType clusterer;
-};
-
-}; // namespace gmm
-}; // namespace mlpack
-
-// Include implementation.
-#include "em_fit_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/em_fit.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/gmm/em_fit.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/em_fit.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/em_fit.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,175 @@
+/**
+ * @file em_fit.hpp
+ * @author Ryan Curtin
+ *
+ * Utility class to fit a GMM using the EM algorithm. Used by
+ * GMM::Estimate<>().
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_GMM_EM_FIT_HPP
+#define __MLPACK_METHODS_GMM_EM_FIT_HPP
+
+#include <mlpack/core.hpp>
+
+// Default clustering mechanism.
+#include <mlpack/methods/kmeans/kmeans.hpp>
+
+namespace mlpack {
+namespace gmm {
+
+/**
+ * This class contains methods which can fit a GMM to observations using the EM
+ * algorithm. It requires an initial clustering mechanism, which is by default
+ * the KMeans algorithm. The clustering mechanism must implement the following
+ * method:
+ *
+ * - void Cluster(const arma::mat& observations,
+ * const size_t clusters,
+ * arma::Col<size_t>& assignments);
+ *
+ * This method should create 'clusters' clusters, and return the assignment of
+ * each point to a cluster.
+ */
+template<typename InitialClusteringType = kmeans::KMeans<> >
+class EMFit
+{
+ public:
+ /**
+ * Construct the EMFit object, optionally passing an InitialClusteringType
+ * object (just in case it needs to store state). Setting the maximum number
+ * of iterations to 0 means that the EM algorithm will iterate until
+ * convergence (with the given tolerance).
+ *
+ * The parameter forcePositive controls whether or not the covariance matrices
+ * are checked for positive definiteness at each iteration. This could be a
+ * time-consuming task, so, if you know your data is well-behaved, you can set
+ * it to false and save some runtime.
+ *
+ * @param maxIterations Maximum number of iterations for EM.
+ * @param tolerance Log-likelihood tolerance required for convergence.
+ * @param forcePositive Check for positive-definiteness of each covariance
+ * matrix at each iteration.
+ * @param clusterer Object which will perform the initial clustering.
+ */
+ EMFit(const size_t maxIterations = 300,
+ const double tolerance = 1e-10,
+ const bool forcePositive = true,
+ InitialClusteringType clusterer = InitialClusteringType());
+
+ /**
+ * Fit the observations to a Gaussian mixture model (GMM) using the EM
+ * algorithm. The size of the vectors (indicating the number of components)
+ * must already be set.
+ *
+ * @param observations List of observations to train on.
+ * @param means Vector to store trained means in.
+ * @param covariances Vector to store trained covariances in.
+ * @param weights Vector to store a priori weights in.
+ */
+ void Estimate(const arma::mat& observations,
+ std::vector<arma::vec>& means,
+ std::vector<arma::mat>& covariances,
+ arma::vec& weights);
+
+ /**
+ * Fit the observations to a Gaussian mixture model (GMM) using the EM
+ * algorithm, taking into account the probabilities of each point being from
+ * this mixture. The size of the vectors (indicating the number of
+ * components) must already be set.
+ *
+ * @param observations List of observations to train on.
+ * @param probabilities Probability of each point being from this model.
+ * @param means Vector to store trained means in.
+ * @param covariances Vector to store trained covariances in.
+ * @param weights Vector to store a priori weights in.
+ */
+ void Estimate(const arma::mat& observations,
+ const arma::vec& probabilities,
+ std::vector<arma::vec>& means,
+ std::vector<arma::mat>& covariances,
+ arma::vec& weights);
+
+ //! Get the clusterer.
+ const InitialClusteringType& Clusterer() const { return clusterer; }
+ //! Modify the clusterer.
+ InitialClusteringType& Clusterer() { return clusterer; }
+
+ //! Get the maximum number of iterations of the EM algorithm.
+ size_t MaxIterations() const { return maxIterations; }
+ //! Modify the maximum number of iterations of the EM algorithm.
+ size_t& MaxIterations() { return maxIterations; }
+
+ //! Get the tolerance for the convergence of the EM algorithm.
+ double Tolerance() const { return tolerance; }
+ //! Modify the tolerance for the convergence of the EM algorithm.
+ double& Tolerance() { return tolerance; }
+
+ //! Get whether or not the covariance matrices are forced to be positive
+ //! definite.
+ bool ForcePositive() const { return forcePositive; }
+ //! Modify whether or not the covariance matrices are forced to be positive
+ //! definite.
+ bool& ForcePositive() { return forcePositive; }
+
+ private:
+ /**
+ * Run the clusterer, and then turn the cluster assignments into Gaussians.
+ * This is a helper function for both overloads of Estimate(). The vectors
+ * must be already set to the number of clusters.
+ *
+ * @param observations List of observations.
+ * @param means Vector to store means in.
+ * @param covariances Vector to store covariances in.
+ * @param weights Vector to store a priori weights in.
+ */
+ void InitialClustering(const arma::mat& observations,
+ std::vector<arma::vec>& means,
+ std::vector<arma::mat>& covariances,
+ arma::vec& weights);
+
+ /**
+ * Calculate the log-likelihood of a model. Yes, this is reimplemented in the
+ * GMM code. Intuition suggests that the log-likelihood is not the best way
+ * to determine if the EM algorithm has converged.
+ *
+ * @param data Data matrix.
+ * @param means Vector of means.
+ * @param covariances Vector of covariance matrices.
+ * @param weights Vector of a priori weights.
+ */
+ double LogLikelihood(const arma::mat& data,
+ const std::vector<arma::vec>& means,
+ const std::vector<arma::mat>& covariances,
+ const arma::vec& weights) const;
+
+ //! Maximum iterations of EM algorithm.
+ size_t maxIterations;
+ //! Tolerance for convergence of EM.
+ double tolerance;
+ //! Whether or not to force positive definiteness of covariance matrices.
+ bool forcePositive;
+ //! Object which will perform the clustering.
+ InitialClusteringType clusterer;
+};
+
+}; // namespace gmm
+}; // namespace mlpack
+
+// Include implementation.
+#include "em_fit_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/em_fit_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/gmm/em_fit_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/em_fit_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,341 +0,0 @@
-/**
- * @file em_fit_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of EM algorithm for fitting GMMs.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_GMM_EM_FIT_IMPL_HPP
-#define __MLPACK_METHODS_GMM_EM_FIT_IMPL_HPP
-
-// In case it hasn't been included yet.
-#include "em_fit.hpp"
-
-// Definition of phi().
-#include "phi.hpp"
-
-namespace mlpack {
-namespace gmm {
-
-//! Constructor.
-template<typename InitialClusteringType>
-EMFit<InitialClusteringType>::EMFit(const size_t maxIterations,
- const double tolerance,
- const bool forcePositive,
- InitialClusteringType clusterer) :
- maxIterations(maxIterations),
- tolerance(tolerance),
- forcePositive(forcePositive),
- clusterer(clusterer)
-{ /* Nothing to do. */ }
-
-template<typename InitialClusteringType>
-void EMFit<InitialClusteringType>::Estimate(const arma::mat& observations,
- std::vector<arma::vec>& means,
- std::vector<arma::mat>& covariances,
- arma::vec& weights)
-{
- InitialClustering(observations, means, covariances, weights);
-
- double l = LogLikelihood(observations, means, covariances, weights);
-
- Log::Debug << "EMFit::Estimate(): initial clustering log-likelihood: "
- << l << std::endl;
-
- double lOld = -DBL_MAX;
- arma::mat condProb(observations.n_cols, means.size());
-
- // Iterate to update the model until no more improvement is found.
- size_t iteration = 1;
- while (std::abs(l - lOld) > tolerance && iteration != maxIterations)
- {
- Log::Info << "EMFit::Estimate(): iteration " << iteration << ", "
- << "log-likelihood " << l << "." << std::endl;
-
- // Calculate the conditional probabilities of choosing a particular
- // Gaussian given the observations and the present theta value.
- for (size_t i = 0; i < means.size(); i++)
- {
- // Store conditional probabilities into condProb vector for each
- // Gaussian. First we make an alias of the condProb vector.
- arma::vec condProbAlias = condProb.unsafe_col(i);
- phi(observations, means[i], covariances[i], condProbAlias);
- condProbAlias *= weights[i];
- }
-
- // Normalize row-wise.
- for (size_t i = 0; i < condProb.n_rows; i++)
- {
- // Avoid dividing by zero; if the probability for everything is 0, we
- // don't want to make it NaN.
- const double probSum = accu(condProb.row(i));
- if (probSum != 0.0)
- condProb.row(i) /= probSum;
- }
-
- // Store the sum of the probability of each state over all the observations.
- arma::vec probRowSums = trans(arma::sum(condProb, 0 /* columnwise */));
-
- // Calculate the new value of the means using the updated conditional
- // probabilities.
- for (size_t i = 0; i < means.size(); i++)
- {
- // Don't update if there's no probability of the Gaussian having points.
- if (probRowSums[i] != 0)
- means[i] = (observations * condProb.col(i)) / probRowSums[i];
-
- // Calculate the new value of the covariances using the updated
- // conditional probabilities and the updated means.
- arma::mat tmp = observations - (means[i] *
- arma::ones<arma::rowvec>(observations.n_cols));
- arma::mat tmpB = tmp % (arma::ones<arma::vec>(observations.n_rows) *
- trans(condProb.col(i)));
-
- // Don't update if there's no probability of the Gaussian having points.
- if (probRowSums[i] != 0.0)
- covariances[i] = (tmp * trans(tmpB)) / probRowSums[i];
-
- // Ensure positive-definiteness. TODO: make this more efficient.
- if (forcePositive && det(covariances[i]) <= 1e-50)
- {
- Log::Debug << "Covariance matrix " << i << " is not positive definite. "
- << "Adding perturbation." << std::endl;
-
- double perturbation = 1e-30;
- while (det(covariances[i]) <= 1e-50)
- {
- covariances[i].diag() += perturbation;
- perturbation *= 10; // Slow, but we don't want to add too much.
- }
- }
- }
-
- // Calculate the new values for omega using the updated conditional
- // probabilities.
- weights = probRowSums / observations.n_cols;
-
- // Update values of l; calculate new log-likelihood.
- lOld = l;
- l = LogLikelihood(observations, means, covariances, weights);
-
- iteration++;
- }
-}
-
-template<typename InitialClusteringType>
-void EMFit<InitialClusteringType>::Estimate(const arma::mat& observations,
- const arma::vec& probabilities,
- std::vector<arma::vec>& means,
- std::vector<arma::mat>& covariances,
- arma::vec& weights)
-{
- InitialClustering(observations, means, covariances, weights);
-
- double l = LogLikelihood(observations, means, covariances, weights);
-
- Log::Debug << "EMFit::Estimate(): initial clustering log-likelihood: "
- << l << std::endl;
-
- double lOld = -DBL_MAX;
- arma::mat condProb(observations.n_cols, means.size());
-
- // Iterate to update the model until no more improvement is found.
- size_t iteration = 1;
- while (std::abs(l - lOld) > tolerance && iteration != maxIterations)
- {
- // Calculate the conditional probabilities of choosing a particular
- // Gaussian given the observations and the present theta value.
- for (size_t i = 0; i < means.size(); i++)
- {
- // Store conditional probabilities into condProb vector for each
- // Gaussian. First we make an alias of the condProb vector.
- arma::vec condProbAlias = condProb.unsafe_col(i);
- phi(observations, means[i], covariances[i], condProbAlias);
- condProbAlias *= weights[i];
- }
-
- // Normalize row-wise.
- for (size_t i = 0; i < condProb.n_rows; i++)
- {
- // Avoid dividing by zero; if the probability for everything is 0, we
- // don't want to make it NaN.
- const double probSum = accu(condProb.row(i));
- if (probSum != 0.0)
- condProb.row(i) /= probSum;
- }
-
- // This will store the sum of probabilities of each state over all the
- // observations.
- arma::vec probRowSums(means.size());
-
- // Calculate the new value of the means using the updated conditional
- // probabilities.
- for (size_t i = 0; i < means.size(); i++)
- {
- // Calculate the sum of probabilities of points, which is the
- // conditional probability of each point being from Gaussian i
- // multiplied by the probability of the point being from this mixture
- // model.
- probRowSums[i] = accu(condProb.col(i) % probabilities);
-
- means[i] = (observations * (condProb.col(i) % probabilities)) /
- probRowSums[i];
-
- // Calculate the new value of the covariances using the updated
- // conditional probabilities and the updated means.
- arma::mat tmp = observations - (means[i] *
- arma::ones<arma::rowvec>(observations.n_cols));
- arma::mat tmpB = tmp % (arma::ones<arma::vec>(observations.n_rows) *
- trans(condProb.col(i) % probabilities));
-
- covariances[i] = (tmp * trans(tmpB)) / probRowSums[i];
-
- // Ensure positive-definiteness. TODO: make this more efficient.
- if (forcePositive && det(covariances[i]) <= 1e-50)
- {
- Log::Debug << "Covariance matrix " << i << " is not positive definite. "
- << "Adding perturbation." << std::endl;
-
- double perturbation = 1e-30;
- while (det(covariances[i]) <= 1e-50)
- {
- covariances[i].diag() += perturbation;
- perturbation *= 10; // Slow, but we don't want to add too much.
- }
- }
- }
-
- // Calculate the new values for omega using the updated conditional
- // probabilities.
- weights = probRowSums / accu(probabilities);
-
- // Update values of l; calculate new log-likelihood.
- lOld = l;
- l = LogLikelihood(observations, means, covariances, weights);
-
- iteration++;
- }
-}
-
-template<typename InitialClusteringType>
-void EMFit<InitialClusteringType>::InitialClustering(
- const arma::mat& observations,
- std::vector<arma::vec>& means,
- std::vector<arma::mat>& covariances,
- arma::vec& weights)
-{
- // Assignments from clustering.
- arma::Col<size_t> assignments;
-
- // Run clustering algorithm.
- clusterer.Cluster(observations, means.size(), assignments);
-
- // Now calculate the means, covariances, and weights.
- weights.zeros();
- for (size_t i = 0; i < means.size(); ++i)
- {
- means[i].zeros();
- covariances[i].zeros();
- }
-
- // From the assignments, generate our means, covariances, and weights.
- for (size_t i = 0; i < observations.n_cols; ++i)
- {
- const size_t cluster = assignments[i];
-
- // Add this to the relevant mean.
- means[cluster] += observations.col(i);
-
- // Add this to the relevant covariance.
-// covariances[cluster] += observations.col(i) * trans(observations.col(i));
-
- // Now add one to the weights (we will normalize).
- weights[cluster]++;
- }
-
- // Now normalize the mean and covariance.
- for (size_t i = 0; i < means.size(); ++i)
- {
-// covariances[i] -= means[i] * trans(means[i]);
-
- means[i] /= (weights[i] > 1) ? weights[i] : 1;
-// covariances[i] /= (weights[i] > 1) ? weights[i] : 1;
- }
-
- for (size_t i = 0; i < observations.n_cols; ++i)
- {
- const size_t cluster = assignments[i];
- const arma::vec normObs = observations.col(i) - means[cluster];
- covariances[cluster] += normObs * normObs.t();
- }
-
- for (size_t i = 0; i < means.size(); ++i)
- {
- covariances[i] /= (weights[i] > 1) ? weights[i] : 1;
-
- // Ensure positive-definiteness. TODO: make this more efficient.
- if (forcePositive && det(covariances[i]) <= 1e-50)
- {
- Log::Debug << "Covariance matrix " << i << " is not positive definite. "
- << "Adding perturbation." << std::endl;
-
- double perturbation = 1e-50;
- while (det(covariances[i]) <= 1e-50)
- {
- covariances[i].diag() += perturbation;
- perturbation *= 10; // Slow, but we don't want to add too much.
- }
- }
- }
-
- // Finally, normalize weights.
- weights /= accu(weights);
-}
-
-template<typename InitialClusteringType>
-double EMFit<InitialClusteringType>::LogLikelihood(
- const arma::mat& observations,
- const std::vector<arma::vec>& means,
- const std::vector<arma::mat>& covariances,
- const arma::vec& weights) const
-{
- double logLikelihood = 0;
-
- arma::vec phis;
- arma::mat likelihoods(means.size(), observations.n_cols);
- for (size_t i = 0; i < means.size(); ++i)
- {
- phi(observations, means[i], covariances[i], phis);
- likelihoods.row(i) = weights(i) * trans(phis);
- }
-
- // Now sum over every point.
- for (size_t j = 0; j < observations.n_cols; ++j)
- {
- if (accu(likelihoods.col(j)) == 0)
- Log::Info << "Likelihood of point " << j << " is 0! It is probably an "
- << "outlier." << std::endl;
- logLikelihood += log(accu(likelihoods.col(j)));
- }
-
- return logLikelihood;
-}
-
-}; // namespace gmm
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/em_fit_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/gmm/em_fit_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/em_fit_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/em_fit_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,341 @@
+/**
+ * @file em_fit_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of EM algorithm for fitting GMMs.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_GMM_EM_FIT_IMPL_HPP
+#define __MLPACK_METHODS_GMM_EM_FIT_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "em_fit.hpp"
+
+// Definition of phi().
+#include "phi.hpp"
+
+namespace mlpack {
+namespace gmm {
+
+//! Constructor.
+template<typename InitialClusteringType>
+EMFit<InitialClusteringType>::EMFit(const size_t maxIterations,
+ const double tolerance,
+ const bool forcePositive,
+ InitialClusteringType clusterer) :
+ maxIterations(maxIterations),
+ tolerance(tolerance),
+ forcePositive(forcePositive),
+ clusterer(clusterer)
+{ /* Nothing to do. */ }
+
+template<typename InitialClusteringType>
+void EMFit<InitialClusteringType>::Estimate(const arma::mat& observations,
+ std::vector<arma::vec>& means,
+ std::vector<arma::mat>& covariances,
+ arma::vec& weights)
+{
+ InitialClustering(observations, means, covariances, weights);
+
+ double l = LogLikelihood(observations, means, covariances, weights);
+
+ Log::Debug << "EMFit::Estimate(): initial clustering log-likelihood: "
+ << l << std::endl;
+
+ double lOld = -DBL_MAX;
+ arma::mat condProb(observations.n_cols, means.size());
+
+ // Iterate to update the model until no more improvement is found.
+ size_t iteration = 1;
+ while (std::abs(l - lOld) > tolerance && iteration != maxIterations)
+ {
+ Log::Info << "EMFit::Estimate(): iteration " << iteration << ", "
+ << "log-likelihood " << l << "." << std::endl;
+
+ // Calculate the conditional probabilities of choosing a particular
+ // Gaussian given the observations and the present theta value.
+ for (size_t i = 0; i < means.size(); i++)
+ {
+ // Store conditional probabilities into condProb vector for each
+ // Gaussian. First we make an alias of the condProb vector.
+ arma::vec condProbAlias = condProb.unsafe_col(i);
+ phi(observations, means[i], covariances[i], condProbAlias);
+ condProbAlias *= weights[i];
+ }
+
+ // Normalize row-wise.
+ for (size_t i = 0; i < condProb.n_rows; i++)
+ {
+ // Avoid dividing by zero; if the probability for everything is 0, we
+ // don't want to make it NaN.
+ const double probSum = accu(condProb.row(i));
+ if (probSum != 0.0)
+ condProb.row(i) /= probSum;
+ }
+
+ // Store the sum of the probability of each state over all the observations.
+ arma::vec probRowSums = trans(arma::sum(condProb, 0 /* columnwise */));
+
+ // Calculate the new value of the means using the updated conditional
+ // probabilities.
+ for (size_t i = 0; i < means.size(); i++)
+ {
+ // Don't update if there's no probability of the Gaussian having points.
+ if (probRowSums[i] != 0)
+ means[i] = (observations * condProb.col(i)) / probRowSums[i];
+
+ // Calculate the new value of the covariances using the updated
+ // conditional probabilities and the updated means.
+ arma::mat tmp = observations - (means[i] *
+ arma::ones<arma::rowvec>(observations.n_cols));
+ arma::mat tmpB = tmp % (arma::ones<arma::vec>(observations.n_rows) *
+ trans(condProb.col(i)));
+
+ // Don't update if there's no probability of the Gaussian having points.
+ if (probRowSums[i] != 0.0)
+ covariances[i] = (tmp * trans(tmpB)) / probRowSums[i];
+
+ // Ensure positive-definiteness. TODO: make this more efficient.
+ if (forcePositive && det(covariances[i]) <= 1e-50)
+ {
+ Log::Debug << "Covariance matrix " << i << " is not positive definite. "
+ << "Adding perturbation." << std::endl;
+
+ double perturbation = 1e-30;
+ while (det(covariances[i]) <= 1e-50)
+ {
+ covariances[i].diag() += perturbation;
+ perturbation *= 10; // Slow, but we don't want to add too much.
+ }
+ }
+ }
+
+ // Calculate the new values for omega using the updated conditional
+ // probabilities.
+ weights = probRowSums / observations.n_cols;
+
+ // Update values of l; calculate new log-likelihood.
+ lOld = l;
+ l = LogLikelihood(observations, means, covariances, weights);
+
+ iteration++;
+ }
+}
+
+template<typename InitialClusteringType>
+void EMFit<InitialClusteringType>::Estimate(const arma::mat& observations,
+ const arma::vec& probabilities,
+ std::vector<arma::vec>& means,
+ std::vector<arma::mat>& covariances,
+ arma::vec& weights)
+{
+ InitialClustering(observations, means, covariances, weights);
+
+ double l = LogLikelihood(observations, means, covariances, weights);
+
+ Log::Debug << "EMFit::Estimate(): initial clustering log-likelihood: "
+ << l << std::endl;
+
+ double lOld = -DBL_MAX;
+ arma::mat condProb(observations.n_cols, means.size());
+
+ // Iterate to update the model until no more improvement is found.
+ size_t iteration = 1;
+ while (std::abs(l - lOld) > tolerance && iteration != maxIterations)
+ {
+ // Calculate the conditional probabilities of choosing a particular
+ // Gaussian given the observations and the present theta value.
+ for (size_t i = 0; i < means.size(); i++)
+ {
+ // Store conditional probabilities into condProb vector for each
+ // Gaussian. First we make an alias of the condProb vector.
+ arma::vec condProbAlias = condProb.unsafe_col(i);
+ phi(observations, means[i], covariances[i], condProbAlias);
+ condProbAlias *= weights[i];
+ }
+
+ // Normalize row-wise.
+ for (size_t i = 0; i < condProb.n_rows; i++)
+ {
+ // Avoid dividing by zero; if the probability for everything is 0, we
+ // don't want to make it NaN.
+ const double probSum = accu(condProb.row(i));
+ if (probSum != 0.0)
+ condProb.row(i) /= probSum;
+ }
+
+ // This will store the sum of probabilities of each state over all the
+ // observations.
+ arma::vec probRowSums(means.size());
+
+ // Calculate the new value of the means using the updated conditional
+ // probabilities.
+ for (size_t i = 0; i < means.size(); i++)
+ {
+ // Calculate the sum of probabilities of points, which is the
+ // conditional probability of each point being from Gaussian i
+ // multiplied by the probability of the point being from this mixture
+ // model.
+ probRowSums[i] = accu(condProb.col(i) % probabilities);
+
+ means[i] = (observations * (condProb.col(i) % probabilities)) /
+ probRowSums[i];
+
+ // Calculate the new value of the covariances using the updated
+ // conditional probabilities and the updated means.
+ arma::mat tmp = observations - (means[i] *
+ arma::ones<arma::rowvec>(observations.n_cols));
+ arma::mat tmpB = tmp % (arma::ones<arma::vec>(observations.n_rows) *
+ trans(condProb.col(i) % probabilities));
+
+ covariances[i] = (tmp * trans(tmpB)) / probRowSums[i];
+
+ // Ensure positive-definiteness. TODO: make this more efficient.
+ if (forcePositive && det(covariances[i]) <= 1e-50)
+ {
+ Log::Debug << "Covariance matrix " << i << " is not positive definite. "
+ << "Adding perturbation." << std::endl;
+
+ double perturbation = 1e-30;
+ while (det(covariances[i]) <= 1e-50)
+ {
+ covariances[i].diag() += perturbation;
+ perturbation *= 10; // Slow, but we don't want to add too much.
+ }
+ }
+ }
+
+ // Calculate the new values for omega using the updated conditional
+ // probabilities.
+ weights = probRowSums / accu(probabilities);
+
+ // Update values of l; calculate new log-likelihood.
+ lOld = l;
+ l = LogLikelihood(observations, means, covariances, weights);
+
+ iteration++;
+ }
+}
+
+template<typename InitialClusteringType>
+void EMFit<InitialClusteringType>::InitialClustering(
+ const arma::mat& observations,
+ std::vector<arma::vec>& means,
+ std::vector<arma::mat>& covariances,
+ arma::vec& weights)
+{
+ // Assignments from clustering.
+ arma::Col<size_t> assignments;
+
+ // Run clustering algorithm.
+ clusterer.Cluster(observations, means.size(), assignments);
+
+ // Now calculate the means, covariances, and weights.
+ weights.zeros();
+ for (size_t i = 0; i < means.size(); ++i)
+ {
+ means[i].zeros();
+ covariances[i].zeros();
+ }
+
+ // From the assignments, generate our means, covariances, and weights.
+ for (size_t i = 0; i < observations.n_cols; ++i)
+ {
+ const size_t cluster = assignments[i];
+
+ // Add this to the relevant mean.
+ means[cluster] += observations.col(i);
+
+ // Add this to the relevant covariance.
+// covariances[cluster] += observations.col(i) * trans(observations.col(i));
+
+ // Now add one to the weights (we will normalize).
+ weights[cluster]++;
+ }
+
+ // Now normalize the mean and covariance.
+ for (size_t i = 0; i < means.size(); ++i)
+ {
+// covariances[i] -= means[i] * trans(means[i]);
+
+ means[i] /= (weights[i] > 1) ? weights[i] : 1;
+// covariances[i] /= (weights[i] > 1) ? weights[i] : 1;
+ }
+
+ for (size_t i = 0; i < observations.n_cols; ++i)
+ {
+ const size_t cluster = assignments[i];
+ const arma::vec normObs = observations.col(i) - means[cluster];
+ covariances[cluster] += normObs * normObs.t();
+ }
+
+ for (size_t i = 0; i < means.size(); ++i)
+ {
+ covariances[i] /= (weights[i] > 1) ? weights[i] : 1;
+
+ // Ensure positive-definiteness. TODO: make this more efficient.
+ if (forcePositive && det(covariances[i]) <= 1e-50)
+ {
+ Log::Debug << "Covariance matrix " << i << " is not positive definite. "
+ << "Adding perturbation." << std::endl;
+
+ double perturbation = 1e-50;
+ while (det(covariances[i]) <= 1e-50)
+ {
+ covariances[i].diag() += perturbation;
+ perturbation *= 10; // Slow, but we don't want to add too much.
+ }
+ }
+ }
+
+ // Finally, normalize weights.
+ weights /= accu(weights);
+}
+
+template<typename InitialClusteringType>
+double EMFit<InitialClusteringType>::LogLikelihood(
+ const arma::mat& observations,
+ const std::vector<arma::vec>& means,
+ const std::vector<arma::mat>& covariances,
+ const arma::vec& weights) const
+{
+ double logLikelihood = 0;
+
+ arma::vec phis;
+ arma::mat likelihoods(means.size(), observations.n_cols);
+ for (size_t i = 0; i < means.size(); ++i)
+ {
+ phi(observations, means[i], covariances[i], phis);
+ likelihoods.row(i) = weights(i) * trans(phis);
+ }
+
+ // Now sum over every point.
+ for (size_t j = 0; j < observations.n_cols; ++j)
+ {
+ if (accu(likelihoods.col(j)) == 0)
+ Log::Info << "Likelihood of point " << j << " is 0! It is probably an "
+ << "outlier." << std::endl;
+ logLikelihood += log(accu(likelihoods.col(j)));
+ }
+
+ return logLikelihood;
+}
+
+}; // namespace gmm
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/gmm.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/gmm/gmm.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/gmm.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,377 +0,0 @@
-/**
- * @author Parikshit Ram (pram at cc.gatech.edu)
- * @file gmm.hpp
- *
- * Defines a Gaussian Mixture model and
- * estimates the parameters of the model
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_MOG_MOG_EM_HPP
-#define __MLPACK_METHODS_MOG_MOG_EM_HPP
-
-#include <mlpack/core.hpp>
-
-// This is the default fitting method class.
-#include "em_fit.hpp"
-
-namespace mlpack {
-namespace gmm /** Gaussian Mixture Models. */ {
-
-/**
- * A Gaussian Mixture Model (GMM). This class uses maximum likelihood loss
- * functions to estimate the parameters of the GMM on a given dataset via the
- * given fitting mechanism, defined by the FittingType template parameter. The
- * GMM can be trained using normal data, or data with probabilities of being
- * from this GMM (see GMM::Estimate() for more information).
- *
- * The FittingType template class must provide a way for the GMM to train on
- * data. It must provide the following two functions:
- *
- * @code
- * void Estimate(const arma::mat& observations,
- * std::vector<arma::vec>& means,
- * std::vector<arma::mat>& covariances,
- * arma::vec& weights);
- *
- * void Estimate(const arma::mat& observations,
- * const arma::vec& probabilities,
- * std::vector<arma::vec>& means,
- * std::vector<arma::mat>& covariances,
- * arma::vec& weights);
- * @endcode
- *
- * These functions should produce a trained GMM from the given observations and
- * probabilities. These may modify the size of the model (by increasing the
- * size of the mean and covariance vectors as well as the weight vectors), but
- * the method should expect that these vectors are already set to the size of
- * the GMM as specified in the constructor.
- *
- * For a sample implementation, see the EMFit class; this class uses the EM
- * algorithm to train a GMM, and is the default fitting type.
- *
- * The GMM, once trained, can be used to generate random points from the
- * distribution and estimate the probability of points being from the
- * distribution. The parameters of the GMM can be obtained through the
- * accessors and mutators.
- *
- * Example use:
- *
- * @code
- * // Set up a mixture of 5 gaussians in a 4-dimensional space (uses the default
- * // EM fitting mechanism).
- * GMM<> g(5, 4);
- *
- * // Train the GMM given the data observations.
- * g.Estimate(data);
- *
- * // Get the probability of 'observation' being observed from this GMM.
- * double probability = g.Probability(observation);
- *
- * // Get a random observation from the GMM.
- * arma::vec observation = g.Random();
- * @endcode
- */
-template<typename FittingType = EMFit<> >
-class GMM
-{
- private:
- //! The number of Gaussians in the model.
- size_t gaussians;
- //! The dimensionality of the model.
- size_t dimensionality;
- //! Vector of means; one for each Gaussian.
- std::vector<arma::vec> means;
- //! Vector of covariances; one for each Gaussian.
- std::vector<arma::mat> covariances;
- //! Vector of a priori weights for each Gaussian.
- arma::vec weights;
-
- public:
- /**
- * Create an empty Gaussian Mixture Model, with zero gaussians.
- */
- GMM() :
- gaussians(0),
- dimensionality(0),
- localFitter(FittingType()),
- fitter(localFitter)
- {
- // Warn the user. They probably don't want to do this. If this constructor
- // is being used (because it is required by some template classes), the user
- // should know that it is potentially dangerous.
- Log::Debug << "GMM::GMM(): no parameters given; Estimate() may fail "
- << "unless parameters are set." << std::endl;
- }
-
- /**
- * Create a GMM with the given number of Gaussians, each of which have the
- * specified dimensionality.
- *
- * @param gaussians Number of Gaussians in this GMM.
- * @param dimensionality Dimensionality of each Gaussian.
- */
- GMM(const size_t gaussians, const size_t dimensionality) :
- gaussians(gaussians),
- dimensionality(dimensionality),
- means(gaussians, arma::vec(dimensionality)),
- covariances(gaussians, arma::mat(dimensionality, dimensionality)),
- weights(gaussians),
- localFitter(FittingType()),
- fitter(localFitter) { /* Nothing to do. */ }
-
- /**
- * Create a GMM with the given number of Gaussians, each of which have the
- * specified dimensionality. Also, pass in an initialized FittingType class;
- * this is useful in cases where the FittingType class needs to store some
- * state.
- *
- * @param gaussians Number of Gaussians in this GMM.
- * @param dimensionality Dimensionality of each Gaussian.
- * @param fitter Initialized fitting mechanism.
- */
- GMM(const size_t gaussians,
- const size_t dimensionality,
- FittingType& fitter) :
- gaussians(gaussians),
- dimensionality(dimensionality),
- means(gaussians, arma::vec(dimensionality)),
- covariances(gaussians, arma::mat(dimensionality, dimensionality)),
- weights(gaussians),
- fitter(fitter) { /* Nothing to do. */ }
-
- /**
- * Create a GMM with the given means, covariances, and weights.
- *
- * @param means Means of the model.
- * @param covariances Covariances of the model.
- * @param weights Weights of the model.
- */
- GMM(const std::vector<arma::vec>& means,
- const std::vector<arma::mat>& covariances,
- const arma::vec& weights) :
- gaussians(means.size()),
- dimensionality((!means.empty()) ? means[0].n_elem : 0),
- means(means),
- covariances(covariances),
- weights(weights),
- localFitter(FittingType()),
- fitter(localFitter) { /* Nothing to do. */ }
-
- /**
- * Create a GMM with the given means, covariances, and weights, and use the
- * given initialized FittingType class. This is useful in cases where the
- * FittingType class needs to store some state.
- *
- * @param means Means of the model.
- * @param covariances Covariances of the model.
- * @param weights Weights of the model.
- */
- GMM(const std::vector<arma::vec>& means,
- const std::vector<arma::mat>& covariances,
- const arma::vec& weights,
- FittingType& fitter) :
- gaussians(means.size()),
- dimensionality((!means.empty()) ? means[0].n_elem : 0),
- means(means),
- covariances(covariances),
- weights(weights),
- fitter(fitter) { /* Nothing to do. */ }
-
- /**
- * Copy constructor for GMMs which use different fitting types.
- */
- template<typename OtherFittingType>
- GMM(const GMM<OtherFittingType>& other);
-
- /**
- * Copy constructor for GMMs using the same fitting type. This also copies
- * the fitter.
- */
- GMM(const GMM& other);
-
- /**
- * Copy operator for GMMs which use different fitting types.
- */
- template<typename OtherFittingType>
- GMM& operator=(const GMM<OtherFittingType>& other);
-
- /**
- * Copy operator for GMMs which use the same fitting type. This also copies
- * the fitter.
- */
- GMM& operator=(const GMM& other);
-
- /**
- * Load a GMM from an XML file. The format of the XML file should be the same
- * as is generated by the Save() method.
- *
- * @param filename Name of XML file containing model to be loaded.
- */
- void Load(const std::string& filename);
-
- /**
- * Save a GMM to an XML file.
- *
- * @param filename Name of XML file to write to.
- */
- void Save(const std::string& filename) const;
-
- //! Return the number of gaussians in the model.
- size_t Gaussians() const { return gaussians; }
- //! Modify the number of gaussians in the model. Careful! You will have to
- //! resize the means, covariances, and weights yourself.
- size_t& Gaussians() { return gaussians; }
-
- //! Return the dimensionality of the model.
- size_t Dimensionality() const { return dimensionality; }
- //! Modify the dimensionality of the model. Careful! You will have to update
- //! each mean and covariance matrix yourself.
- size_t& Dimensionality() { return dimensionality; }
-
- //! Return a const reference to the vector of means (mu).
- const std::vector<arma::vec>& Means() const { return means; }
- //! Return a reference to the vector of means (mu).
- std::vector<arma::vec>& Means() { return means; }
-
- //! Return a const reference to the vector of covariance matrices (sigma).
- const std::vector<arma::mat>& Covariances() const { return covariances; }
- //! Return a reference to the vector of covariance matrices (sigma).
- std::vector<arma::mat>& Covariances() { return covariances; }
-
- //! Return a const reference to the a priori weights of each Gaussian.
- const arma::vec& Weights() const { return weights; }
- //! Return a reference to the a priori weights of each Gaussian.
- arma::vec& Weights() { return weights; }
-
- //! Return a const reference to the fitting type.
- const FittingType& Fitter() const { return fitter; }
- //! Return a reference to the fitting type.
- FittingType& Fitter() { return fitter; }
-
- /**
- * Return the probability that the given observation came from this
- * distribution.
- *
- * @param observation Observation to evaluate the probability of.
- */
- double Probability(const arma::vec& observation) const;
-
- /**
- * Return the probability that the given observation came from the given
- * Gaussian component in this distribution.
- *
- * @param observation Observation to evaluate the probability of.
- * @param component Index of the component of the GMM to be considered.
- */
- double Probability(const arma::vec& observation,
- const size_t component) const;
-
- /**
- * Return a randomly generated observation according to the probability
- * distribution defined by this object.
- *
- * @return Random observation from this GMM.
- */
- arma::vec Random() const;
-
- /**
- * Estimate the probability distribution directly from the given observations,
- * using the given algorithm in the FittingType class to fit the data.
- *
- * The fitting will be performed 'trials' times; from these trials, the model
- * with the greatest log-likelihood will be selected. By default, only one
- * trial is performed. The log-likelihood of the best fitting is returned.
- *
- * @tparam FittingType The type of fitting method which should be used
- * (EMFit<> is suggested).
- * @param observations Observations of the model.
- * @param trials Number of trials to perform; the model in these trials with
- * the greatest log-likelihood will be selected.
- * @return The log-likelihood of the best fit.
- */
- double Estimate(const arma::mat& observations,
- const size_t trials = 1);
-
- /**
- * Estimate the probability distribution directly from the given observations,
- * taking into account the probability of each observation actually being from
- * this distribution, and using the given algorithm in the FittingType class
- * to fit the data.
- *
- * The fitting will be performed 'trials' times; from these trials, the model
- * with the greatest log-likelihood will be selected. By default, only one
- * trial is performed. The log-likelihood of the best fitting is returned.
- *
- * @param observations Observations of the model.
- * @param probabilities Probability of each observation being from this
- * distribution.
- * @param trials Number of trials to perform; the model in these trials with
- * the greatest log-likelihood will be selected.
- * @return The log-likelihood of the best fit.
- */
- double Estimate(const arma::mat& observations,
- const arma::vec& probabilities,
- const size_t trials = 1);
-
- /**
- * Classify the given observations as being from an individual component in
- * this GMM. The resultant classifications are stored in the 'labels' object,
- * and each label will be between 0 and (Gaussians() - 1). Supposing that a
- * point was classified with label 2, and that our GMM object was called
- * 'gmm', one could access the relevant Gaussian distribution as follows:
- *
- * @code
- * arma::vec mean = gmm.Means()[2];
- * arma::mat covariance = gmm.Covariances()[2];
- * double priorWeight = gmm.Weights()[2];
- * @endcode
- *
- * @param observations List of observations to classify.
- * @param labels Object which will be filled with labels.
- */
- void Classify(const arma::mat& observations,
- arma::Col<size_t>& labels) const;
-
- private:
- /**
- * This function computes the loglikelihood of the given model. This function
- * is used by GMM::Estimate().
- *
- * @param dataPoints Observations to calculate the likelihood for.
- * @param means Means of the given mixture model.
- * @param covars Covariances of the given mixture model.
- * @param weights Weights of the given mixture model.
- */
- double LogLikelihood(const arma::mat& dataPoints,
- const std::vector<arma::vec>& means,
- const std::vector<arma::mat>& covars,
- const arma::vec& weights) const;
-
- //! Locally-stored fitting object; in case the user did not pass one.
- FittingType localFitter;
-
- //! Reference to the fitting object we should use.
- FittingType& fitter;
-};
-
-}; // namespace gmm
-}; // namespace mlpack
-
-// Include implementation.
-#include "gmm_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/gmm.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/gmm/gmm.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/gmm.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/gmm.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,377 @@
+/**
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ * @file gmm.hpp
+ *
+ * Defines a Gaussian Mixture model and
+ * estimates the parameters of the model
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_MOG_MOG_EM_HPP
+#define __MLPACK_METHODS_MOG_MOG_EM_HPP
+
+#include <mlpack/core.hpp>
+
+// This is the default fitting method class.
+#include "em_fit.hpp"
+
+namespace mlpack {
+namespace gmm /** Gaussian Mixture Models. */ {
+
+/**
+ * A Gaussian Mixture Model (GMM). This class uses maximum likelihood loss
+ * functions to estimate the parameters of the GMM on a given dataset via the
+ * given fitting mechanism, defined by the FittingType template parameter. The
+ * GMM can be trained using normal data, or data with probabilities of being
+ * from this GMM (see GMM::Estimate() for more information).
+ *
+ * The FittingType template class must provide a way for the GMM to train on
+ * data. It must provide the following two functions:
+ *
+ * @code
+ * void Estimate(const arma::mat& observations,
+ * std::vector<arma::vec>& means,
+ * std::vector<arma::mat>& covariances,
+ * arma::vec& weights);
+ *
+ * void Estimate(const arma::mat& observations,
+ * const arma::vec& probabilities,
+ * std::vector<arma::vec>& means,
+ * std::vector<arma::mat>& covariances,
+ * arma::vec& weights);
+ * @endcode
+ *
+ * These functions should produce a trained GMM from the given observations and
+ * probabilities. These may modify the size of the model (by increasing the
+ * size of the mean and covariance vectors as well as the weight vectors), but
+ * the method should expect that these vectors are already set to the size of
+ * the GMM as specified in the constructor.
+ *
+ * For a sample implementation, see the EMFit class; this class uses the EM
+ * algorithm to train a GMM, and is the default fitting type.
+ *
+ * The GMM, once trained, can be used to generate random points from the
+ * distribution and estimate the probability of points being from the
+ * distribution. The parameters of the GMM can be obtained through the
+ * accessors and mutators.
+ *
+ * Example use:
+ *
+ * @code
+ * // Set up a mixture of 5 gaussians in a 4-dimensional space (uses the default
+ * // EM fitting mechanism).
+ * GMM<> g(5, 4);
+ *
+ * // Train the GMM given the data observations.
+ * g.Estimate(data);
+ *
+ * // Get the probability of 'observation' being observed from this GMM.
+ * double probability = g.Probability(observation);
+ *
+ * // Get a random observation from the GMM.
+ * arma::vec observation = g.Random();
+ * @endcode
+ */
+template<typename FittingType = EMFit<> >
+class GMM
+{
+ private:
+ //! The number of Gaussians in the model.
+ size_t gaussians;
+ //! The dimensionality of the model.
+ size_t dimensionality;
+ //! Vector of means; one for each Gaussian.
+ std::vector<arma::vec> means;
+ //! Vector of covariances; one for each Gaussian.
+ std::vector<arma::mat> covariances;
+ //! Vector of a priori weights for each Gaussian.
+ arma::vec weights;
+
+ public:
+ /**
+ * Create an empty Gaussian Mixture Model, with zero gaussians.
+ */
+ GMM() :
+ gaussians(0),
+ dimensionality(0),
+ localFitter(FittingType()),
+ fitter(localFitter)
+ {
+ // Warn the user. They probably don't want to do this. If this constructor
+ // is being used (because it is required by some template classes), the user
+ // should know that it is potentially dangerous.
+ Log::Debug << "GMM::GMM(): no parameters given; Estimate() may fail "
+ << "unless parameters are set." << std::endl;
+ }
+
+ /**
+ * Create a GMM with the given number of Gaussians, each of which have the
+ * specified dimensionality.
+ *
+ * @param gaussians Number of Gaussians in this GMM.
+ * @param dimensionality Dimensionality of each Gaussian.
+ */
+ GMM(const size_t gaussians, const size_t dimensionality) :
+ gaussians(gaussians),
+ dimensionality(dimensionality),
+ means(gaussians, arma::vec(dimensionality)),
+ covariances(gaussians, arma::mat(dimensionality, dimensionality)),
+ weights(gaussians),
+ localFitter(FittingType()),
+ fitter(localFitter) { /* Nothing to do. */ }
+
+ /**
+ * Create a GMM with the given number of Gaussians, each of which have the
+ * specified dimensionality. Also, pass in an initialized FittingType class;
+ * this is useful in cases where the FittingType class needs to store some
+ * state.
+ *
+ * @param gaussians Number of Gaussians in this GMM.
+ * @param dimensionality Dimensionality of each Gaussian.
+ * @param fitter Initialized fitting mechanism.
+ */
+ GMM(const size_t gaussians,
+ const size_t dimensionality,
+ FittingType& fitter) :
+ gaussians(gaussians),
+ dimensionality(dimensionality),
+ means(gaussians, arma::vec(dimensionality)),
+ covariances(gaussians, arma::mat(dimensionality, dimensionality)),
+ weights(gaussians),
+ fitter(fitter) { /* Nothing to do. */ }
+
+ /**
+ * Create a GMM with the given means, covariances, and weights.
+ *
+ * @param means Means of the model.
+ * @param covariances Covariances of the model.
+ * @param weights Weights of the model.
+ */
+ GMM(const std::vector<arma::vec>& means,
+ const std::vector<arma::mat>& covariances,
+ const arma::vec& weights) :
+ gaussians(means.size()),
+ dimensionality((!means.empty()) ? means[0].n_elem : 0),
+ means(means),
+ covariances(covariances),
+ weights(weights),
+ localFitter(FittingType()),
+ fitter(localFitter) { /* Nothing to do. */ }
+
+ /**
+ * Create a GMM with the given means, covariances, and weights, and use the
+ * given initialized FittingType class. This is useful in cases where the
+ * FittingType class needs to store some state.
+ *
+ * @param means Means of the model.
+ * @param covariances Covariances of the model.
+ * @param weights Weights of the model.
+ */
+ GMM(const std::vector<arma::vec>& means,
+ const std::vector<arma::mat>& covariances,
+ const arma::vec& weights,
+ FittingType& fitter) :
+ gaussians(means.size()),
+ dimensionality((!means.empty()) ? means[0].n_elem : 0),
+ means(means),
+ covariances(covariances),
+ weights(weights),
+ fitter(fitter) { /* Nothing to do. */ }
+
+ /**
+ * Copy constructor for GMMs which use different fitting types.
+ */
+ template<typename OtherFittingType>
+ GMM(const GMM<OtherFittingType>& other);
+
+ /**
+ * Copy constructor for GMMs using the same fitting type. This also copies
+ * the fitter.
+ */
+ GMM(const GMM& other);
+
+ /**
+ * Copy operator for GMMs which use different fitting types.
+ */
+ template<typename OtherFittingType>
+ GMM& operator=(const GMM<OtherFittingType>& other);
+
+ /**
+ * Copy operator for GMMs which use the same fitting type. This also copies
+ * the fitter.
+ */
+ GMM& operator=(const GMM& other);
+
+ /**
+ * Load a GMM from an XML file. The format of the XML file should be the same
+ * as is generated by the Save() method.
+ *
+ * @param filename Name of XML file containing model to be loaded.
+ */
+ void Load(const std::string& filename);
+
+ /**
+ * Save a GMM to an XML file.
+ *
+ * @param filename Name of XML file to write to.
+ */
+ void Save(const std::string& filename) const;
+
+ //! Return the number of gaussians in the model.
+ size_t Gaussians() const { return gaussians; }
+ //! Modify the number of gaussians in the model. Careful! You will have to
+ //! resize the means, covariances, and weights yourself.
+ size_t& Gaussians() { return gaussians; }
+
+ //! Return the dimensionality of the model.
+ size_t Dimensionality() const { return dimensionality; }
+ //! Modify the dimensionality of the model. Careful! You will have to update
+ //! each mean and covariance matrix yourself.
+ size_t& Dimensionality() { return dimensionality; }
+
+ //! Return a const reference to the vector of means (mu).
+ const std::vector<arma::vec>& Means() const { return means; }
+ //! Return a reference to the vector of means (mu).
+ std::vector<arma::vec>& Means() { return means; }
+
+ //! Return a const reference to the vector of covariance matrices (sigma).
+ const std::vector<arma::mat>& Covariances() const { return covariances; }
+ //! Return a reference to the vector of covariance matrices (sigma).
+ std::vector<arma::mat>& Covariances() { return covariances; }
+
+ //! Return a const reference to the a priori weights of each Gaussian.
+ const arma::vec& Weights() const { return weights; }
+ //! Return a reference to the a priori weights of each Gaussian.
+ arma::vec& Weights() { return weights; }
+
+ //! Return a const reference to the fitting type.
+ const FittingType& Fitter() const { return fitter; }
+ //! Return a reference to the fitting type.
+ FittingType& Fitter() { return fitter; }
+
+ /**
+ * Return the probability that the given observation came from this
+ * distribution.
+ *
+ * @param observation Observation to evaluate the probability of.
+ */
+ double Probability(const arma::vec& observation) const;
+
+ /**
+ * Return the probability that the given observation came from the given
+ * Gaussian component in this distribution.
+ *
+ * @param observation Observation to evaluate the probability of.
+ * @param component Index of the component of the GMM to be considered.
+ */
+ double Probability(const arma::vec& observation,
+ const size_t component) const;
+
+ /**
+ * Return a randomly generated observation according to the probability
+ * distribution defined by this object.
+ *
+ * @return Random observation from this GMM.
+ */
+ arma::vec Random() const;
+
+ /**
+ * Estimate the probability distribution directly from the given observations,
+ * using the given algorithm in the FittingType class to fit the data.
+ *
+ * The fitting will be performed 'trials' times; from these trials, the model
+ * with the greatest log-likelihood will be selected. By default, only one
+ * trial is performed. The log-likelihood of the best fitting is returned.
+ *
+ * @tparam FittingType The type of fitting method which should be used
+ * (EMFit<> is suggested).
+ * @param observations Observations of the model.
+ * @param trials Number of trials to perform; the model in these trials with
+ * the greatest log-likelihood will be selected.
+ * @return The log-likelihood of the best fit.
+ */
+ double Estimate(const arma::mat& observations,
+ const size_t trials = 1);
+
+ /**
+ * Estimate the probability distribution directly from the given observations,
+ * taking into account the probability of each observation actually being from
+ * this distribution, and using the given algorithm in the FittingType class
+ * to fit the data.
+ *
+ * The fitting will be performed 'trials' times; from these trials, the model
+ * with the greatest log-likelihood will be selected. By default, only one
+ * trial is performed. The log-likelihood of the best fitting is returned.
+ *
+ * @param observations Observations of the model.
+ * @param probabilities Probability of each observation being from this
+ * distribution.
+ * @param trials Number of trials to perform; the model in these trials with
+ * the greatest log-likelihood will be selected.
+ * @return The log-likelihood of the best fit.
+ */
+ double Estimate(const arma::mat& observations,
+ const arma::vec& probabilities,
+ const size_t trials = 1);
+
+ /**
+ * Classify the given observations as being from an individual component in
+ * this GMM. The resultant classifications are stored in the 'labels' object,
+ * and each label will be between 0 and (Gaussians() - 1). Supposing that a
+ * point was classified with label 2, and that our GMM object was called
+ * 'gmm', one could access the relevant Gaussian distribution as follows:
+ *
+ * @code
+ * arma::vec mean = gmm.Means()[2];
+ * arma::mat covariance = gmm.Covariances()[2];
+ * double priorWeight = gmm.Weights()[2];
+ * @endcode
+ *
+ * @param observations List of observations to classify.
+ * @param labels Object which will be filled with labels.
+ */
+ void Classify(const arma::mat& observations,
+ arma::Col<size_t>& labels) const;
+
+ private:
+ /**
+ * This function computes the loglikelihood of the given model. This function
+ * is used by GMM::Estimate().
+ *
+ * @param dataPoints Observations to calculate the likelihood for.
+ * @param means Means of the given mixture model.
+ * @param covars Covariances of the given mixture model.
+ * @param weights Weights of the given mixture model.
+ */
+ double LogLikelihood(const arma::mat& dataPoints,
+ const std::vector<arma::vec>& means,
+ const std::vector<arma::mat>& covars,
+ const arma::vec& weights) const;
+
+ //! Locally-stored fitting object; in case the user did not pass one.
+ FittingType localFitter;
+
+ //! Reference to the fitting object we should use.
+ FittingType& fitter;
+};
+
+}; // namespace gmm
+}; // namespace mlpack
+
+// Include implementation.
+#include "gmm_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/gmm_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/gmm/gmm_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/gmm_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,395 +0,0 @@
-/**
- * @file gmm_impl.hpp
- * @author Parikshit Ram (pram at cc.gatech.edu)
- * @author Ryan Curtin
- *
- * Implementation of template-based GMM methods.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_GMM_GMM_IMPL_HPP
-#define __MLPACK_METHODS_GMM_GMM_IMPL_HPP
-
-// In case it hasn't already been included.
-#include "gmm.hpp"
-
-#include <mlpack/core/util/save_restore_utility.hpp>
-
-namespace mlpack {
-namespace gmm {
-
-// Copy constructor.
-template<typename FittingType>
-template<typename OtherFittingType>
-GMM<FittingType>::GMM(const GMM<OtherFittingType>& other) :
- gaussians(other.Gaussians()),
- dimensionality(other.Dimensionality()),
- means(other.Means()),
- covariances(other.Covariances()),
- weights(other.Weights()),
- localFitter(FittingType()),
- fitter(localFitter) { /* Nothing to do. */ }
-
-// Copy constructor for when the other GMM uses the same fitting type.
-template<typename FittingType>
-GMM<FittingType>::GMM(const GMM<FittingType>& other) :
- gaussians(other.Gaussians()),
- dimensionality(other.Dimensionality()),
- means(other.Means()),
- covariances(other.Covariances()),
- weights(other.Weights()),
- localFitter(other.Fitter()),
- fitter(localFitter) { /* Nothing to do. */ }
-
-template<typename FittingType>
-template<typename OtherFittingType>
-GMM<FittingType>& GMM<FittingType>::operator=(
- const GMM<OtherFittingType>& other)
-{
- gaussians = other.Gaussians();
- dimensionality = other.Dimensionality();
- means = other.Means();
- covariances = other.Covariances();
- weights = other.Weights();
-
- return *this;
-}
-
-template<typename FittingType>
-GMM<FittingType>& GMM<FittingType>::operator=(const GMM<FittingType>& other)
-{
- gaussians = other.Gaussians();
- dimensionality = other.Dimensionality();
- means = other.Means();
- covariances = other.Covariances();
- weights = other.Weights();
- localFitter = other.Fitter();
-
- return *this;
-}
-
-// Load a GMM from file.
-template<typename FittingType>
-void GMM<FittingType>::Load(const std::string& filename)
-{
- util::SaveRestoreUtility load;
-
- if (!load.ReadFile(filename))
- Log::Fatal << "GMM::Load(): could not read file '" << filename << "'!\n";
-
- load.LoadParameter(gaussians, "gaussians");
- load.LoadParameter(dimensionality, "dimensionality");
- load.LoadParameter(weights, "weights");
-
- // We need to do a little error checking here.
- if (weights.n_elem != gaussians)
- {
- Log::Fatal << "GMM::Load('" << filename << "'): file reports " << gaussians
- << " gaussians but weights vector only contains " << weights.n_elem
- << " elements!" << std::endl;
- }
-
- means.resize(gaussians);
- covariances.resize(gaussians);
-
- for (size_t i = 0; i < gaussians; ++i)
- {
- std::stringstream o;
- o << i;
- std::string meanName = "mean" + o.str();
- std::string covName = "covariance" + o.str();
-
- load.LoadParameter(means[i], meanName);
- load.LoadParameter(covariances[i], covName);
- }
-}
-
-// Save a GMM to a file.
-template<typename FittingType>
-void GMM<FittingType>::Save(const std::string& filename) const
-{
- util::SaveRestoreUtility save;
- save.SaveParameter(gaussians, "gaussians");
- save.SaveParameter(dimensionality, "dimensionality");
- save.SaveParameter(weights, "weights");
- for (size_t i = 0; i < gaussians; ++i)
- {
- // Generate names for the XML nodes.
- std::stringstream o;
- o << i;
- std::string meanName = "mean" + o.str();
- std::string covName = "covariance" + o.str();
-
- // Now save them.
- save.SaveParameter(means[i], meanName);
- save.SaveParameter(covariances[i], covName);
- }
-
- if (!save.WriteFile(filename))
- Log::Warn << "GMM::Save(): error saving to '" << filename << "'.\n";
-}
-
-/**
- * Return the probability of the given observation being from this GMM.
- */
-template<typename FittingType>
-double GMM<FittingType>::Probability(const arma::vec& observation) const
-{
- // Sum the probability for each Gaussian in our mixture (and we have to
- // multiply by the prior for each Gaussian too).
- double sum = 0;
- for (size_t i = 0; i < gaussians; i++)
- sum += weights[i] * phi(observation, means[i], covariances[i]);
-
- return sum;
-}
-
-/**
- * Return the probability of the given observation being from the given
- * component in the mixture.
- */
-template<typename FittingType>
-double GMM<FittingType>::Probability(const arma::vec& observation,
- const size_t component) const
-{
- // We are only considering one Gaussian component -- so we only need to call
- // phi() once. We do consider the prior probability!
- return weights[component] *
- phi(observation, means[component], covariances[component]);
-}
-
-/**
- * Return a randomly generated observation according to the probability
- * distribution defined by this object.
- */
-template<typename FittingType>
-arma::vec GMM<FittingType>::Random() const
-{
- // Determine which Gaussian it will be coming from.
- double gaussRand = math::Random();
- size_t gaussian;
-
- double sumProb = 0;
- for (size_t g = 0; g < gaussians; g++)
- {
- sumProb += weights(g);
- if (gaussRand <= sumProb)
- {
- gaussian = g;
- break;
- }
- }
-
- return trans(chol(covariances[gaussian])) *
- arma::randn<arma::vec>(dimensionality) + means[gaussian];
-}
-
-/**
- * Fit the GMM to the given observations.
- */
-template<typename FittingType>
-double GMM<FittingType>::Estimate(const arma::mat& observations,
- const size_t trials)
-{
- double bestLikelihood; // This will be reported later.
-
- // We don't need to store temporary models if we are only doing one trial.
- if (trials == 1)
- {
- // Train the model. The user will have been warned earlier if the GMM was
- // initialized with no parameters (0 gaussians, dimensionality of 0).
- fitter.Estimate(observations, means, covariances, weights);
-
- bestLikelihood = LogLikelihood(observations, means, covariances, weights);
- }
- else
- {
- if (trials == 0)
- return -DBL_MAX; // It's what they asked for...
-
- // We need to keep temporary copies. We'll do the first training into the
- // actual model position, so that if it's the best we don't need to copy it.
- fitter.Estimate(observations, means, covariances, weights);
-
- bestLikelihood = LogLikelihood(observations, means, covariances, weights);
-
- Log::Info << "GMM::Estimate(): Log-likelihood of trial 0 is "
- << bestLikelihood << "." << std::endl;
-
- // Now the temporary model.
- std::vector<arma::vec> meansTrial(gaussians, arma::vec(dimensionality));
- std::vector<arma::mat> covariancesTrial(gaussians,
- arma::mat(dimensionality, dimensionality));
- arma::vec weightsTrial(gaussians);
-
- for (size_t trial = 1; trial < trials; ++trial)
- {
- fitter.Estimate(observations, meansTrial, covariancesTrial, weightsTrial);
-
- // Check to see if the log-likelihood of this one is better.
- double newLikelihood = LogLikelihood(observations, meansTrial,
- covariancesTrial, weightsTrial);
-
- Log::Info << "GMM::Estimate(): Log-likelihood of trial " << trial
- << " is " << newLikelihood << "." << std::endl;
-
- if (newLikelihood > bestLikelihood)
- {
- // Save new likelihood and copy new model.
- bestLikelihood = newLikelihood;
-
- means = meansTrial;
- covariances = covariancesTrial;
- weights = weightsTrial;
- }
- }
- }
-
- // Report final log-likelihood and return it.
- Log::Info << "GMM::Estimate(): log-likelihood of trained GMM is "
- << bestLikelihood << "." << std::endl;
- return bestLikelihood;
-}
-
-/**
- * Fit the GMM to the given observations, each of which has a certain
- * probability of being from this distribution.
- */
-template<typename FittingType>
-double GMM<FittingType>::Estimate(const arma::mat& observations,
- const arma::vec& probabilities,
- const size_t trials)
-{
- double bestLikelihood; // This will be reported later.
-
- // We don't need to store temporary models if we are only doing one trial.
- if (trials == 1)
- {
- // Train the model. The user will have been warned earlier if the GMM was
- // initialized with no parameters (0 gaussians, dimensionality of 0).
- fitter.Estimate(observations, probabilities, means, covariances, weights);
-
- bestLikelihood = LogLikelihood(observations, means, covariances, weights);
- }
- else
- {
- if (trials == 0)
- return -DBL_MAX; // It's what they asked for...
-
- // We need to keep temporary copies. We'll do the first training into the
- // actual model position, so that if it's the best we don't need to copy it.
- fitter.Estimate(observations, probabilities, means, covariances, weights);
-
- bestLikelihood = LogLikelihood(observations, means, covariances, weights);
-
- Log::Debug << "GMM::Estimate(): Log-likelihood of trial 0 is "
- << bestLikelihood << "." << std::endl;
-
- // Now the temporary model.
- std::vector<arma::vec> meansTrial(gaussians, arma::vec(dimensionality));
- std::vector<arma::mat> covariancesTrial(gaussians,
- arma::mat(dimensionality, dimensionality));
- arma::vec weightsTrial(gaussians);
-
- for (size_t trial = 1; trial < trials; ++trial)
- {
- fitter.Estimate(observations, meansTrial, covariancesTrial, weightsTrial);
-
- // Check to see if the log-likelihood of this one is better.
- double newLikelihood = LogLikelihood(observations, meansTrial,
- covariancesTrial, weightsTrial);
-
- Log::Debug << "GMM::Estimate(): Log-likelihood of trial " << trial
- << " is " << newLikelihood << "." << std::endl;
-
- if (newLikelihood > bestLikelihood)
- {
- // Save new likelihood and copy new model.
- bestLikelihood = newLikelihood;
-
- means = meansTrial;
- covariances = covariancesTrial;
- weights = weightsTrial;
- }
- }
- }
-
- // Report final log-likelihood and return it.
- Log::Info << "GMM::Estimate(): log-likelihood of trained GMM is "
- << bestLikelihood << "." << std::endl;
- return bestLikelihood;
-}
-
-/**
- * Classify the given observations as being from an individual component in this
- * GMM.
- */
-template<typename FittingType>
-void GMM<FittingType>::Classify(const arma::mat& observations,
- arma::Col<size_t>& labels) const
-{
- // This is not the best way to do this!
-
- // We should not have to fill this with values, because each one should be
- // overwritten.
- labels.set_size(observations.n_cols);
- for (size_t i = 0; i < observations.n_cols; ++i)
- {
- // Find maximum probability component.
- double probability = 0;
- for (size_t j = 0; j < gaussians; ++j)
- {
- double newProb = Probability(observations.unsafe_col(i), j);
- if (newProb >= probability)
- {
- probability = newProb;
- labels[i] = j;
- }
- }
- }
-}
-
-/**
- * Get the log-likelihood of this data's fit to the model.
- */
-template<typename FittingType>
-double GMM<FittingType>::LogLikelihood(
- const arma::mat& data,
- const std::vector<arma::vec>& meansL,
- const std::vector<arma::mat>& covariancesL,
- const arma::vec& weightsL) const
-{
- double loglikelihood = 0;
-
- arma::vec phis;
- arma::mat likelihoods(gaussians, data.n_cols);
- for (size_t i = 0; i < gaussians; i++)
- {
- phi(data, meansL[i], covariancesL[i], phis);
- likelihoods.row(i) = weightsL(i) * trans(phis);
- }
-
- // Now sum over every point.
- for (size_t j = 0; j < data.n_cols; j++)
- loglikelihood += log(accu(likelihoods.col(j)));
-
- return loglikelihood;
-}
-
-}; // namespace gmm
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/gmm_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/gmm/gmm_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/gmm_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/gmm_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,395 @@
+/**
+ * @file gmm_impl.hpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ * @author Ryan Curtin
+ *
+ * Implementation of template-based GMM methods.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_GMM_GMM_IMPL_HPP
+#define __MLPACK_METHODS_GMM_GMM_IMPL_HPP
+
+// In case it hasn't already been included.
+#include "gmm.hpp"
+
+#include <mlpack/core/util/save_restore_utility.hpp>
+
+namespace mlpack {
+namespace gmm {
+
+// Copy constructor.
+template<typename FittingType>
+template<typename OtherFittingType>
+GMM<FittingType>::GMM(const GMM<OtherFittingType>& other) :
+ gaussians(other.Gaussians()),
+ dimensionality(other.Dimensionality()),
+ means(other.Means()),
+ covariances(other.Covariances()),
+ weights(other.Weights()),
+ localFitter(FittingType()),
+ fitter(localFitter) { /* Nothing to do. */ }
+
+// Copy constructor for when the other GMM uses the same fitting type.
+template<typename FittingType>
+GMM<FittingType>::GMM(const GMM<FittingType>& other) :
+ gaussians(other.Gaussians()),
+ dimensionality(other.Dimensionality()),
+ means(other.Means()),
+ covariances(other.Covariances()),
+ weights(other.Weights()),
+ localFitter(other.Fitter()),
+ fitter(localFitter) { /* Nothing to do. */ }
+
+template<typename FittingType>
+template<typename OtherFittingType>
+GMM<FittingType>& GMM<FittingType>::operator=(
+ const GMM<OtherFittingType>& other)
+{
+ gaussians = other.Gaussians();
+ dimensionality = other.Dimensionality();
+ means = other.Means();
+ covariances = other.Covariances();
+ weights = other.Weights();
+
+ return *this;
+}
+
+template<typename FittingType>
+GMM<FittingType>& GMM<FittingType>::operator=(const GMM<FittingType>& other)
+{
+ gaussians = other.Gaussians();
+ dimensionality = other.Dimensionality();
+ means = other.Means();
+ covariances = other.Covariances();
+ weights = other.Weights();
+ localFitter = other.Fitter();
+
+ return *this;
+}
+
+// Load a GMM from file.
+template<typename FittingType>
+void GMM<FittingType>::Load(const std::string& filename)
+{
+ util::SaveRestoreUtility load;
+
+ if (!load.ReadFile(filename))
+ Log::Fatal << "GMM::Load(): could not read file '" << filename << "'!\n";
+
+ load.LoadParameter(gaussians, "gaussians");
+ load.LoadParameter(dimensionality, "dimensionality");
+ load.LoadParameter(weights, "weights");
+
+ // We need to do a little error checking here.
+ if (weights.n_elem != gaussians)
+ {
+ Log::Fatal << "GMM::Load('" << filename << "'): file reports " << gaussians
+ << " gaussians but weights vector only contains " << weights.n_elem
+ << " elements!" << std::endl;
+ }
+
+ means.resize(gaussians);
+ covariances.resize(gaussians);
+
+ for (size_t i = 0; i < gaussians; ++i)
+ {
+ std::stringstream o;
+ o << i;
+ std::string meanName = "mean" + o.str();
+ std::string covName = "covariance" + o.str();
+
+ load.LoadParameter(means[i], meanName);
+ load.LoadParameter(covariances[i], covName);
+ }
+}
+
+// Save a GMM to a file.
+template<typename FittingType>
+void GMM<FittingType>::Save(const std::string& filename) const
+{
+ util::SaveRestoreUtility save;
+ save.SaveParameter(gaussians, "gaussians");
+ save.SaveParameter(dimensionality, "dimensionality");
+ save.SaveParameter(weights, "weights");
+ for (size_t i = 0; i < gaussians; ++i)
+ {
+ // Generate names for the XML nodes.
+ std::stringstream o;
+ o << i;
+ std::string meanName = "mean" + o.str();
+ std::string covName = "covariance" + o.str();
+
+ // Now save them.
+ save.SaveParameter(means[i], meanName);
+ save.SaveParameter(covariances[i], covName);
+ }
+
+ if (!save.WriteFile(filename))
+ Log::Warn << "GMM::Save(): error saving to '" << filename << "'.\n";
+}
+
+/**
+ * Return the probability of the given observation being from this GMM.
+ */
+template<typename FittingType>
+double GMM<FittingType>::Probability(const arma::vec& observation) const
+{
+ // Sum the probability for each Gaussian in our mixture (and we have to
+ // multiply by the prior for each Gaussian too).
+ double sum = 0;
+ for (size_t i = 0; i < gaussians; i++)
+ sum += weights[i] * phi(observation, means[i], covariances[i]);
+
+ return sum;
+}
+
+/**
+ * Return the probability of the given observation being from the given
+ * component in the mixture.
+ */
+template<typename FittingType>
+double GMM<FittingType>::Probability(const arma::vec& observation,
+ const size_t component) const
+{
+ // We are only considering one Gaussian component -- so we only need to call
+ // phi() once. We do consider the prior probability!
+ return weights[component] *
+ phi(observation, means[component], covariances[component]);
+}
+
+/**
+ * Return a randomly generated observation according to the probability
+ * distribution defined by this object.
+ */
+template<typename FittingType>
+arma::vec GMM<FittingType>::Random() const
+{
+ // Determine which Gaussian it will be coming from.
+ double gaussRand = math::Random();
+ size_t gaussian;
+
+ double sumProb = 0;
+ for (size_t g = 0; g < gaussians; g++)
+ {
+ sumProb += weights(g);
+ if (gaussRand <= sumProb)
+ {
+ gaussian = g;
+ break;
+ }
+ }
+
+ return trans(chol(covariances[gaussian])) *
+ arma::randn<arma::vec>(dimensionality) + means[gaussian];
+}
+
+/**
+ * Fit the GMM to the given observations.
+ */
+template<typename FittingType>
+double GMM<FittingType>::Estimate(const arma::mat& observations,
+ const size_t trials)
+{
+ double bestLikelihood; // This will be reported later.
+
+ // We don't need to store temporary models if we are only doing one trial.
+ if (trials == 1)
+ {
+ // Train the model. The user will have been warned earlier if the GMM was
+ // initialized with no parameters (0 gaussians, dimensionality of 0).
+ fitter.Estimate(observations, means, covariances, weights);
+
+ bestLikelihood = LogLikelihood(observations, means, covariances, weights);
+ }
+ else
+ {
+ if (trials == 0)
+ return -DBL_MAX; // It's what they asked for...
+
+ // We need to keep temporary copies. We'll do the first training into the
+ // actual model position, so that if it's the best we don't need to copy it.
+ fitter.Estimate(observations, means, covariances, weights);
+
+ bestLikelihood = LogLikelihood(observations, means, covariances, weights);
+
+ Log::Info << "GMM::Estimate(): Log-likelihood of trial 0 is "
+ << bestLikelihood << "." << std::endl;
+
+ // Now the temporary model.
+ std::vector<arma::vec> meansTrial(gaussians, arma::vec(dimensionality));
+ std::vector<arma::mat> covariancesTrial(gaussians,
+ arma::mat(dimensionality, dimensionality));
+ arma::vec weightsTrial(gaussians);
+
+ for (size_t trial = 1; trial < trials; ++trial)
+ {
+ fitter.Estimate(observations, meansTrial, covariancesTrial, weightsTrial);
+
+ // Check to see if the log-likelihood of this one is better.
+ double newLikelihood = LogLikelihood(observations, meansTrial,
+ covariancesTrial, weightsTrial);
+
+ Log::Info << "GMM::Estimate(): Log-likelihood of trial " << trial
+ << " is " << newLikelihood << "." << std::endl;
+
+ if (newLikelihood > bestLikelihood)
+ {
+ // Save new likelihood and copy new model.
+ bestLikelihood = newLikelihood;
+
+ means = meansTrial;
+ covariances = covariancesTrial;
+ weights = weightsTrial;
+ }
+ }
+ }
+
+ // Report final log-likelihood and return it.
+ Log::Info << "GMM::Estimate(): log-likelihood of trained GMM is "
+ << bestLikelihood << "." << std::endl;
+ return bestLikelihood;
+}
+
+/**
+ * Fit the GMM to the given observations, each of which has a certain
+ * probability of being from this distribution.
+ */
+template<typename FittingType>
+double GMM<FittingType>::Estimate(const arma::mat& observations,
+ const arma::vec& probabilities,
+ const size_t trials)
+{
+ double bestLikelihood; // This will be reported later.
+
+ // We don't need to store temporary models if we are only doing one trial.
+ if (trials == 1)
+ {
+ // Train the model. The user will have been warned earlier if the GMM was
+ // initialized with no parameters (0 gaussians, dimensionality of 0).
+ fitter.Estimate(observations, probabilities, means, covariances, weights);
+
+ bestLikelihood = LogLikelihood(observations, means, covariances, weights);
+ }
+ else
+ {
+ if (trials == 0)
+ return -DBL_MAX; // It's what they asked for...
+
+ // We need to keep temporary copies. We'll do the first training into the
+ // actual model position, so that if it's the best we don't need to copy it.
+ fitter.Estimate(observations, probabilities, means, covariances, weights);
+
+ bestLikelihood = LogLikelihood(observations, means, covariances, weights);
+
+ Log::Debug << "GMM::Estimate(): Log-likelihood of trial 0 is "
+ << bestLikelihood << "." << std::endl;
+
+ // Now the temporary model.
+ std::vector<arma::vec> meansTrial(gaussians, arma::vec(dimensionality));
+ std::vector<arma::mat> covariancesTrial(gaussians,
+ arma::mat(dimensionality, dimensionality));
+ arma::vec weightsTrial(gaussians);
+
+ for (size_t trial = 1; trial < trials; ++trial)
+ {
+ fitter.Estimate(observations, meansTrial, covariancesTrial, weightsTrial);
+
+ // Check to see if the log-likelihood of this one is better.
+ double newLikelihood = LogLikelihood(observations, meansTrial,
+ covariancesTrial, weightsTrial);
+
+ Log::Debug << "GMM::Estimate(): Log-likelihood of trial " << trial
+ << " is " << newLikelihood << "." << std::endl;
+
+ if (newLikelihood > bestLikelihood)
+ {
+ // Save new likelihood and copy new model.
+ bestLikelihood = newLikelihood;
+
+ means = meansTrial;
+ covariances = covariancesTrial;
+ weights = weightsTrial;
+ }
+ }
+ }
+
+ // Report final log-likelihood and return it.
+ Log::Info << "GMM::Estimate(): log-likelihood of trained GMM is "
+ << bestLikelihood << "." << std::endl;
+ return bestLikelihood;
+}
+
+/**
+ * Classify the given observations as being from an individual component in this
+ * GMM.
+ */
+template<typename FittingType>
+void GMM<FittingType>::Classify(const arma::mat& observations,
+ arma::Col<size_t>& labels) const
+{
+ // This is not the best way to do this!
+
+ // We should not have to fill this with values, because each one should be
+ // overwritten.
+ labels.set_size(observations.n_cols);
+ for (size_t i = 0; i < observations.n_cols; ++i)
+ {
+ // Find maximum probability component.
+ double probability = 0;
+ for (size_t j = 0; j < gaussians; ++j)
+ {
+ double newProb = Probability(observations.unsafe_col(i), j);
+ if (newProb >= probability)
+ {
+ probability = newProb;
+ labels[i] = j;
+ }
+ }
+ }
+}
+
+/**
+ * Get the log-likelihood of this data's fit to the model.
+ */
+template<typename FittingType>
+double GMM<FittingType>::LogLikelihood(
+ const arma::mat& data,
+ const std::vector<arma::vec>& meansL,
+ const std::vector<arma::mat>& covariancesL,
+ const arma::vec& weightsL) const
+{
+ double loglikelihood = 0;
+
+ arma::vec phis;
+ arma::mat likelihoods(gaussians, data.n_cols);
+ for (size_t i = 0; i < gaussians; i++)
+ {
+ phi(data, meansL[i], covariancesL[i], phis);
+ likelihoods.row(i) = weightsL(i) * trans(phis);
+ }
+
+ // Now sum over every point.
+ for (size_t j = 0; j < data.n_cols; j++)
+ loglikelihood += log(accu(likelihoods.col(j)));
+
+ return loglikelihood;
+}
+
+}; // namespace gmm
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/gmm_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/gmm/gmm_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/gmm_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,170 +0,0 @@
-/**
- * @author Parikshit Ram (pram at cc.gatech.edu)
- * @file gmm_main.cpp
- *
- * This program trains a mixture of Gaussians on a given data matrix.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-
-#include "gmm.hpp"
-
-#include <mlpack/methods/kmeans/refined_start.hpp>
-
-using namespace mlpack;
-using namespace mlpack::gmm;
-using namespace mlpack::util;
-using namespace mlpack::kmeans;
-
-PROGRAM_INFO("Gaussian Mixture Model (GMM) Training",
- "This program takes a parametric estimate of a Gaussian mixture model (GMM)"
- " using the EM algorithm to find the maximum likelihood estimate. The "
- "model is saved to an XML file, which contains information about each "
- "Gaussian."
- "\n\n"
- "If GMM training fails with an error indicating that a covariance matrix "
- "could not be inverted, be sure that the 'no_force_positive' flag was not "
- "specified. Alternately, adding a small amount of Gaussian noise to the "
- "entire dataset may help prevent Gaussians with zero variance in a "
- "particular dimension, which is usually the cause of non-invertible "
- "covariance matrices."
- "\n\n"
- "The 'no_force_positive' flag, if set, will avoid the checks after each "
- "iteration of the EM algorithm which ensure that the covariance matrices "
- "are positive definite. Specifying the flag can cause faster runtime, "
- "but may also cause non-positive definite covariance matrices, which will "
- "cause the program to crash.");
-
-PARAM_STRING_REQ("input_file", "File containing the data on which the model "
- "will be fit.", "i");
-PARAM_INT("gaussians", "Number of Gaussians in the GMM.", "g", 1);
-PARAM_STRING("output_file", "The file to write the trained GMM parameters into "
- "(as XML).", "o", "gmm.xml");
-PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
-PARAM_INT("trials", "Number of trials to perform in training GMM.", "t", 10);
-
-// Parameters for EM algorithm.
-PARAM_DOUBLE("tolerance", "Tolerance for convergence of EM.", "T", 1e-10);
-PARAM_FLAG("no_force_positive", "Do not force the covariance matrices to be "
- "positive definite.", "P");
-PARAM_INT("max_iterations", "Maximum number of iterations of EM algorithm "
- "(passing 0 will run until convergence).", "n", 250);
-
-// Parameters for dataset modification.
-PARAM_DOUBLE("noise", "Variance of zero-mean Gaussian noise to add to data.",
- "N", 0);
-
-// Parameters for k-means initialization.
-PARAM_FLAG("refined_start", "During the initialization, use refined initial "
- "positions for k-means clustering (Bradley and Fayyad, 1998).", "r");
-PARAM_INT("samplings", "If using --refined_start, specify the number of "
- "samplings used for initial points.", "S", 100);
-PARAM_DOUBLE("percentage", "If using --refined_start, specify the percentage of"
- " the dataset used for each sampling (should be between 0.0 and 1.0).",
- "p", 0.02);
-
-int main(int argc, char* argv[])
-{
- CLI::ParseCommandLine(argc, argv);
-
- // Check parameters and load data.
- if (CLI::GetParam<int>("seed") != 0)
- math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
- else
- math::RandomSeed((size_t) std::time(NULL));
-
- arma::mat dataPoints;
- data::Load(CLI::GetParam<std::string>("input_file").c_str(), dataPoints,
- true);
-
- const int gaussians = CLI::GetParam<int>("gaussians");
- if (gaussians <= 0)
- {
- Log::Fatal << "Invalid number of Gaussians (" << gaussians << "); must "
- "be greater than or equal to 1." << std::endl;
- }
-
- // Do we need to add noise to the dataset?
- if (CLI::HasParam("noise"))
- {
- Timer::Start("noise_addition");
- const double noise = CLI::GetParam<double>("noise");
- dataPoints += noise * arma::randn(dataPoints.n_rows, dataPoints.n_cols);
- Log::Info << "Added zero-mean Gaussian noise with variance " << noise
- << " to dataset." << std::endl;
- Timer::Stop("noise_addition");
- }
-
- // Gather parameters for EMFit object.
- const size_t maxIterations = (size_t) CLI::GetParam<int>("max_iterations");
- const double tolerance = CLI::GetParam<double>("tolerance");
- const bool forcePositive = !CLI::HasParam("no_force_positive");
-
- // This gets a bit weird because we need different types depending on whether
- // --refined_start is specified.
- double likelihood;
- if (CLI::HasParam("refined_start"))
- {
- const int samplings = CLI::GetParam<int>("samplings");
- const double percentage = CLI::GetParam<double>("percentage");
-
- if (samplings <= 0)
- Log::Fatal << "Number of samplings (" << samplings << ") must be greater"
- << " than 0!" << std::endl;
-
- if (percentage <= 0.0 || percentage > 1.0)
- Log::Fatal << "Percentage for sampling (" << percentage << ") must be "
- << "greater than 0.0 and less than or equal to 1.0!" << std::endl;
-
- typedef KMeans<metric::SquaredEuclideanDistance, RefinedStart> KMeansType;
-
- // These are default parameters.
- KMeansType k(1000, 1.0, metric::SquaredEuclideanDistance(),
- RefinedStart(samplings, percentage));
-
- EMFit<KMeansType> em(maxIterations, tolerance, forcePositive, k);
-
- GMM<EMFit<KMeansType> > gmm(size_t(gaussians), dataPoints.n_rows, em);
-
- // Compute the parameters of the model using the EM algorithm.
- Timer::Start("em");
- likelihood = gmm.Estimate(dataPoints, CLI::GetParam<int>("trials"));
- Timer::Stop("em");
-
- // Save results.
- gmm.Save(CLI::GetParam<std::string>("output_file"));
- }
- else
- {
- EMFit<> em(maxIterations, tolerance, forcePositive);
-
- // Calculate mixture of Gaussians.
- GMM<> gmm(size_t(gaussians), dataPoints.n_rows, em);
-
- // Compute the parameters of the model using the EM algorithm.
- Timer::Start("em");
- likelihood = gmm.Estimate(dataPoints, CLI::GetParam<int>("trials"));
- Timer::Stop("em");
-
- // Save results.
- gmm.Save(CLI::GetParam<std::string>("output_file"));
- }
-
- Log::Info << "Log-likelihood of estimate: " << likelihood << ".\n";
-
-
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/gmm_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/gmm/gmm_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/gmm_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/gmm_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,170 @@
+/**
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ * @file gmm_main.cpp
+ *
+ * This program trains a mixture of Gaussians on a given data matrix.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+
+#include "gmm.hpp"
+
+#include <mlpack/methods/kmeans/refined_start.hpp>
+
+using namespace mlpack;
+using namespace mlpack::gmm;
+using namespace mlpack::util;
+using namespace mlpack::kmeans;
+
+PROGRAM_INFO("Gaussian Mixture Model (GMM) Training",
+ "This program takes a parametric estimate of a Gaussian mixture model (GMM)"
+ " using the EM algorithm to find the maximum likelihood estimate. The "
+ "model is saved to an XML file, which contains information about each "
+ "Gaussian."
+ "\n\n"
+ "If GMM training fails with an error indicating that a covariance matrix "
+ "could not be inverted, be sure that the 'no_force_positive' flag was not "
+ "specified. Alternately, adding a small amount of Gaussian noise to the "
+ "entire dataset may help prevent Gaussians with zero variance in a "
+ "particular dimension, which is usually the cause of non-invertible "
+ "covariance matrices."
+ "\n\n"
+ "The 'no_force_positive' flag, if set, will avoid the checks after each "
+ "iteration of the EM algorithm which ensure that the covariance matrices "
+ "are positive definite. Specifying the flag can cause faster runtime, "
+ "but may also cause non-positive definite covariance matrices, which will "
+ "cause the program to crash.");
+
+PARAM_STRING_REQ("input_file", "File containing the data on which the model "
+ "will be fit.", "i");
+PARAM_INT("gaussians", "Number of Gaussians in the GMM.", "g", 1);
+PARAM_STRING("output_file", "The file to write the trained GMM parameters into "
+ "(as XML).", "o", "gmm.xml");
+PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
+PARAM_INT("trials", "Number of trials to perform in training GMM.", "t", 10);
+
+// Parameters for EM algorithm.
+PARAM_DOUBLE("tolerance", "Tolerance for convergence of EM.", "T", 1e-10);
+PARAM_FLAG("no_force_positive", "Do not force the covariance matrices to be "
+ "positive definite.", "P");
+PARAM_INT("max_iterations", "Maximum number of iterations of EM algorithm "
+ "(passing 0 will run until convergence).", "n", 250);
+
+// Parameters for dataset modification.
+PARAM_DOUBLE("noise", "Variance of zero-mean Gaussian noise to add to data.",
+ "N", 0);
+
+// Parameters for k-means initialization.
+PARAM_FLAG("refined_start", "During the initialization, use refined initial "
+ "positions for k-means clustering (Bradley and Fayyad, 1998).", "r");
+PARAM_INT("samplings", "If using --refined_start, specify the number of "
+ "samplings used for initial points.", "S", 100);
+PARAM_DOUBLE("percentage", "If using --refined_start, specify the percentage of"
+ " the dataset used for each sampling (should be between 0.0 and 1.0).",
+ "p", 0.02);
+
+int main(int argc, char* argv[])
+{
+ CLI::ParseCommandLine(argc, argv);
+
+ // Check parameters and load data.
+ if (CLI::GetParam<int>("seed") != 0)
+ math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
+ else
+ math::RandomSeed((size_t) std::time(NULL));
+
+ arma::mat dataPoints;
+ data::Load(CLI::GetParam<std::string>("input_file").c_str(), dataPoints,
+ true);
+
+ const int gaussians = CLI::GetParam<int>("gaussians");
+ if (gaussians <= 0)
+ {
+ Log::Fatal << "Invalid number of Gaussians (" << gaussians << "); must "
+ "be greater than or equal to 1." << std::endl;
+ }
+
+ // Do we need to add noise to the dataset?
+ if (CLI::HasParam("noise"))
+ {
+ Timer::Start("noise_addition");
+ const double noise = CLI::GetParam<double>("noise");
+ dataPoints += noise * arma::randn(dataPoints.n_rows, dataPoints.n_cols);
+ Log::Info << "Added zero-mean Gaussian noise with variance " << noise
+ << " to dataset." << std::endl;
+ Timer::Stop("noise_addition");
+ }
+
+ // Gather parameters for EMFit object.
+ const size_t maxIterations = (size_t) CLI::GetParam<int>("max_iterations");
+ const double tolerance = CLI::GetParam<double>("tolerance");
+ const bool forcePositive = !CLI::HasParam("no_force_positive");
+
+ // This gets a bit weird because we need different types depending on whether
+ // --refined_start is specified.
+ double likelihood;
+ if (CLI::HasParam("refined_start"))
+ {
+ const int samplings = CLI::GetParam<int>("samplings");
+ const double percentage = CLI::GetParam<double>("percentage");
+
+ if (samplings <= 0)
+ Log::Fatal << "Number of samplings (" << samplings << ") must be greater"
+ << " than 0!" << std::endl;
+
+ if (percentage <= 0.0 || percentage > 1.0)
+ Log::Fatal << "Percentage for sampling (" << percentage << ") must be "
+ << "greater than 0.0 and less than or equal to 1.0!" << std::endl;
+
+ typedef KMeans<metric::SquaredEuclideanDistance, RefinedStart> KMeansType;
+
+ // These are default parameters.
+ KMeansType k(1000, 1.0, metric::SquaredEuclideanDistance(),
+ RefinedStart(samplings, percentage));
+
+ EMFit<KMeansType> em(maxIterations, tolerance, forcePositive, k);
+
+ GMM<EMFit<KMeansType> > gmm(size_t(gaussians), dataPoints.n_rows, em);
+
+ // Compute the parameters of the model using the EM algorithm.
+ Timer::Start("em");
+ likelihood = gmm.Estimate(dataPoints, CLI::GetParam<int>("trials"));
+ Timer::Stop("em");
+
+ // Save results.
+ gmm.Save(CLI::GetParam<std::string>("output_file"));
+ }
+ else
+ {
+ EMFit<> em(maxIterations, tolerance, forcePositive);
+
+ // Calculate mixture of Gaussians.
+ GMM<> gmm(size_t(gaussians), dataPoints.n_rows, em);
+
+ // Compute the parameters of the model using the EM algorithm.
+ Timer::Start("em");
+ likelihood = gmm.Estimate(dataPoints, CLI::GetParam<int>("trials"));
+ Timer::Stop("em");
+
+ // Save results.
+ gmm.Save(CLI::GetParam<std::string>("output_file"));
+ }
+
+ Log::Info << "Log-likelihood of estimate: " << likelihood << ".\n";
+
+
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/phi.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/gmm/phi.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/phi.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,162 +0,0 @@
-/**
- * @author Parikshit Ram (pram at cc.gatech.edu)
- * @file phi.hpp
- *
- * This file computes the Gaussian probability
- * density function
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_MOG_PHI_HPP
-#define __MLPACK_METHODS_MOG_PHI_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace gmm {
-
-/**
- * Calculates the univariate Gaussian probability density function.
- *
- * Example use:
- * @code
- * double x, mean, var;
- * ....
- * double f = phi(x, mean, var);
- * @endcode
- *
- * @param x Observation.
- * @param mean Mean of univariate Gaussian.
- * @param var Variance of univariate Gaussian.
- * @return Probability of x being observed from the given univariate Gaussian.
- */
-inline double phi(const double x, const double mean, const double var)
-{
- return exp(-1.0 * ((x - mean) * (x - mean) / (2 * var)))
- / sqrt(2 * M_PI * var);
-}
-
-/**
- * Calculates the multivariate Gaussian probability density function.
- *
- * Example use:
- * @code
- * extern arma::vec x, mean;
- * extern arma::mat cov;
- * ....
- * double f = phi(x, mean, cov);
- * @endcode
- *
- * @param x Observation.
- * @param mean Mean of multivariate Gaussian.
- * @param cov Covariance of multivariate Gaussian.
- * @return Probability of x being observed from the given multivariate Gaussian.
- */
-inline double phi(const arma::vec& x,
- const arma::vec& mean,
- const arma::mat& cov)
-{
- arma::vec diff = mean - x;
-
- // Parentheses required for Armadillo 3.0.0 bug.
- arma::vec exponent = -0.5 * (trans(diff) * inv(cov) * diff);
-
- // TODO: What if det(cov) < 0?
- return pow(2 * M_PI, (double) x.n_elem / -2.0) * pow(det(cov), -0.5) *
- exp(exponent[0]);
-}
-
-/**
- * Calculates the multivariate Gaussian probability density function and also
- * the gradients with respect to the mean and the variance.
- *
- * Example use:
- * @code
- * extern arma::vec x, mean, g_mean, g_cov;
- * std::vector<arma::mat> d_cov; // the dSigma
- * ....
- * double f = phi(x, mean, cov, d_cov, &g_mean, &g_cov);
- * @endcode
- */
-inline double phi(const arma::vec& x,
- const arma::vec& mean,
- const arma::mat& cov,
- const std::vector<arma::mat>& d_cov,
- arma::vec& g_mean,
- arma::vec& g_cov)
-{
- // We don't call out to another version of the function to avoid inverting the
- // covariance matrix more than once.
- arma::mat cinv = inv(cov);
-
- arma::vec diff = mean - x;
- // Parentheses required for Armadillo 3.0.0 bug.
- arma::vec exponent = -0.5 * (trans(diff) * inv(cov) * diff);
-
- long double f = pow(2 * M_PI, (double) x.n_elem / 2) * pow(det(cov), -0.5)
- * exp(exponent[0]);
-
- // Calculate the g_mean values; this is a (1 x dim) vector.
- arma::vec invDiff = cinv * diff;
- g_mean = f * invDiff;
-
- // Calculate the g_cov values; this is a (1 x (dim * (dim + 1) / 2)) vector.
- for (size_t i = 0; i < d_cov.size(); i++)
- {
- arma::mat inv_d = cinv * d_cov[i];
-
- g_cov[i] = f * dot(d_cov[i] * invDiff, invDiff) +
- accu(inv_d.diag()) / 2;
- }
-
- return f;
-}
-
-/**
- * Calculates the multivariate Gaussian probability density function for each
- * data point (column) in the given matrix, with respect to the given mean and
- * variance.
- *
- * @param x List of observations.
- * @param mean Mean of multivariate Gaussian.
- * @param cov Covariance of multivariate Gaussian.
- * @param probabilities Output probabilities for each input observation.
- */
-inline void phi(const arma::mat& x,
- const arma::vec& mean,
- const arma::mat& cov,
- arma::vec& probabilities)
-{
- // Column i of 'diffs' is the difference between x.col(i) and the mean.
- arma::mat diffs = x - (mean * arma::ones<arma::rowvec>(x.n_cols));
-
- // Now, we only want to calculate the diagonal elements of (diffs' * cov^-1 *
- // diffs). We just don't need any of the other elements. We can calculate
- // the right hand part of the equation (instead of the left side) so that
- // later we are referencing columns, not rows -- that is faster.
- arma::mat rhs = -0.5 * inv(cov) * diffs;
- arma::vec exponents(diffs.n_cols); // We will now fill this.
- for (size_t i = 0; i < diffs.n_cols; i++)
- exponents(i) = exp(accu(diffs.unsafe_col(i) % rhs.unsafe_col(i)));
-
- probabilities = pow(2 * M_PI, (double) mean.n_elem / -2.0) *
- pow(det(cov), -0.5) * exponents;
-}
-
-}; // namespace gmm
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/phi.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/gmm/phi.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/phi.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/gmm/phi.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,162 @@
+/**
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ * @file phi.hpp
+ *
+ * This file computes the Gaussian probability
+ * density function
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_MOG_PHI_HPP
+#define __MLPACK_METHODS_MOG_PHI_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace gmm {
+
+/**
+ * Calculates the univariate Gaussian probability density function.
+ *
+ * Example use:
+ * @code
+ * double x, mean, var;
+ * ....
+ * double f = phi(x, mean, var);
+ * @endcode
+ *
+ * @param x Observation.
+ * @param mean Mean of univariate Gaussian.
+ * @param var Variance of univariate Gaussian.
+ * @return Probability of x being observed from the given univariate Gaussian.
+ */
+inline double phi(const double x, const double mean, const double var)
+{
+ return exp(-1.0 * ((x - mean) * (x - mean) / (2 * var)))
+ / sqrt(2 * M_PI * var);
+}
+
+/**
+ * Calculates the multivariate Gaussian probability density function.
+ *
+ * Example use:
+ * @code
+ * extern arma::vec x, mean;
+ * extern arma::mat cov;
+ * ....
+ * double f = phi(x, mean, cov);
+ * @endcode
+ *
+ * @param x Observation.
+ * @param mean Mean of multivariate Gaussian.
+ * @param cov Covariance of multivariate Gaussian.
+ * @return Probability of x being observed from the given multivariate Gaussian.
+ */
+inline double phi(const arma::vec& x,
+ const arma::vec& mean,
+ const arma::mat& cov)
+{
+ arma::vec diff = mean - x;
+
+ // Parentheses required for Armadillo 3.0.0 bug.
+ arma::vec exponent = -0.5 * (trans(diff) * inv(cov) * diff);
+
+ // TODO: What if det(cov) < 0?
+ return pow(2 * M_PI, (double) x.n_elem / -2.0) * pow(det(cov), -0.5) *
+ exp(exponent[0]);
+}
+
+/**
+ * Calculates the multivariate Gaussian probability density function and also
+ * the gradients with respect to the mean and the variance.
+ *
+ * Example use:
+ * @code
+ * extern arma::vec x, mean, g_mean, g_cov;
+ * std::vector<arma::mat> d_cov; // the dSigma
+ * ....
+ * double f = phi(x, mean, cov, d_cov, &g_mean, &g_cov);
+ * @endcode
+ */
+inline double phi(const arma::vec& x,
+ const arma::vec& mean,
+ const arma::mat& cov,
+ const std::vector<arma::mat>& d_cov,
+ arma::vec& g_mean,
+ arma::vec& g_cov)
+{
+ // We don't call out to another version of the function to avoid inverting the
+ // covariance matrix more than once.
+ arma::mat cinv = inv(cov);
+
+ arma::vec diff = mean - x;
+ // Parentheses required for Armadillo 3.0.0 bug.
+ arma::vec exponent = -0.5 * (trans(diff) * inv(cov) * diff);
+
+ long double f = pow(2 * M_PI, (double) x.n_elem / 2) * pow(det(cov), -0.5)
+ * exp(exponent[0]);
+
+ // Calculate the g_mean values; this is a (1 x dim) vector.
+ arma::vec invDiff = cinv * diff;
+ g_mean = f * invDiff;
+
+ // Calculate the g_cov values; this is a (1 x (dim * (dim + 1) / 2)) vector.
+ for (size_t i = 0; i < d_cov.size(); i++)
+ {
+ arma::mat inv_d = cinv * d_cov[i];
+
+ g_cov[i] = f * dot(d_cov[i] * invDiff, invDiff) +
+ accu(inv_d.diag()) / 2;
+ }
+
+ return f;
+}
+
+/**
+ * Calculates the multivariate Gaussian probability density function for each
+ * data point (column) in the given matrix, with respect to the given mean and
+ * variance.
+ *
+ * @param x List of observations.
+ * @param mean Mean of multivariate Gaussian.
+ * @param cov Covariance of multivariate Gaussian.
+ * @param probabilities Output probabilities for each input observation.
+ */
+inline void phi(const arma::mat& x,
+ const arma::vec& mean,
+ const arma::mat& cov,
+ arma::vec& probabilities)
+{
+ // Column i of 'diffs' is the difference between x.col(i) and the mean.
+ arma::mat diffs = x - (mean * arma::ones<arma::rowvec>(x.n_cols));
+
+ // Now, we only want to calculate the diagonal elements of (diffs' * cov^-1 *
+ // diffs). We just don't need any of the other elements. We can calculate
+ // the right hand part of the equation (instead of the left side) so that
+ // later we are referencing columns, not rows -- that is faster.
+ arma::mat rhs = -0.5 * inv(cov) * diffs;
+ arma::vec exponents(diffs.n_cols); // We will now fill this.
+ for (size_t i = 0; i < diffs.n_cols; i++)
+ exponents(i) = exp(accu(diffs.unsafe_col(i) % rhs.unsafe_col(i)));
+
+ probabilities = pow(2 * M_PI, (double) mean.n_elem / -2.0) *
+ pow(det(cov), -0.5) * exponents;
+}
+
+}; // namespace gmm
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/hmm/hmm.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,339 +0,0 @@
-/**
- * @file hmm.hpp
- * @author Ryan Curtin
- * @author Tran Quoc Long
- *
- * Definition of HMM class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_HMM_HMM_HPP
-#define __MLPACK_METHODS_HMM_HMM_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace hmm /** Hidden Markov Models. */ {
-
-/**
- * A class that represents a Hidden Markov Model with an arbitrary type of
- * emission distribution. This HMM class supports training (supervised and
- * unsupervised), prediction of state sequences via the Viterbi algorithm,
- * estimation of state probabilities, generation of random sequences, and
- * calculation of the log-likelihood of a given sequence.
- *
- * The template parameter, Distribution, specifies the distribution which the
- * emissions follow. The class should implement the following functions:
- *
- * @code
- * class Distribution
- * {
- * public:
- * // The type of observation used by this distribution.
- * typedef something DataType;
- *
- * // Return the probability of the given observation.
- * double Probability(const DataType& observation) const;
- *
- * // Estimate the distribution based on the given observations.
- * void Estimate(const std::vector<DataType>& observations);
- *
- * // Estimate the distribution based on the given observations, given also
- * // the probability of each observation coming from this distribution.
- * void Estimate(const std::vector<DataType>& observations,
- * const std::vector<double>& probabilities);
- * };
- * @endcode
- *
- * See the mlpack::distribution::DiscreteDistribution class for an example. One
- * would use the DiscreteDistribution class when the observations are
- * non-negative integers. Other distributions could be Gaussians, a mixture of
- * Gaussians (GMM), or any other probability distribution implementing the
- * four Distribution functions.
- *
- * Usage of the HMM class generally involves either training an HMM or loading
- * an already-known HMM and taking probability measurements of sequences.
- * Example code for supervised training of a Gaussian HMM (that is, where the
- * emission output distribution is a single Gaussian for each hidden state) is
- * given below.
- *
- * @code
- * extern arma::mat observations; // Each column is an observation.
- * extern arma::Col<size_t> states; // Hidden states for each observation.
- * // Create an untrained HMM with 5 hidden states and default (N(0, 1))
- * // Gaussian distributions with the dimensionality of the dataset.
- * HMM<GaussianDistribution> hmm(5, GaussianDistribution(observations.n_rows));
- *
- * // Train the HMM (the labels could be omitted to perform unsupervised
- * // training).
- * hmm.Train(observations, states);
- * @endcode
- *
- * Once initialized, the HMM can evaluate the probability of a certain sequence
- * (with LogLikelihood()), predict the most likely sequence of hidden states
- * (with Predict()), generate a sequence (with Generate()), or estimate the
- * probabilities of each state for a sequence of observations (with Estimate()).
- *
- * @tparam Distribution Type of emission distribution for this HMM.
- */
-template<typename Distribution = distribution::DiscreteDistribution>
-class HMM
-{
- public:
- /**
- * Create the Hidden Markov Model with the given number of hidden states and
- * the given default distribution for emissions. The dimensionality of the
- * observations is taken from the emissions variable, so it is important that
- * the given default emission distribution is set with the correct
- * dimensionality. Alternately, set the dimensionality with Dimensionality().
- * Optionally, the tolerance for convergence of the Baum-Welch algorithm can
- * be set.
- *
- * @param states Number of states.
- * @param emissions Default distribution for emissions.
- * @param tolerance Tolerance for convergence of training algorithm
- * (Baum-Welch).
- */
- HMM(const size_t states,
- const Distribution emissions,
- const double tolerance = 1e-5);
-
- /**
- * Create the Hidden Markov Model with the given transition matrix and the
- * given emission distributions. The dimensionality of the observations of
- * the HMM are taken from the given emission distributions. Alternately, the
- * dimensionality can be set with Dimensionality().
- *
- * The transition matrix should be such that T(i, j) is the probability of
- * transition to state i from state j. The columns of the matrix should sum
- * to 1.
- *
- * The emission matrix should be such that E(i, j) is the probability of
- * emission i while in state j. The columns of the matrix should sum to 1.
- *
- * Optionally, the tolerance for convergence of the Baum-Welch algorithm can
- * be set.
- *
- * @param transition Transition matrix.
- * @param emission Emission distributions.
- * @param tolerance Tolerance for convergence of training algorithm
- * (Baum-Welch).
- */
- HMM(const arma::mat& transition,
- const std::vector<Distribution>& emission,
- const double tolerance = 1e-5);
-
- /**
- * Train the model using the Baum-Welch algorithm, with only the given
- * unlabeled observations. Instead of giving a guess transition and emission
- * matrix here, do that in the constructor. Each matrix in the vector of data
- * sequences holds an individual data sequence; each point in each individual
- * data sequence should be a column in the matrix. The number of rows in each
- * matrix should be equal to the dimensionality of the HMM (which is set in
- * the constructor).
- *
- * It is preferable to use the other overload of Train(), with labeled data.
- * That will produce much better results. However, if labeled data is
- * unavailable, this will work. In addition, it is possible to use Train()
- * with labeled data first, and then continue to train the model using this
- * overload of Train() with unlabeled data.
- *
- * The tolerance of the Baum-Welch algorithm can be set either in the
- * constructor or with the Tolerance() method. When the change in
- * log-likelihood of the model between iterations is less than the tolerance,
- * the Baum-Welch algorithm terminates.
- *
- * @note
- * Train() can be called multiple times with different sequences; each time it
- * is called, it uses the current parameters of the HMM as a starting point
- * for training.
- * @endnote
- *
- * @param dataSeq Vector of observation sequences.
- */
- void Train(const std::vector<arma::mat>& dataSeq);
-
- /**
- * Train the model using the given labeled observations; the transition and
- * emission matrices are directly estimated. Each matrix in the vector of
- * data sequences corresponds to a vector in the vector of state sequences.
- * Each point in each individual data sequence should be a column in the
- * matrix, and its state should be the corresponding element in the state
- * sequence vector. For instance, dataSeq[0].col(3) corresponds to the fourth
- * observation in the first data sequence, and its state is stateSeq[0][3].
- * The number of rows in each matrix should be equal to the dimensionality of
- * the HMM (which is set in the constructor).
- *
- * @note
- * Train() can be called multiple times with different sequences; each time it
- * is called, it uses the current parameters of the HMM as a starting point
- * for training.
- * @endnote
- *
- * @param dataSeq Vector of observation sequences.
- * @param stateSeq Vector of state sequences, corresponding to each
- * observation.
- */
- void Train(const std::vector<arma::mat>& dataSeq,
- const std::vector<arma::Col<size_t> >& stateSeq);
-
- /**
- * Estimate the probabilities of each hidden state at each time step for each
- * given data observation, using the Forward-Backward algorithm. Each matrix
- * which is returned has columns equal to the number of data observations, and
- * rows equal to the number of hidden states in the model. The log-likelihood
- * of the most probable sequence is returned.
- *
- * @param dataSeq Sequence of observations.
- * @param stateProb Matrix in which the probabilities of each state at each
- * time interval will be stored.
- * @param forwardProb Matrix in which the forward probabilities of each state
- * at each time interval will be stored.
- * @param backwardProb Matrix in which the backward probabilities of each
- * state at each time interval will be stored.
- * @param scales Vector in which the scaling factors at each time interval
- * will be stored.
- * @return Log-likelihood of most likely state sequence.
- */
- double Estimate(const arma::mat& dataSeq,
- arma::mat& stateProb,
- arma::mat& forwardProb,
- arma::mat& backwardProb,
- arma::vec& scales) const;
-
- /**
- * Estimate the probabilities of each hidden state at each time step of each
- * given data observation, using the Forward-Backward algorithm. The returned
- * matrix of state probabilities has columns equal to the number of data
- * observations, and rows equal to the number of hidden states in the model.
- * The log-likelihood of the most probable sequence is returned.
- *
- * @param dataSeq Sequence of observations.
- * @param stateProb Probabilities of each state at each time interval.
- * @return Log-likelihood of most likely state sequence.
- */
- double Estimate(const arma::mat& dataSeq,
- arma::mat& stateProb) const;
-
- /**
- * Generate a random data sequence of the given length. The data sequence is
- * stored in the dataSequence parameter, and the state sequence is stored in
- * the stateSequence parameter. Each column of dataSequence represents a
- * random observation.
- *
- * @param length Length of random sequence to generate.
- * @param dataSequence Vector to store data in.
- * @param stateSequence Vector to store states in.
- * @param startState Hidden state to start sequence in (default 0).
- */
- void Generate(const size_t length,
- arma::mat& dataSequence,
- arma::Col<size_t>& stateSequence,
- const size_t startState = 0) const;
-
- /**
- * Compute the most probable hidden state sequence for the given data
- * sequence, using the Viterbi algorithm, returning the log-likelihood of the
- * most likely state sequence.
- *
- * @param dataSeq Sequence of observations.
- * @param stateSeq Vector in which the most probable state sequence will be
- * stored.
- * @return Log-likelihood of most probable state sequence.
- */
- double Predict(const arma::mat& dataSeq,
- arma::Col<size_t>& stateSeq) const;
-
- /**
- * Compute the log-likelihood of the given data sequence.
- *
- * @param dataSeq Data sequence to evaluate the likelihood of.
- * @return Log-likelihood of the given sequence.
- */
- double LogLikelihood(const arma::mat& dataSeq) const;
-
- //! Return the transition matrix.
- const arma::mat& Transition() const { return transition; }
- //! Return a modifiable transition matrix reference.
- arma::mat& Transition() { return transition; }
-
- //! Return the emission distributions.
- const std::vector<Distribution>& Emission() const { return emission; }
- //! Return a modifiable emission probability matrix reference.
- std::vector<Distribution>& Emission() { return emission; }
-
- //! Get the dimensionality of observations.
- size_t Dimensionality() const { return dimensionality; }
- //! Set the dimensionality of observations.
- size_t& Dimensionality() { return dimensionality; }
-
- //! Get the tolerance of the Baum-Welch algorithm.
- double Tolerance() const { return tolerance; }
- //! Modify the tolerance of the Baum-Welch algorithm.
- double& Tolerance() { return tolerance; }
-
- private:
- // Helper functions.
-
- /**
- * The Forward algorithm (part of the Forward-Backward algorithm). Computes
- * forward probabilities for each state for each observation in the given data
- * sequence. The returned matrix has rows equal to the number of hidden
- * states and columns equal to the number of observations.
- *
- * @param dataSeq Data sequence to compute probabilities for.
- * @param scales Vector in which scaling factors will be saved.
- * @param forwardProb Matrix in which forward probabilities will be saved.
- */
- void Forward(const arma::mat& dataSeq,
- arma::vec& scales,
- arma::mat& forwardProb) const;
-
- /**
- * The Backward algorithm (part of the Forward-Backward algorithm). Computes
- * backward probabilities for each state for each observation in the given
- * data sequence, using the scaling factors found (presumably) by Forward().
- * The returned matrix has rows equal to the number of hidden states and
- * columns equal to the number of observations.
- *
- * @param dataSeq Data sequence to compute probabilities for.
- * @param scales Vector of scaling factors.
- * @param backwardProb Matrix in which backward probabilities will be saved.
- */
- void Backward(const arma::mat& dataSeq,
- const arma::vec& scales,
- arma::mat& backwardProb) const;
-
- //! Transition probability matrix.
- arma::mat transition;
-
- //! Set of emission probability distributions; one for each state.
- std::vector<Distribution> emission;
-
- //! Dimensionality of observations.
- size_t dimensionality;
-
- //! Tolerance of Baum-Welch algorithm.
- double tolerance;
-};
-
-}; // namespace hmm
-}; // namespace mlpack
-
-// Include implementation.
-#include "hmm_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/hmm/hmm.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,339 @@
+/**
+ * @file hmm.hpp
+ * @author Ryan Curtin
+ * @author Tran Quoc Long
+ *
+ * Definition of HMM class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_HMM_HMM_HPP
+#define __MLPACK_METHODS_HMM_HMM_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace hmm /** Hidden Markov Models. */ {
+
+/**
+ * A class that represents a Hidden Markov Model with an arbitrary type of
+ * emission distribution. This HMM class supports training (supervised and
+ * unsupervised), prediction of state sequences via the Viterbi algorithm,
+ * estimation of state probabilities, generation of random sequences, and
+ * calculation of the log-likelihood of a given sequence.
+ *
+ * The template parameter, Distribution, specifies the distribution which the
+ * emissions follow. The class should implement the following functions:
+ *
+ * @code
+ * class Distribution
+ * {
+ * public:
+ * // The type of observation used by this distribution.
+ * typedef something DataType;
+ *
+ * // Return the probability of the given observation.
+ * double Probability(const DataType& observation) const;
+ *
+ * // Estimate the distribution based on the given observations.
+ * void Estimate(const std::vector<DataType>& observations);
+ *
+ * // Estimate the distribution based on the given observations, given also
+ * // the probability of each observation coming from this distribution.
+ * void Estimate(const std::vector<DataType>& observations,
+ * const std::vector<double>& probabilities);
+ * };
+ * @endcode
+ *
+ * See the mlpack::distribution::DiscreteDistribution class for an example. One
+ * would use the DiscreteDistribution class when the observations are
+ * non-negative integers. Other distributions could be Gaussians, a mixture of
+ * Gaussians (GMM), or any other probability distribution implementing the
+ * four Distribution functions.
+ *
+ * Usage of the HMM class generally involves either training an HMM or loading
+ * an already-known HMM and taking probability measurements of sequences.
+ * Example code for supervised training of a Gaussian HMM (that is, where the
+ * emission output distribution is a single Gaussian for each hidden state) is
+ * given below.
+ *
+ * @code
+ * extern arma::mat observations; // Each column is an observation.
+ * extern arma::Col<size_t> states; // Hidden states for each observation.
+ * // Create an untrained HMM with 5 hidden states and default (N(0, 1))
+ * // Gaussian distributions with the dimensionality of the dataset.
+ * HMM<GaussianDistribution> hmm(5, GaussianDistribution(observations.n_rows));
+ *
+ * // Train the HMM (the labels could be omitted to perform unsupervised
+ * // training).
+ * hmm.Train(observations, states);
+ * @endcode
+ *
+ * Once initialized, the HMM can evaluate the probability of a certain sequence
+ * (with LogLikelihood()), predict the most likely sequence of hidden states
+ * (with Predict()), generate a sequence (with Generate()), or estimate the
+ * probabilities of each state for a sequence of observations (with Estimate()).
+ *
+ * @tparam Distribution Type of emission distribution for this HMM.
+ */
+template<typename Distribution = distribution::DiscreteDistribution>
+class HMM
+{
+ public:
+ /**
+ * Create the Hidden Markov Model with the given number of hidden states and
+ * the given default distribution for emissions. The dimensionality of the
+ * observations is taken from the emissions variable, so it is important that
+ * the given default emission distribution is set with the correct
+ * dimensionality. Alternately, set the dimensionality with Dimensionality().
+ * Optionally, the tolerance for convergence of the Baum-Welch algorithm can
+ * be set.
+ *
+ * @param states Number of states.
+ * @param emissions Default distribution for emissions.
+ * @param tolerance Tolerance for convergence of training algorithm
+ * (Baum-Welch).
+ */
+ HMM(const size_t states,
+ const Distribution emissions,
+ const double tolerance = 1e-5);
+
+ /**
+ * Create the Hidden Markov Model with the given transition matrix and the
+ * given emission distributions. The dimensionality of the observations of
+ * the HMM are taken from the given emission distributions. Alternately, the
+ * dimensionality can be set with Dimensionality().
+ *
+ * The transition matrix should be such that T(i, j) is the probability of
+ * transition to state i from state j. The columns of the matrix should sum
+ * to 1.
+ *
+ * The emission matrix should be such that E(i, j) is the probability of
+ * emission i while in state j. The columns of the matrix should sum to 1.
+ *
+ * Optionally, the tolerance for convergence of the Baum-Welch algorithm can
+ * be set.
+ *
+ * @param transition Transition matrix.
+ * @param emission Emission distributions.
+ * @param tolerance Tolerance for convergence of training algorithm
+ * (Baum-Welch).
+ */
+ HMM(const arma::mat& transition,
+ const std::vector<Distribution>& emission,
+ const double tolerance = 1e-5);
+
+ /**
+ * Train the model using the Baum-Welch algorithm, with only the given
+ * unlabeled observations. Instead of giving a guess transition and emission
+ * matrix here, do that in the constructor. Each matrix in the vector of data
+ * sequences holds an individual data sequence; each point in each individual
+ * data sequence should be a column in the matrix. The number of rows in each
+ * matrix should be equal to the dimensionality of the HMM (which is set in
+ * the constructor).
+ *
+ * It is preferable to use the other overload of Train(), with labeled data.
+ * That will produce much better results. However, if labeled data is
+ * unavailable, this will work. In addition, it is possible to use Train()
+ * with labeled data first, and then continue to train the model using this
+ * overload of Train() with unlabeled data.
+ *
+ * The tolerance of the Baum-Welch algorithm can be set either in the
+ * constructor or with the Tolerance() method. When the change in
+ * log-likelihood of the model between iterations is less than the tolerance,
+ * the Baum-Welch algorithm terminates.
+ *
+ * @note
+ * Train() can be called multiple times with different sequences; each time it
+ * is called, it uses the current parameters of the HMM as a starting point
+ * for training.
+ * @endnote
+ *
+ * @param dataSeq Vector of observation sequences.
+ */
+ void Train(const std::vector<arma::mat>& dataSeq);
+
+ /**
+ * Train the model using the given labeled observations; the transition and
+ * emission matrices are directly estimated. Each matrix in the vector of
+ * data sequences corresponds to a vector in the vector of state sequences.
+ * Each point in each individual data sequence should be a column in the
+ * matrix, and its state should be the corresponding element in the state
+ * sequence vector. For instance, dataSeq[0].col(3) corresponds to the fourth
+ * observation in the first data sequence, and its state is stateSeq[0][3].
+ * The number of rows in each matrix should be equal to the dimensionality of
+ * the HMM (which is set in the constructor).
+ *
+ * @note
+ * Train() can be called multiple times with different sequences; each time it
+ * is called, it uses the current parameters of the HMM as a starting point
+ * for training.
+ * @endnote
+ *
+ * @param dataSeq Vector of observation sequences.
+ * @param stateSeq Vector of state sequences, corresponding to each
+ * observation.
+ */
+ void Train(const std::vector<arma::mat>& dataSeq,
+ const std::vector<arma::Col<size_t> >& stateSeq);
+
+ /**
+ * Estimate the probabilities of each hidden state at each time step for each
+ * given data observation, using the Forward-Backward algorithm. Each matrix
+ * which is returned has columns equal to the number of data observations, and
+ * rows equal to the number of hidden states in the model. The log-likelihood
+ * of the most probable sequence is returned.
+ *
+ * @param dataSeq Sequence of observations.
+ * @param stateProb Matrix in which the probabilities of each state at each
+ * time interval will be stored.
+ * @param forwardProb Matrix in which the forward probabilities of each state
+ * at each time interval will be stored.
+ * @param backwardProb Matrix in which the backward probabilities of each
+ * state at each time interval will be stored.
+ * @param scales Vector in which the scaling factors at each time interval
+ * will be stored.
+ * @return Log-likelihood of most likely state sequence.
+ */
+ double Estimate(const arma::mat& dataSeq,
+ arma::mat& stateProb,
+ arma::mat& forwardProb,
+ arma::mat& backwardProb,
+ arma::vec& scales) const;
+
+ /**
+ * Estimate the probabilities of each hidden state at each time step of each
+ * given data observation, using the Forward-Backward algorithm. The returned
+ * matrix of state probabilities has columns equal to the number of data
+ * observations, and rows equal to the number of hidden states in the model.
+ * The log-likelihood of the most probable sequence is returned.
+ *
+ * @param dataSeq Sequence of observations.
+ * @param stateProb Probabilities of each state at each time interval.
+ * @return Log-likelihood of most likely state sequence.
+ */
+ double Estimate(const arma::mat& dataSeq,
+ arma::mat& stateProb) const;
+
+ /**
+ * Generate a random data sequence of the given length. The data sequence is
+ * stored in the dataSequence parameter, and the state sequence is stored in
+ * the stateSequence parameter. Each column of dataSequence represents a
+ * random observation.
+ *
+ * @param length Length of random sequence to generate.
+ * @param dataSequence Vector to store data in.
+ * @param stateSequence Vector to store states in.
+ * @param startState Hidden state to start sequence in (default 0).
+ */
+ void Generate(const size_t length,
+ arma::mat& dataSequence,
+ arma::Col<size_t>& stateSequence,
+ const size_t startState = 0) const;
+
+ /**
+ * Compute the most probable hidden state sequence for the given data
+ * sequence, using the Viterbi algorithm, returning the log-likelihood of the
+ * most likely state sequence.
+ *
+ * @param dataSeq Sequence of observations.
+ * @param stateSeq Vector in which the most probable state sequence will be
+ * stored.
+ * @return Log-likelihood of most probable state sequence.
+ */
+ double Predict(const arma::mat& dataSeq,
+ arma::Col<size_t>& stateSeq) const;
+
+ /**
+ * Compute the log-likelihood of the given data sequence.
+ *
+ * @param dataSeq Data sequence to evaluate the likelihood of.
+ * @return Log-likelihood of the given sequence.
+ */
+ double LogLikelihood(const arma::mat& dataSeq) const;
+
+ //! Return the transition matrix.
+ const arma::mat& Transition() const { return transition; }
+ //! Return a modifiable transition matrix reference.
+ arma::mat& Transition() { return transition; }
+
+ //! Return the emission distributions.
+ const std::vector<Distribution>& Emission() const { return emission; }
+ //! Return a modifiable emission probability matrix reference.
+ std::vector<Distribution>& Emission() { return emission; }
+
+ //! Get the dimensionality of observations.
+ size_t Dimensionality() const { return dimensionality; }
+ //! Set the dimensionality of observations.
+ size_t& Dimensionality() { return dimensionality; }
+
+ //! Get the tolerance of the Baum-Welch algorithm.
+ double Tolerance() const { return tolerance; }
+ //! Modify the tolerance of the Baum-Welch algorithm.
+ double& Tolerance() { return tolerance; }
+
+ private:
+ // Helper functions.
+
+ /**
+ * The Forward algorithm (part of the Forward-Backward algorithm). Computes
+ * forward probabilities for each state for each observation in the given data
+ * sequence. The returned matrix has rows equal to the number of hidden
+ * states and columns equal to the number of observations.
+ *
+ * @param dataSeq Data sequence to compute probabilities for.
+ * @param scales Vector in which scaling factors will be saved.
+ * @param forwardProb Matrix in which forward probabilities will be saved.
+ */
+ void Forward(const arma::mat& dataSeq,
+ arma::vec& scales,
+ arma::mat& forwardProb) const;
+
+ /**
+ * The Backward algorithm (part of the Forward-Backward algorithm). Computes
+ * backward probabilities for each state for each observation in the given
+ * data sequence, using the scaling factors found (presumably) by Forward().
+ * The returned matrix has rows equal to the number of hidden states and
+ * columns equal to the number of observations.
+ *
+ * @param dataSeq Data sequence to compute probabilities for.
+ * @param scales Vector of scaling factors.
+ * @param backwardProb Matrix in which backward probabilities will be saved.
+ */
+ void Backward(const arma::mat& dataSeq,
+ const arma::vec& scales,
+ arma::mat& backwardProb) const;
+
+ //! Transition probability matrix.
+ arma::mat transition;
+
+ //! Set of emission probability distributions; one for each state.
+ std::vector<Distribution> emission;
+
+ //! Dimensionality of observations.
+ size_t dimensionality;
+
+ //! Tolerance of Baum-Welch algorithm.
+ double tolerance;
+};
+
+}; // namespace hmm
+}; // namespace mlpack
+
+// Include implementation.
+#include "hmm_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_generate_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/hmm/hmm_generate_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_generate_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,144 +0,0 @@
-/**
- * @file hmm_viterbi_main.cpp
- * @author Ryan Curtin
- *
- * Compute the most probably hidden state sequence of a given observation
- * sequence for a given HMM.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-
-#include "hmm.hpp"
-#include "hmm_util.hpp"
-
-#include <mlpack/methods/gmm/gmm.hpp>
-
-PROGRAM_INFO("Hidden Markov Model (HMM) Sequence Generator", "This "
- "utility takes an already-trained HMM (--model_file) and generates a "
- "random observation sequence and hidden state sequence based on its "
- "parameters, saving them to the specified files (--output_file and "
- "--state_file)");
-
-PARAM_STRING_REQ("model_file", "File containing HMM (XML).", "m");
-PARAM_INT_REQ("length", "Length of sequence to generate.", "l");
-
-PARAM_INT("start_state", "Starting state of sequence.", "t", 0);
-PARAM_STRING("output_file", "File to save observation sequence to.", "o",
- "output.csv");
-PARAM_STRING("state_file", "File to save hidden state sequence to (may be left "
- "unspecified.", "S", "");
-PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
-
-using namespace mlpack;
-using namespace mlpack::hmm;
-using namespace mlpack::distribution;
-using namespace mlpack::util;
-using namespace mlpack::gmm;
-using namespace mlpack::math;
-using namespace arma;
-using namespace std;
-
-int main(int argc, char** argv)
-{
- // Parse command line options.
- CLI::ParseCommandLine(argc, argv);
-
- // Set random seed.
- if (CLI::GetParam<int>("seed") != 0)
- RandomSeed((size_t) CLI::GetParam<int>("seed"));
- else
- RandomSeed((size_t) time(NULL));
-
- // Load observations.
- const string modelFile = CLI::GetParam<string>("model_file");
- const int length = CLI::GetParam<int>("length");
- const int startState = CLI::GetParam<int>("start_state");
-
- if (length <= 0)
- {
- Log::Fatal << "Invalid sequence length (" << length << "); must be greater "
- << "than or equal to 0!" << endl;
- }
-
- // Load model, but first we have to determine its type.
- SaveRestoreUtility sr;
- sr.ReadFile(modelFile);
- string type;
- sr.LoadParameter(type, "hmm_type");
-
- mat observations;
- Col<size_t> sequence;
- if (type == "discrete")
- {
- HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));
-
- LoadHMM(hmm, sr);
-
- if (startState < 0 || startState >= (int) hmm.Transition().n_rows)
- {
- Log::Fatal << "Invalid start state (" << startState << "); must be "
- << "between 0 and number of states (" << hmm.Transition().n_rows
- << ")!" << endl;
- }
-
- hmm.Generate(size_t(length), observations, sequence, size_t(startState));
- }
- else if (type == "gaussian")
- {
- HMM<GaussianDistribution> hmm(1, GaussianDistribution(1));
-
- LoadHMM(hmm, sr);
-
- if (startState < 0 || startState >= (int) hmm.Transition().n_rows)
- {
- Log::Fatal << "Invalid start state (" << startState << "); must be "
- << "between 0 and number of states (" << hmm.Transition().n_rows
- << ")!" << endl;
- }
-
- hmm.Generate(size_t(length), observations, sequence, size_t(startState));
- }
- else if (type == "gmm")
- {
- HMM<GMM<> > hmm(1, GMM<>(1, 1));
-
- LoadHMM(hmm, sr);
-
- if (startState < 0 || startState >= (int) hmm.Transition().n_rows)
- {
- Log::Fatal << "Invalid start state (" << startState << "); must be "
- << "between 0 and number of states (" << hmm.Transition().n_rows
- << ")!" << endl;
- }
-
- hmm.Generate(size_t(length), observations, sequence, size_t(startState));
- }
- else
- {
- Log::Fatal << "Unknown HMM type '" << type << "' in file '" << modelFile
- << "'!" << endl;
- }
-
- // Save observations.
- const string outputFile = CLI::GetParam<string>("output_file");
- data::Save(outputFile, observations, true);
-
- // Do we want to save the hidden sequence?
- const string sequenceFile = CLI::GetParam<string>("state_file");
- if (sequenceFile != "")
- data::Save(sequenceFile, sequence, true);
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_generate_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/hmm/hmm_generate_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_generate_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_generate_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,144 @@
+/**
+ * @file hmm_viterbi_main.cpp
+ * @author Ryan Curtin
+ *
+ * Compute the most probably hidden state sequence of a given observation
+ * sequence for a given HMM.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+
+#include "hmm.hpp"
+#include "hmm_util.hpp"
+
+#include <mlpack/methods/gmm/gmm.hpp>
+
+PROGRAM_INFO("Hidden Markov Model (HMM) Sequence Generator", "This "
+ "utility takes an already-trained HMM (--model_file) and generates a "
+ "random observation sequence and hidden state sequence based on its "
+ "parameters, saving them to the specified files (--output_file and "
+ "--state_file)");
+
+PARAM_STRING_REQ("model_file", "File containing HMM (XML).", "m");
+PARAM_INT_REQ("length", "Length of sequence to generate.", "l");
+
+PARAM_INT("start_state", "Starting state of sequence.", "t", 0);
+PARAM_STRING("output_file", "File to save observation sequence to.", "o",
+ "output.csv");
+PARAM_STRING("state_file", "File to save hidden state sequence to (may be left "
+ "unspecified.", "S", "");
+PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
+
+using namespace mlpack;
+using namespace mlpack::hmm;
+using namespace mlpack::distribution;
+using namespace mlpack::util;
+using namespace mlpack::gmm;
+using namespace mlpack::math;
+using namespace arma;
+using namespace std;
+
+int main(int argc, char** argv)
+{
+ // Parse command line options.
+ CLI::ParseCommandLine(argc, argv);
+
+ // Set random seed.
+ if (CLI::GetParam<int>("seed") != 0)
+ RandomSeed((size_t) CLI::GetParam<int>("seed"));
+ else
+ RandomSeed((size_t) time(NULL));
+
+ // Load observations.
+ const string modelFile = CLI::GetParam<string>("model_file");
+ const int length = CLI::GetParam<int>("length");
+ const int startState = CLI::GetParam<int>("start_state");
+
+ if (length <= 0)
+ {
+ Log::Fatal << "Invalid sequence length (" << length << "); must be greater "
+ << "than or equal to 0!" << endl;
+ }
+
+ // Load model, but first we have to determine its type.
+ SaveRestoreUtility sr;
+ sr.ReadFile(modelFile);
+ string type;
+ sr.LoadParameter(type, "hmm_type");
+
+ mat observations;
+ Col<size_t> sequence;
+ if (type == "discrete")
+ {
+ HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));
+
+ LoadHMM(hmm, sr);
+
+ if (startState < 0 || startState >= (int) hmm.Transition().n_rows)
+ {
+ Log::Fatal << "Invalid start state (" << startState << "); must be "
+ << "between 0 and number of states (" << hmm.Transition().n_rows
+ << ")!" << endl;
+ }
+
+ hmm.Generate(size_t(length), observations, sequence, size_t(startState));
+ }
+ else if (type == "gaussian")
+ {
+ HMM<GaussianDistribution> hmm(1, GaussianDistribution(1));
+
+ LoadHMM(hmm, sr);
+
+ if (startState < 0 || startState >= (int) hmm.Transition().n_rows)
+ {
+ Log::Fatal << "Invalid start state (" << startState << "); must be "
+ << "between 0 and number of states (" << hmm.Transition().n_rows
+ << ")!" << endl;
+ }
+
+ hmm.Generate(size_t(length), observations, sequence, size_t(startState));
+ }
+ else if (type == "gmm")
+ {
+ HMM<GMM<> > hmm(1, GMM<>(1, 1));
+
+ LoadHMM(hmm, sr);
+
+ if (startState < 0 || startState >= (int) hmm.Transition().n_rows)
+ {
+ Log::Fatal << "Invalid start state (" << startState << "); must be "
+ << "between 0 and number of states (" << hmm.Transition().n_rows
+ << ")!" << endl;
+ }
+
+ hmm.Generate(size_t(length), observations, sequence, size_t(startState));
+ }
+ else
+ {
+ Log::Fatal << "Unknown HMM type '" << type << "' in file '" << modelFile
+ << "'!" << endl;
+ }
+
+ // Save observations.
+ const string outputFile = CLI::GetParam<string>("output_file");
+ data::Save(outputFile, observations, true);
+
+ // Do we want to save the hidden sequence?
+ const string sequenceFile = CLI::GetParam<string>("state_file");
+ if (sequenceFile != "")
+ data::Save(sequenceFile, sequence, true);
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/hmm/hmm_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,505 +0,0 @@
-/**
- * @file hmm_impl.hpp
- * @author Ryan Curtin
- * @author Tran Quoc Long
- *
- * Implementation of HMM class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_HMM_HMM_IMPL_HPP
-#define __MLPACK_METHODS_HMM_HMM_IMPL_HPP
-
-// Just in case...
-#include "hmm.hpp"
-
-namespace mlpack {
-namespace hmm {
-
-/**
- * Create the Hidden Markov Model with the given number of hidden states and the
- * given number of emission states.
- */
-template<typename Distribution>
-HMM<Distribution>::HMM(const size_t states,
- const Distribution emissions,
- const double tolerance) :
- transition(arma::ones<arma::mat>(states, states) / (double) states),
- emission(states, /* default distribution */ emissions),
- dimensionality(emissions.Dimensionality()),
- tolerance(tolerance)
-{ /* nothing to do */ }
-
-/**
- * Create the Hidden Markov Model with the given transition matrix and the given
- * emission probability matrix.
- */
-template<typename Distribution>
-HMM<Distribution>::HMM(const arma::mat& transition,
- const std::vector<Distribution>& emission,
- const double tolerance) :
- transition(transition),
- emission(emission),
- tolerance(tolerance)
-{
- // Set the dimensionality, if we can.
- if (emission.size() > 0)
- dimensionality = emission[0].Dimensionality();
- else
- {
- Log::Warn << "HMM::HMM(): no emission distributions given; assuming a "
- << "dimensionality of 0 and hoping it gets set right later."
- << std::endl;
- dimensionality = 0;
- }
-}
-
-/**
- * Train the model using the Baum-Welch algorithm, with only the given unlabeled
- * observations. Each matrix in the vector of data sequences holds an
- * individual data sequence; each point in each individual data sequence should
- * be a column in the matrix. The number of rows in each matrix should be equal
- * to the dimensionality of the HMM.
- *
- * It is preferable to use the other overload of Train(), with labeled data.
- * That will produce much better results. However, if labeled data is
- * unavailable, this will work. In addition, it is possible to use Train() with
- * labeled data first, and then continue to train the model using this overload
- * of Train() with unlabeled data.
- *
- * @param dataSeq Set of data sequences to train on.
- */
-template<typename Distribution>
-void HMM<Distribution>::Train(const std::vector<arma::mat>& dataSeq)
-{
- // We should allow a guess at the transition and emission matrices.
- double loglik = 0;
- double oldLoglik = 0;
-
- // Maximum iterations?
- size_t iterations = 1000;
-
- // Find length of all sequences and ensure they are the correct size.
- size_t totalLength = 0;
- for (size_t seq = 0; seq < dataSeq.size(); seq++)
- {
- totalLength += dataSeq[seq].n_cols;
-
- if (dataSeq[seq].n_rows != dimensionality)
- Log::Fatal << "HMM::Train(): data sequence " << seq << " has "
- << "dimensionality " << dataSeq[seq].n_rows << " (expected "
- << dimensionality << " dimensions)." << std::endl;
- }
-
- // These are used later for training of each distribution. We initialize it
- // all now so we don't have to do any allocation later on.
- std::vector<arma::vec> emissionProb(transition.n_cols,
- arma::vec(totalLength));
- arma::mat emissionList(dimensionality, totalLength);
-
- // This should be the Baum-Welch algorithm (EM for HMM estimation). This
- // follows the procedure outlined in Elliot, Aggoun, and Moore's book "Hidden
- // Markov Models: Estimation and Control", pp. 36-40.
- for (size_t iter = 0; iter < iterations; iter++)
- {
- // Clear new transition matrix and emission probabilities.
- arma::mat newTransition(transition.n_rows, transition.n_cols);
- newTransition.zeros();
-
- // Reset log likelihood.
- loglik = 0;
-
- // Sum over time.
- size_t sumTime = 0;
-
- // Loop over each sequence.
- for (size_t seq = 0; seq < dataSeq.size(); seq++)
- {
- arma::mat stateProb;
- arma::mat forward;
- arma::mat backward;
- arma::vec scales;
-
- // Add the log-likelihood of this sequence. This is the E-step.
- loglik += Estimate(dataSeq[seq], stateProb, forward, backward, scales);
-
- // Now re-estimate the parameters. This is the M-step.
- // T_ij = sum_d ((1 / P(seq[d])) sum_t (f(i, t) T_ij E_i(seq[d][t]) b(i,
- // t + 1)))
- // E_ij = sum_d ((1 / P(seq[d])) sum_{t | seq[d][t] = j} f(i, t) b(i, t)
- // We store the new estimates in a different matrix.
- for (size_t t = 0; t < dataSeq[seq].n_cols; t++)
- {
- for (size_t j = 0; j < transition.n_cols; j++)
- {
- if (t < dataSeq[seq].n_cols - 1)
- {
- // Estimate of T_ij (probability of transition from state j to state
- // i). We postpone multiplication of the old T_ij until later.
- for (size_t i = 0; i < transition.n_rows; i++)
- newTransition(i, j) += forward(j, t) * backward(i, t + 1) *
- emission[i].Probability(dataSeq[seq].unsafe_col(t + 1)) /
- scales[t + 1];
- }
-
- // Add to list of emission observations, for Distribution::Estimate().
- emissionList.col(sumTime) = dataSeq[seq].col(t);
- emissionProb[j][sumTime] = stateProb(j, t);
- }
- sumTime++;
- }
- }
-
- // Assign the new transition matrix. We use %= (element-wise
- // multiplication) because every element of the new transition matrix must
- // still be multiplied by the old elements (this is the multiplication we
- // earlier postponed).
- transition %= newTransition;
-
- // Now we normalize the transition matrix.
- for (size_t i = 0; i < transition.n_cols; i++)
- transition.col(i) /= accu(transition.col(i));
-
- // Now estimate emission probabilities.
- for (size_t state = 0; state < transition.n_cols; state++)
- emission[state].Estimate(emissionList, emissionProb[state]);
-
- Log::Debug << "Iteration " << iter << ": log-likelihood " << loglik
- << std::endl;
-
- if (std::abs(oldLoglik - loglik) < tolerance)
- {
- Log::Debug << "Converged after " << iter << " iterations." << std::endl;
- break;
- }
-
- oldLoglik = loglik;
- }
-}
-
-/**
- * Train the model using the given labeled observations; the transition and
- * emission matrices are directly estimated.
- */
-template<typename Distribution>
-void HMM<Distribution>::Train(const std::vector<arma::mat>& dataSeq,
- const std::vector<arma::Col<size_t> >& stateSeq)
-{
- // Simple error checking.
- if (dataSeq.size() != stateSeq.size())
- {
- Log::Fatal << "HMM::Train(): number of data sequences (" << dataSeq.size()
- << ") not equal to number of state sequences (" << stateSeq.size()
- << ")." << std::endl;
- }
-
- transition.zeros();
-
- // Estimate the transition and emission matrices directly from the
- // observations. The emission list holds the time indices for observations
- // from each state.
- std::vector<std::vector<std::pair<size_t, size_t> > >
- emissionList(transition.n_cols);
- for (size_t seq = 0; seq < dataSeq.size(); seq++)
- {
- // Simple error checking.
- if (dataSeq[seq].n_cols != stateSeq[seq].n_elem)
- {
- Log::Fatal << "HMM::Train(): number of observations ("
- << dataSeq[seq].n_cols << ") in sequence " << seq
- << " not equal to number of states (" << stateSeq[seq].n_cols
- << ") in sequence " << seq << "." << std::endl;
- }
-
- if (dataSeq[seq].n_rows != dimensionality)
- {
- Log::Fatal << "HMM::Train(): data sequence " << seq << " has "
- << "dimensionality " << dataSeq[seq].n_rows << " (expected "
- << dimensionality << " dimensions)." << std::endl;
- }
-
- // Loop over each observation in the sequence. For estimation of the
- // transition matrix, we must ignore the last observation.
- for (size_t t = 0; t < dataSeq[seq].n_cols - 1; t++)
- {
- transition(stateSeq[seq][t + 1], stateSeq[seq][t])++;
- emissionList[stateSeq[seq][t]].push_back(std::make_pair(seq, t));
- }
-
- // Last observation.
- emissionList[stateSeq[seq][stateSeq[seq].n_elem - 1]].push_back(
- std::make_pair(seq, stateSeq[seq].n_elem - 1));
- }
-
- // Normalize transition matrix.
- for (size_t col = 0; col < transition.n_cols; col++)
- {
- // If the transition probability sum is greater than 0 in this column, the
- // emission probability sum will also be greater than 0. We want to avoid
- // division by 0.
- double sum = accu(transition.col(col));
- if (sum > 0)
- transition.col(col) /= sum;
- }
-
- // Estimate emission matrix.
- for (size_t state = 0; state < transition.n_cols; state++)
- {
- // Generate full sequence of observations for this state from the list of
- // emissions that are from this state.
- arma::mat emissions(dimensionality, emissionList[state].size());
- for (size_t i = 0; i < emissions.n_cols; i++)
- {
- emissions.col(i) = dataSeq[emissionList[state][i].first].col(
- emissionList[state][i].second);
- }
-
- emission[state].Estimate(emissions);
- }
-}
-
-/**
- * Estimate the probabilities of each hidden state at each time step for each
- * given data observation.
- */
-template<typename Distribution>
-double HMM<Distribution>::Estimate(const arma::mat& dataSeq,
- arma::mat& stateProb,
- arma::mat& forwardProb,
- arma::mat& backwardProb,
- arma::vec& scales) const
-{
- // First run the forward-backward algorithm.
- Forward(dataSeq, scales, forwardProb);
- Backward(dataSeq, scales, backwardProb);
-
- // Now assemble the state probability matrix based on the forward and backward
- // probabilities.
- stateProb = forwardProb % backwardProb;
-
- // Finally assemble the log-likelihood and return it.
- return accu(log(scales));
-}
-
-/**
- * Estimate the probabilities of each hidden state at each time step for each
- * given data observation.
- */
-template<typename Distribution>
-double HMM<Distribution>::Estimate(const arma::mat& dataSeq,
- arma::mat& stateProb) const
-{
- // We don't need to save these.
- arma::mat forwardProb, backwardProb;
- arma::vec scales;
-
- return Estimate(dataSeq, stateProb, forwardProb, backwardProb, scales);
-}
-
-/**
- * Generate a random data sequence of a given length. The data sequence is
- * stored in the dataSequence parameter, and the state sequence is stored in
- * the stateSequence parameter.
- */
-template<typename Distribution>
-void HMM<Distribution>::Generate(const size_t length,
- arma::mat& dataSequence,
- arma::Col<size_t>& stateSequence,
- const size_t startState) const
-{
- // Set vectors to the right size.
- stateSequence.set_size(length);
- dataSequence.set_size(dimensionality, length);
-
- // Set start state (default is 0).
- stateSequence[0] = startState;
-
- // Choose first emission state.
- double randValue = math::Random();
-
- // We just have to find where our random value sits in the probability
- // distribution of emissions for our starting state.
- dataSequence.col(0) = emission[startState].Random();
-
- // Now choose the states and emissions for the rest of the sequence.
- for (size_t t = 1; t < length; t++)
- {
- // First choose the hidden state.
- randValue = math::Random();
-
- // Now find where our random value sits in the probability distribution of
- // state changes.
- double probSum = 0;
- for (size_t st = 0; st < transition.n_rows; st++)
- {
- probSum += transition(st, stateSequence[t - 1]);
- if (randValue <= probSum)
- {
- stateSequence[t] = st;
- break;
- }
- }
-
- // Now choose the emission.
- dataSequence.col(t) = emission[stateSequence[t]].Random();
- }
-}
-
-/**
- * Compute the most probable hidden state sequence for the given observation
- * using the Viterbi algorithm. Returns the log-likelihood of the most likely
- * sequence.
- */
-template<typename Distribution>
-double HMM<Distribution>::Predict(const arma::mat& dataSeq,
- arma::Col<size_t>& stateSeq) const
-{
- // This is an implementation of the Viterbi algorithm for finding the most
- // probable sequence of states to produce the observed data sequence. We
- // don't use log-likelihoods to save that little bit of time, but we'll
- // calculate the log-likelihood at the end of it all.
- stateSeq.set_size(dataSeq.n_cols);
- arma::mat logStateProb(transition.n_rows, dataSeq.n_cols);
-
- // Store the logs of the transposed transition matrix. This is because we
- // will be using the rows of the transition matrix.
- arma::mat logTrans(log(trans(transition)));
-
- // The calculation of the first state is slightly different; the probability
- // of the first state being state j is the maximum probability that the state
- // came to be j from another state.
- logStateProb.col(0).zeros();
- for (size_t state = 0; state < transition.n_rows; state++)
- logStateProb[state] = log(transition(state, 0) *
- emission[state].Probability(dataSeq.unsafe_col(0)));
-
- // Store the best first state.
- arma::uword index;
- logStateProb.unsafe_col(0).max(index);
- stateSeq[0] = index;
-
- for (size_t t = 1; t < dataSeq.n_cols; t++)
- {
- // Assemble the state probability for this element.
- // Given that we are in state j, we state with the highest probability of
- // being the previous state.
- for (size_t j = 0; j < transition.n_rows; j++)
- {
- arma::vec prob = logStateProb.col(t - 1) + logTrans.col(j);
- logStateProb(j, t) = prob.max() +
- log(emission[j].Probability(dataSeq.unsafe_col(t)));
- }
-
- // Store the best state.
- logStateProb.unsafe_col(t).max(index);
- stateSeq[t] = index;
- }
-
- return logStateProb(stateSeq(dataSeq.n_cols - 1), dataSeq.n_cols - 1);
-}
-
-/**
- * Compute the log-likelihood of the given data sequence.
- */
-template<typename Distribution>
-double HMM<Distribution>::LogLikelihood(const arma::mat& dataSeq) const
-{
- arma::mat forward;
- arma::vec scales;
-
- Forward(dataSeq, scales, forward);
-
- // The log-likelihood is the log of the scales for each time step.
- return accu(log(scales));
-}
-
-/**
- * The Forward procedure (part of the Forward-Backward algorithm).
- */
-template<typename Distribution>
-void HMM<Distribution>::Forward(const arma::mat& dataSeq,
- arma::vec& scales,
- arma::mat& forwardProb) const
-{
- // Our goal is to calculate the forward probabilities:
- // P(X_k | o_{1:k}) for all possible states X_k, for each time point k.
- forwardProb.zeros(transition.n_rows, dataSeq.n_cols);
- scales.zeros(dataSeq.n_cols);
-
- // Starting state (at t = -1) is assumed to be state 0. This is what MATLAB
- // does in their hmmdecode() function, so we will emulate that behavior.
- for (size_t state = 0; state < transition.n_rows; state++)
- forwardProb(state, 0) = transition(state, 0) *
- emission[state].Probability(dataSeq.unsafe_col(0));
-
- // Then normalize the column.
- scales[0] = accu(forwardProb.col(0));
- forwardProb.col(0) /= scales[0];
-
- // Now compute the probabilities for each successive observation.
- for (size_t t = 1; t < dataSeq.n_cols; t++)
- {
- for (size_t j = 0; j < transition.n_rows; j++)
- {
- // The forward probability of state j at time t is the sum over all states
- // of the probability of the previous state transitioning to the current
- // state and emitting the given observation.
- forwardProb(j, t) = accu(forwardProb.col(t - 1) %
- trans(transition.row(j))) *
- emission[j].Probability(dataSeq.unsafe_col(t));
- }
-
- // Normalize probability.
- scales[t] = accu(forwardProb.col(t));
- forwardProb.col(t) /= scales[t];
- }
-}
-
-template<typename Distribution>
-void HMM<Distribution>::Backward(const arma::mat& dataSeq,
- const arma::vec& scales,
- arma::mat& backwardProb) const
-{
- // Our goal is to calculate the backward probabilities:
- // P(X_k | o_{k + 1:T}) for all possible states X_k, for each time point k.
- backwardProb.zeros(transition.n_rows, dataSeq.n_cols);
-
- // The last element probability is 1.
- backwardProb.col(dataSeq.n_cols - 1).fill(1);
-
- // Now step backwards through all other observations.
- for (size_t t = dataSeq.n_cols - 2; t + 1 > 0; t--)
- {
- for (size_t j = 0; j < transition.n_rows; j++)
- {
- // The backward probability of state j at time t is the sum over all state
- // of the probability of the next state having been a transition from the
- // current state multiplied by the probability of each of those states
- // emitting the given observation.
- for (size_t state = 0; state < transition.n_rows; state++)
- backwardProb(j, t) += transition(state, j) * backwardProb(state, t + 1)
- * emission[state].Probability(dataSeq.unsafe_col(t + 1));
-
- // Normalize by the weights from the forward algorithm.
- backwardProb(j, t) /= scales[t + 1];
- }
- }
-}
-
-}; // namespace hmm
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/hmm/hmm_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,505 @@
+/**
+ * @file hmm_impl.hpp
+ * @author Ryan Curtin
+ * @author Tran Quoc Long
+ *
+ * Implementation of HMM class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_HMM_HMM_IMPL_HPP
+#define __MLPACK_METHODS_HMM_HMM_IMPL_HPP
+
+// Just in case...
+#include "hmm.hpp"
+
+namespace mlpack {
+namespace hmm {
+
+/**
+ * Create the Hidden Markov Model with the given number of hidden states and the
+ * given number of emission states.
+ */
+template<typename Distribution>
+HMM<Distribution>::HMM(const size_t states,
+ const Distribution emissions,
+ const double tolerance) :
+ transition(arma::ones<arma::mat>(states, states) / (double) states),
+ emission(states, /* default distribution */ emissions),
+ dimensionality(emissions.Dimensionality()),
+ tolerance(tolerance)
+{ /* nothing to do */ }
+
+/**
+ * Create the Hidden Markov Model with the given transition matrix and the given
+ * emission probability matrix.
+ */
+template<typename Distribution>
+HMM<Distribution>::HMM(const arma::mat& transition,
+ const std::vector<Distribution>& emission,
+ const double tolerance) :
+ transition(transition),
+ emission(emission),
+ tolerance(tolerance)
+{
+ // Set the dimensionality, if we can.
+ if (emission.size() > 0)
+ dimensionality = emission[0].Dimensionality();
+ else
+ {
+ Log::Warn << "HMM::HMM(): no emission distributions given; assuming a "
+ << "dimensionality of 0 and hoping it gets set right later."
+ << std::endl;
+ dimensionality = 0;
+ }
+}
+
+/**
+ * Train the model using the Baum-Welch algorithm, with only the given unlabeled
+ * observations. Each matrix in the vector of data sequences holds an
+ * individual data sequence; each point in each individual data sequence should
+ * be a column in the matrix. The number of rows in each matrix should be equal
+ * to the dimensionality of the HMM.
+ *
+ * It is preferable to use the other overload of Train(), with labeled data.
+ * That will produce much better results. However, if labeled data is
+ * unavailable, this will work. In addition, it is possible to use Train() with
+ * labeled data first, and then continue to train the model using this overload
+ * of Train() with unlabeled data.
+ *
+ * @param dataSeq Set of data sequences to train on.
+ */
+template<typename Distribution>
+void HMM<Distribution>::Train(const std::vector<arma::mat>& dataSeq)
+{
+ // We should allow a guess at the transition and emission matrices.
+ double loglik = 0;
+ double oldLoglik = 0;
+
+ // Maximum iterations?
+ size_t iterations = 1000;
+
+ // Find length of all sequences and ensure they are the correct size.
+ size_t totalLength = 0;
+ for (size_t seq = 0; seq < dataSeq.size(); seq++)
+ {
+ totalLength += dataSeq[seq].n_cols;
+
+ if (dataSeq[seq].n_rows != dimensionality)
+ Log::Fatal << "HMM::Train(): data sequence " << seq << " has "
+ << "dimensionality " << dataSeq[seq].n_rows << " (expected "
+ << dimensionality << " dimensions)." << std::endl;
+ }
+
+ // These are used later for training of each distribution. We initialize it
+ // all now so we don't have to do any allocation later on.
+ std::vector<arma::vec> emissionProb(transition.n_cols,
+ arma::vec(totalLength));
+ arma::mat emissionList(dimensionality, totalLength);
+
+ // This should be the Baum-Welch algorithm (EM for HMM estimation). This
+ // follows the procedure outlined in Elliot, Aggoun, and Moore's book "Hidden
+ // Markov Models: Estimation and Control", pp. 36-40.
+ for (size_t iter = 0; iter < iterations; iter++)
+ {
+ // Clear new transition matrix and emission probabilities.
+ arma::mat newTransition(transition.n_rows, transition.n_cols);
+ newTransition.zeros();
+
+ // Reset log likelihood.
+ loglik = 0;
+
+ // Sum over time.
+ size_t sumTime = 0;
+
+ // Loop over each sequence.
+ for (size_t seq = 0; seq < dataSeq.size(); seq++)
+ {
+ arma::mat stateProb;
+ arma::mat forward;
+ arma::mat backward;
+ arma::vec scales;
+
+ // Add the log-likelihood of this sequence. This is the E-step.
+ loglik += Estimate(dataSeq[seq], stateProb, forward, backward, scales);
+
+ // Now re-estimate the parameters. This is the M-step.
+ // T_ij = sum_d ((1 / P(seq[d])) sum_t (f(i, t) T_ij E_i(seq[d][t]) b(i,
+ // t + 1)))
+ // E_ij = sum_d ((1 / P(seq[d])) sum_{t | seq[d][t] = j} f(i, t) b(i, t)
+ // We store the new estimates in a different matrix.
+ for (size_t t = 0; t < dataSeq[seq].n_cols; t++)
+ {
+ for (size_t j = 0; j < transition.n_cols; j++)
+ {
+ if (t < dataSeq[seq].n_cols - 1)
+ {
+ // Estimate of T_ij (probability of transition from state j to state
+ // i). We postpone multiplication of the old T_ij until later.
+ for (size_t i = 0; i < transition.n_rows; i++)
+ newTransition(i, j) += forward(j, t) * backward(i, t + 1) *
+ emission[i].Probability(dataSeq[seq].unsafe_col(t + 1)) /
+ scales[t + 1];
+ }
+
+ // Add to list of emission observations, for Distribution::Estimate().
+ emissionList.col(sumTime) = dataSeq[seq].col(t);
+ emissionProb[j][sumTime] = stateProb(j, t);
+ }
+ sumTime++;
+ }
+ }
+
+ // Assign the new transition matrix. We use %= (element-wise
+ // multiplication) because every element of the new transition matrix must
+ // still be multiplied by the old elements (this is the multiplication we
+ // earlier postponed).
+ transition %= newTransition;
+
+ // Now we normalize the transition matrix.
+ for (size_t i = 0; i < transition.n_cols; i++)
+ transition.col(i) /= accu(transition.col(i));
+
+ // Now estimate emission probabilities.
+ for (size_t state = 0; state < transition.n_cols; state++)
+ emission[state].Estimate(emissionList, emissionProb[state]);
+
+ Log::Debug << "Iteration " << iter << ": log-likelihood " << loglik
+ << std::endl;
+
+ if (std::abs(oldLoglik - loglik) < tolerance)
+ {
+ Log::Debug << "Converged after " << iter << " iterations." << std::endl;
+ break;
+ }
+
+ oldLoglik = loglik;
+ }
+}
+
+/**
+ * Train the model using the given labeled observations; the transition and
+ * emission matrices are directly estimated.
+ */
+template<typename Distribution>
+void HMM<Distribution>::Train(const std::vector<arma::mat>& dataSeq,
+ const std::vector<arma::Col<size_t> >& stateSeq)
+{
+ // Simple error checking.
+ if (dataSeq.size() != stateSeq.size())
+ {
+ Log::Fatal << "HMM::Train(): number of data sequences (" << dataSeq.size()
+ << ") not equal to number of state sequences (" << stateSeq.size()
+ << ")." << std::endl;
+ }
+
+ transition.zeros();
+
+ // Estimate the transition and emission matrices directly from the
+ // observations. The emission list holds the time indices for observations
+ // from each state.
+ std::vector<std::vector<std::pair<size_t, size_t> > >
+ emissionList(transition.n_cols);
+ for (size_t seq = 0; seq < dataSeq.size(); seq++)
+ {
+ // Simple error checking.
+ if (dataSeq[seq].n_cols != stateSeq[seq].n_elem)
+ {
+ Log::Fatal << "HMM::Train(): number of observations ("
+ << dataSeq[seq].n_cols << ") in sequence " << seq
+ << " not equal to number of states (" << stateSeq[seq].n_cols
+ << ") in sequence " << seq << "." << std::endl;
+ }
+
+ if (dataSeq[seq].n_rows != dimensionality)
+ {
+ Log::Fatal << "HMM::Train(): data sequence " << seq << " has "
+ << "dimensionality " << dataSeq[seq].n_rows << " (expected "
+ << dimensionality << " dimensions)." << std::endl;
+ }
+
+ // Loop over each observation in the sequence. For estimation of the
+ // transition matrix, we must ignore the last observation.
+ for (size_t t = 0; t < dataSeq[seq].n_cols - 1; t++)
+ {
+ transition(stateSeq[seq][t + 1], stateSeq[seq][t])++;
+ emissionList[stateSeq[seq][t]].push_back(std::make_pair(seq, t));
+ }
+
+ // Last observation.
+ emissionList[stateSeq[seq][stateSeq[seq].n_elem - 1]].push_back(
+ std::make_pair(seq, stateSeq[seq].n_elem - 1));
+ }
+
+ // Normalize transition matrix.
+ for (size_t col = 0; col < transition.n_cols; col++)
+ {
+ // If the transition probability sum is greater than 0 in this column, the
+ // emission probability sum will also be greater than 0. We want to avoid
+ // division by 0.
+ double sum = accu(transition.col(col));
+ if (sum > 0)
+ transition.col(col) /= sum;
+ }
+
+ // Estimate emission matrix.
+ for (size_t state = 0; state < transition.n_cols; state++)
+ {
+ // Generate full sequence of observations for this state from the list of
+ // emissions that are from this state.
+ arma::mat emissions(dimensionality, emissionList[state].size());
+ for (size_t i = 0; i < emissions.n_cols; i++)
+ {
+ emissions.col(i) = dataSeq[emissionList[state][i].first].col(
+ emissionList[state][i].second);
+ }
+
+ emission[state].Estimate(emissions);
+ }
+}
+
+/**
+ * Estimate the probabilities of each hidden state at each time step for each
+ * given data observation.
+ */
+template<typename Distribution>
+double HMM<Distribution>::Estimate(const arma::mat& dataSeq,
+ arma::mat& stateProb,
+ arma::mat& forwardProb,
+ arma::mat& backwardProb,
+ arma::vec& scales) const
+{
+ // First run the forward-backward algorithm.
+ Forward(dataSeq, scales, forwardProb);
+ Backward(dataSeq, scales, backwardProb);
+
+ // Now assemble the state probability matrix based on the forward and backward
+ // probabilities.
+ stateProb = forwardProb % backwardProb;
+
+ // Finally assemble the log-likelihood and return it.
+ return accu(log(scales));
+}
+
+/**
+ * Estimate the probabilities of each hidden state at each time step for each
+ * given data observation.
+ */
+template<typename Distribution>
+double HMM<Distribution>::Estimate(const arma::mat& dataSeq,
+ arma::mat& stateProb) const
+{
+ // We don't need to save these.
+ arma::mat forwardProb, backwardProb;
+ arma::vec scales;
+
+ return Estimate(dataSeq, stateProb, forwardProb, backwardProb, scales);
+}
+
+/**
+ * Generate a random data sequence of a given length. The data sequence is
+ * stored in the dataSequence parameter, and the state sequence is stored in
+ * the stateSequence parameter.
+ */
+template<typename Distribution>
+void HMM<Distribution>::Generate(const size_t length,
+ arma::mat& dataSequence,
+ arma::Col<size_t>& stateSequence,
+ const size_t startState) const
+{
+ // Set vectors to the right size.
+ stateSequence.set_size(length);
+ dataSequence.set_size(dimensionality, length);
+
+ // Set start state (default is 0).
+ stateSequence[0] = startState;
+
+ // Choose first emission state.
+ double randValue = math::Random();
+
+ // We just have to find where our random value sits in the probability
+ // distribution of emissions for our starting state.
+ dataSequence.col(0) = emission[startState].Random();
+
+ // Now choose the states and emissions for the rest of the sequence.
+ for (size_t t = 1; t < length; t++)
+ {
+ // First choose the hidden state.
+ randValue = math::Random();
+
+ // Now find where our random value sits in the probability distribution of
+ // state changes.
+ double probSum = 0;
+ for (size_t st = 0; st < transition.n_rows; st++)
+ {
+ probSum += transition(st, stateSequence[t - 1]);
+ if (randValue <= probSum)
+ {
+ stateSequence[t] = st;
+ break;
+ }
+ }
+
+ // Now choose the emission.
+ dataSequence.col(t) = emission[stateSequence[t]].Random();
+ }
+}
+
+/**
+ * Compute the most probable hidden state sequence for the given observation
+ * using the Viterbi algorithm. Returns the log-likelihood of the most likely
+ * sequence.
+ */
+template<typename Distribution>
+double HMM<Distribution>::Predict(const arma::mat& dataSeq,
+ arma::Col<size_t>& stateSeq) const
+{
+ // This is an implementation of the Viterbi algorithm for finding the most
+ // probable sequence of states to produce the observed data sequence. We
+ // don't use log-likelihoods to save that little bit of time, but we'll
+ // calculate the log-likelihood at the end of it all.
+ stateSeq.set_size(dataSeq.n_cols);
+ arma::mat logStateProb(transition.n_rows, dataSeq.n_cols);
+
+ // Store the logs of the transposed transition matrix. This is because we
+ // will be using the rows of the transition matrix.
+ arma::mat logTrans(log(trans(transition)));
+
+ // The calculation of the first state is slightly different; the probability
+ // of the first state being state j is the maximum probability that the state
+ // came to be j from another state.
+ logStateProb.col(0).zeros();
+ for (size_t state = 0; state < transition.n_rows; state++)
+ logStateProb[state] = log(transition(state, 0) *
+ emission[state].Probability(dataSeq.unsafe_col(0)));
+
+ // Store the best first state.
+ arma::uword index;
+ logStateProb.unsafe_col(0).max(index);
+ stateSeq[0] = index;
+
+ for (size_t t = 1; t < dataSeq.n_cols; t++)
+ {
+ // Assemble the state probability for this element.
+ // Given that we are in state j, we state with the highest probability of
+ // being the previous state.
+ for (size_t j = 0; j < transition.n_rows; j++)
+ {
+ arma::vec prob = logStateProb.col(t - 1) + logTrans.col(j);
+ logStateProb(j, t) = prob.max() +
+ log(emission[j].Probability(dataSeq.unsafe_col(t)));
+ }
+
+ // Store the best state.
+ logStateProb.unsafe_col(t).max(index);
+ stateSeq[t] = index;
+ }
+
+ return logStateProb(stateSeq(dataSeq.n_cols - 1), dataSeq.n_cols - 1);
+}
+
+/**
+ * Compute the log-likelihood of the given data sequence.
+ */
+template<typename Distribution>
+double HMM<Distribution>::LogLikelihood(const arma::mat& dataSeq) const
+{
+ arma::mat forward;
+ arma::vec scales;
+
+ Forward(dataSeq, scales, forward);
+
+ // The log-likelihood is the log of the scales for each time step.
+ return accu(log(scales));
+}
+
+/**
+ * The Forward procedure (part of the Forward-Backward algorithm).
+ */
+template<typename Distribution>
+void HMM<Distribution>::Forward(const arma::mat& dataSeq,
+ arma::vec& scales,
+ arma::mat& forwardProb) const
+{
+ // Our goal is to calculate the forward probabilities:
+ // P(X_k | o_{1:k}) for all possible states X_k, for each time point k.
+ forwardProb.zeros(transition.n_rows, dataSeq.n_cols);
+ scales.zeros(dataSeq.n_cols);
+
+ // Starting state (at t = -1) is assumed to be state 0. This is what MATLAB
+ // does in their hmmdecode() function, so we will emulate that behavior.
+ for (size_t state = 0; state < transition.n_rows; state++)
+ forwardProb(state, 0) = transition(state, 0) *
+ emission[state].Probability(dataSeq.unsafe_col(0));
+
+ // Then normalize the column.
+ scales[0] = accu(forwardProb.col(0));
+ forwardProb.col(0) /= scales[0];
+
+ // Now compute the probabilities for each successive observation.
+ for (size_t t = 1; t < dataSeq.n_cols; t++)
+ {
+ for (size_t j = 0; j < transition.n_rows; j++)
+ {
+ // The forward probability of state j at time t is the sum over all states
+ // of the probability of the previous state transitioning to the current
+ // state and emitting the given observation.
+ forwardProb(j, t) = accu(forwardProb.col(t - 1) %
+ trans(transition.row(j))) *
+ emission[j].Probability(dataSeq.unsafe_col(t));
+ }
+
+ // Normalize probability.
+ scales[t] = accu(forwardProb.col(t));
+ forwardProb.col(t) /= scales[t];
+ }
+}
+
+template<typename Distribution>
+void HMM<Distribution>::Backward(const arma::mat& dataSeq,
+ const arma::vec& scales,
+ arma::mat& backwardProb) const
+{
+ // Our goal is to calculate the backward probabilities:
+ // P(X_k | o_{k + 1:T}) for all possible states X_k, for each time point k.
+ backwardProb.zeros(transition.n_rows, dataSeq.n_cols);
+
+ // The last element probability is 1.
+ backwardProb.col(dataSeq.n_cols - 1).fill(1);
+
+ // Now step backwards through all other observations.
+ for (size_t t = dataSeq.n_cols - 2; t + 1 > 0; t--)
+ {
+ for (size_t j = 0; j < transition.n_rows; j++)
+ {
+ // The backward probability of state j at time t is the sum over all state
+ // of the probability of the next state having been a transition from the
+ // current state multiplied by the probability of each of those states
+ // emitting the given observation.
+ for (size_t state = 0; state < transition.n_rows; state++)
+ backwardProb(j, t) += transition(state, j) * backwardProb(state, t + 1)
+ * emission[state].Probability(dataSeq.unsafe_col(t + 1));
+
+ // Normalize by the weights from the forward algorithm.
+ backwardProb(j, t) /= scales[t + 1];
+ }
+ }
+}
+
+}; // namespace hmm
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_loglik_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/hmm/hmm_loglik_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_loglik_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,115 +0,0 @@
-/**
- * @file hmm_loglik_main.cpp
- * @author Ryan Curtin
- *
- * Compute the log-likelihood of a given sequence for a given HMM.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-
-#include "hmm.hpp"
-#include "hmm_util.hpp"
-
-#include <mlpack/methods/gmm/gmm.hpp>
-
-PROGRAM_INFO("Hidden Markov Model (HMM) Sequence Log-Likelihood", "This "
- "utility takes an already-trained HMM (--model_file) and evaluates the "
- "log-likelihood of a given sequence of observations (--input_file). The "
- "computed log-likelihood is given directly to stdout.");
-
-PARAM_STRING_REQ("input_file", "File containing observations,", "i");
-PARAM_STRING_REQ("model_file", "File containing HMM (XML).", "m");
-
-using namespace mlpack;
-using namespace mlpack::hmm;
-using namespace mlpack::distribution;
-using namespace mlpack::util;
-using namespace mlpack::gmm;
-using namespace arma;
-using namespace std;
-
-int main(int argc, char** argv)
-{
- // Parse command line options.
- CLI::ParseCommandLine(argc, argv);
-
- // Load observations.
- const string inputFile = CLI::GetParam<string>("input_file");
- const string modelFile = CLI::GetParam<string>("model_file");
-
- mat dataSeq;
- data::Load(inputFile, dataSeq, true);
-
- // Load model, but first we have to determine its type.
- SaveRestoreUtility sr;
- sr.ReadFile(modelFile);
- string type;
- sr.LoadParameter(type, "hmm_type");
-
- double loglik = 0;
- if (type == "discrete")
- {
- HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));
-
- LoadHMM(hmm, sr);
-
- // Verify only one row in observations.
- if (dataSeq.n_cols == 1)
- dataSeq = trans(dataSeq);
-
- if (dataSeq.n_rows > 1)
- Log::Fatal << "Only one-dimensional discrete observations allowed for "
- << "discrete HMMs!" << endl;
-
- loglik = hmm.LogLikelihood(dataSeq);
- }
- else if (type == "gaussian")
- {
- HMM<GaussianDistribution> hmm(1, GaussianDistribution(1));
-
- LoadHMM(hmm, sr);
-
- // Verify correct dimensionality.
- if (dataSeq.n_rows != hmm.Emission()[0].Mean().n_elem)
- Log::Fatal << "Observation dimensionality (" << dataSeq.n_rows << ") "
- << "does not match HMM Gaussian dimensionality ("
- << hmm.Emission()[0].Mean().n_elem << ")!" << endl;
-
- loglik = hmm.LogLikelihood(dataSeq);
- }
- else if (type == "gmm")
- {
- HMM<GMM<> > hmm(1, GMM<>(1, 1));
-
- LoadHMM(hmm, sr);
-
- // Verify correct dimensionality.
- if (dataSeq.n_rows != hmm.Emission()[0].Dimensionality())
- Log::Fatal << "Observation dimensionality (" << dataSeq.n_rows << ") "
- << "does not match HMM Gaussian dimensionality ("
- << hmm.Emission()[0].Dimensionality() << ")!" << endl;
-
- loglik = hmm.LogLikelihood(dataSeq);
- }
- else
- {
- Log::Fatal << "Unknown HMM type '" << type << "' in file '" << modelFile
- << "'!" << endl;
- }
-
- cout << loglik << endl;
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_loglik_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/hmm/hmm_loglik_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_loglik_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_loglik_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,115 @@
+/**
+ * @file hmm_loglik_main.cpp
+ * @author Ryan Curtin
+ *
+ * Compute the log-likelihood of a given sequence for a given HMM.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+
+#include "hmm.hpp"
+#include "hmm_util.hpp"
+
+#include <mlpack/methods/gmm/gmm.hpp>
+
+PROGRAM_INFO("Hidden Markov Model (HMM) Sequence Log-Likelihood", "This "
+ "utility takes an already-trained HMM (--model_file) and evaluates the "
+ "log-likelihood of a given sequence of observations (--input_file). The "
+ "computed log-likelihood is given directly to stdout.");
+
+PARAM_STRING_REQ("input_file", "File containing observations,", "i");
+PARAM_STRING_REQ("model_file", "File containing HMM (XML).", "m");
+
+using namespace mlpack;
+using namespace mlpack::hmm;
+using namespace mlpack::distribution;
+using namespace mlpack::util;
+using namespace mlpack::gmm;
+using namespace arma;
+using namespace std;
+
+int main(int argc, char** argv)
+{
+ // Parse command line options.
+ CLI::ParseCommandLine(argc, argv);
+
+ // Load observations.
+ const string inputFile = CLI::GetParam<string>("input_file");
+ const string modelFile = CLI::GetParam<string>("model_file");
+
+ mat dataSeq;
+ data::Load(inputFile, dataSeq, true);
+
+ // Load model, but first we have to determine its type.
+ SaveRestoreUtility sr;
+ sr.ReadFile(modelFile);
+ string type;
+ sr.LoadParameter(type, "hmm_type");
+
+ double loglik = 0;
+ if (type == "discrete")
+ {
+ HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));
+
+ LoadHMM(hmm, sr);
+
+ // Verify only one row in observations.
+ if (dataSeq.n_cols == 1)
+ dataSeq = trans(dataSeq);
+
+ if (dataSeq.n_rows > 1)
+ Log::Fatal << "Only one-dimensional discrete observations allowed for "
+ << "discrete HMMs!" << endl;
+
+ loglik = hmm.LogLikelihood(dataSeq);
+ }
+ else if (type == "gaussian")
+ {
+ HMM<GaussianDistribution> hmm(1, GaussianDistribution(1));
+
+ LoadHMM(hmm, sr);
+
+ // Verify correct dimensionality.
+ if (dataSeq.n_rows != hmm.Emission()[0].Mean().n_elem)
+ Log::Fatal << "Observation dimensionality (" << dataSeq.n_rows << ") "
+ << "does not match HMM Gaussian dimensionality ("
+ << hmm.Emission()[0].Mean().n_elem << ")!" << endl;
+
+ loglik = hmm.LogLikelihood(dataSeq);
+ }
+ else if (type == "gmm")
+ {
+ HMM<GMM<> > hmm(1, GMM<>(1, 1));
+
+ LoadHMM(hmm, sr);
+
+ // Verify correct dimensionality.
+ if (dataSeq.n_rows != hmm.Emission()[0].Dimensionality())
+ Log::Fatal << "Observation dimensionality (" << dataSeq.n_rows << ") "
+ << "does not match HMM Gaussian dimensionality ("
+ << hmm.Emission()[0].Dimensionality() << ")!" << endl;
+
+ loglik = hmm.LogLikelihood(dataSeq);
+ }
+ else
+ {
+ Log::Fatal << "Unknown HMM type '" << type << "' in file '" << modelFile
+ << "'!" << endl;
+ }
+
+ cout << loglik << endl;
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_train_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/hmm/hmm_train_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_train_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,378 +0,0 @@
-/**
- * @file hmm_train_main.cpp
- * @author Ryan Curtin
- *
- * Executable which trains an HMM and saves the trained HMM to file.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-
-#include "hmm.hpp"
-#include "hmm_util.hpp"
-
-#include <mlpack/methods/gmm/gmm.hpp>
-
-PROGRAM_INFO("Hidden Markov Model (HMM) Training", "This program allows a "
- "Hidden Markov Model to be trained on labeled or unlabeled data. It "
- "support three types of HMMs: discrete HMMs, Gaussian HMMs, or GMM HMMs."
- "\n\n"
- "Either one input sequence can be specified (with --input_file), or, a "
- "file containing files in which input sequences can be found (when "
- "--input_file and --batch are used together). In addition, labels can be "
- "provided in the file specified by --label_file, and if --batch is used, "
- "the file given to --label_file should contain a list of files of labels "
- "corresponding to the sequences in the file given to --input_file."
- "\n\n"
- "The HMM is trained with the Baum-Welch algorithm if no labels are "
- "provided. The tolerance of the Baum-Welch algorithm can be set with the "
- "--tolerance option."
- "\n\n"
- "Optionally, a pre-created HMM model can be used as a guess for the "
- "transition matrix and emission probabilities; this is specifiable with "
- "--model_file.");
-
-PARAM_STRING_REQ("input_file", "File containing input observations.", "i");
-PARAM_STRING_REQ("type", "Type of HMM: discrete | gaussian | gmm.", "t");
-
-PARAM_FLAG("batch", "If true, input_file (and if passed, labels_file) are "
- "expected to contain a list of files to use as input observation sequences "
- " (and label sequences).", "b");
-PARAM_INT("states", "Number of hidden states in HMM (necessary, unless "
- "model_file is specified.", "n", 0);
-PARAM_INT("gaussians", "Number of gaussians in each GMM (necessary when type is"
- " 'gmm'.", "g", 0);
-PARAM_STRING("model_file", "Pre-existing HMM model (optional).", "m", "");
-PARAM_STRING("labels_file", "Optional file of hidden states, used for "
- "labeled training.", "l", "");
-PARAM_STRING("output_file", "File to save trained HMM to (XML).", "o",
- "output_hmm.xml");
-PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
-PARAM_DOUBLE("tolerance", "Tolerance of the Baum-Welch algorithm.", "T", 1e-5);
-
-using namespace mlpack;
-using namespace mlpack::hmm;
-using namespace mlpack::distribution;
-using namespace mlpack::util;
-using namespace mlpack::gmm;
-using namespace mlpack::math;
-using namespace arma;
-using namespace std;
-
-int main(int argc, char** argv)
-{
- // Parse command line options.
- CLI::ParseCommandLine(argc, argv);
-
- // Set random seed.
- if (CLI::GetParam<int>("seed") != 0)
- RandomSeed((size_t) CLI::GetParam<int>("seed"));
- else
- RandomSeed((size_t) time(NULL));
-
- // Validate parameters.
- const string inputFile = CLI::GetParam<string>("input_file");
- const string labelsFile = CLI::GetParam<string>("labels_file");
- const string modelFile = CLI::GetParam<string>("model_file");
- const string outputFile = CLI::GetParam<string>("output_file");
- const string type = CLI::GetParam<string>("type");
- const int states = CLI::GetParam<int>("states");
- const bool batch = CLI::HasParam("batch");
- const double tolerance = CLI::GetParam<double>("tolerance");
-
- // Validate number of states.
- if (states == 0 && modelFile == "")
- {
- Log::Fatal << "Must specify number of states if model file is not "
- << "specified!" << endl;
- }
-
- if (states < 0 && modelFile == "")
- {
- Log::Fatal << "Invalid number of states (" << states << "); must be greater"
- << " than or equal to 1." << endl;
- }
-
- // Load the dataset(s) and labels.
- vector<mat> trainSeq;
- vector<arma::Col<size_t> > labelSeq; // May be empty.
- if (batch)
- {
- // The input file contains a list of files to read.
- Log::Info << "Reading list of training sequences from '" << inputFile
- << "'." << endl;
-
- fstream f(inputFile.c_str(), ios_base::in);
-
- if (!f.is_open())
- Log::Fatal << "Could not open '" << inputFile << "' for reading." << endl;
-
- // Now read each line in.
- char lineBuf[1024]; // Max 1024 characters... hopefully that is long enough.
- f.getline(lineBuf, 1024, '\n');
- while (!f.eof())
- {
- Log::Info << "Adding training sequence from '" << lineBuf << "'." << endl;
-
- // Now read the matrix.
- trainSeq.push_back(mat());
- if (labelsFile == "") // Nonfatal in this case.
- {
- if (!data::Load(lineBuf, trainSeq.back(), false))
- {
- Log::Warn << "Loading training sequence from '" << lineBuf << "' "
- << "failed. Sequence ignored." << endl;
- trainSeq.pop_back(); // Remove last element which we did not use.
- }
- }
- else
- {
- data::Load(lineBuf, trainSeq.back(), true);
- }
-
- // See if we need to transpose the data.
- if (type == "discrete")
- {
- if (trainSeq.back().n_cols == 1)
- trainSeq.back() = trans(trainSeq.back());
- }
-
- f.getline(lineBuf, 1024, '\n');
- }
-
- f.close();
-
- // Now load labels, if we need to.
- if (labelsFile != "")
- {
- f.open(labelsFile.c_str(), ios_base::in);
-
- if (!f.is_open())
- Log::Fatal << "Could not open '" << labelsFile << "' for reading."
- << endl;
-
- // Now read each line in.
- f.getline(lineBuf, 1024, '\n');
- while (!f.eof())
- {
- Log::Info << "Adding training sequence labels from '" << lineBuf
- << "'." << endl;
-
- // Now read the matrix.
- Mat<size_t> label;
- data::Load(lineBuf, label, true); // Fatal on failure.
-
- // Ensure that matrix only has one column.
- if (label.n_rows == 1)
- label = trans(label);
-
- if (label.n_cols > 1)
- Log::Fatal << "Invalid labels; must be one-dimensional." << endl;
-
- labelSeq.push_back(label.col(0));
-
- f.getline(lineBuf, 1024, '\n');
- }
- }
- }
- else
- {
- // Only one input file.
- trainSeq.resize(1);
- data::Load(inputFile.c_str(), trainSeq[0], true);
-
- // Do we need to load labels?
- if (labelsFile != "")
- {
- Mat<size_t> label;
- data::Load(labelsFile, label, true);
-
- // Ensure that matrix only has one column.
- if (label.n_rows == 1)
- label = trans(label);
-
- if (label.n_cols > 1)
- Log::Fatal << "Invalid labels; must be one-dimensional." << endl;
-
- // Verify the same number of observations as the data.
- if (label.n_elem != trainSeq[labelSeq.size()].n_cols)
- Log::Fatal << "Label sequence " << labelSeq.size() << " does not have "
- << "the same number of points as observation sequence "
- << labelSeq.size() << "!" << endl;
-
- labelSeq.push_back(label.col(0));
- }
- }
-
- // Now, train the HMM, since we have loaded the input data.
- if (type == "discrete")
- {
- // Verify observations are valid.
- for (size_t i = 0; i < trainSeq.size(); ++i)
- if (trainSeq[i].n_rows > 1)
- Log::Fatal << "Error in training sequence " << i << ": only "
- << "one-dimensional discrete observations allowed for discrete "
- << "HMMs!" << endl;
-
- // Do we have a model to preload?
- HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1), tolerance);
-
- if (modelFile != "")
- {
- SaveRestoreUtility loader;
- loader.ReadFile(modelFile);
- LoadHMM(hmm, loader);
- }
- else // New model.
- {
- // Maximum observation is necessary so we know how to train the discrete
- // distribution.
- size_t maxEmission = 0;
- for (vector<mat>::iterator it = trainSeq.begin(); it != trainSeq.end();
- ++it)
- {
- size_t maxSeq = size_t(as_scalar(max(trainSeq[0], 1))) + 1;
- if (maxSeq > maxEmission)
- maxEmission = maxSeq;
- }
-
- Log::Info << maxEmission << " discrete observations in the input data."
- << endl;
-
- // Create HMM object.
- hmm = HMM<DiscreteDistribution>(size_t(states),
- DiscreteDistribution(maxEmission), tolerance);
- }
-
- // Do we have labels?
- if (labelsFile == "")
- hmm.Train(trainSeq); // Unsupervised training.
- else
- hmm.Train(trainSeq, labelSeq); // Supervised training.
-
- // Finally, save the model. This should later be integrated into the HMM
- // class itself.
- SaveRestoreUtility sr;
- SaveHMM(hmm, sr);
- sr.WriteFile(outputFile);
- }
- else if (type == "gaussian")
- {
- // Create HMM object.
- HMM<GaussianDistribution> hmm(1, GaussianDistribution(1), tolerance);
-
- // Do we have a model to load?
- size_t dimensionality = 0;
- if (modelFile != "")
- {
- SaveRestoreUtility loader;
- loader.ReadFile(modelFile);
- LoadHMM(hmm, loader);
-
- dimensionality = hmm.Emission()[0].Mean().n_elem;
- }
- else
- {
- // Find dimension of the data.
- dimensionality = trainSeq[0].n_rows;
-
- hmm = HMM<GaussianDistribution>(size_t(states),
- GaussianDistribution(dimensionality), tolerance);
- }
-
- // Verify dimensionality of data.
- for (size_t i = 0; i < trainSeq.size(); ++i)
- if (trainSeq[i].n_rows != dimensionality)
- Log::Fatal << "Observation sequence " << i << " dimensionality ("
- << trainSeq[i].n_rows << " is incorrect (should be "
- << dimensionality << ")!" << endl;
-
- // Now run the training.
- if (labelsFile == "")
- hmm.Train(trainSeq); // Unsupervised training.
- else
- hmm.Train(trainSeq, labelSeq); // Supervised training.
-
- // Finally, save the model. This should later be integrated into th HMM
- // class itself.
- SaveRestoreUtility sr;
- SaveHMM(hmm, sr);
- sr.WriteFile(outputFile);
- }
- else if (type == "gmm")
- {
- // Create HMM object.
- HMM<GMM<> > hmm(1, GMM<>(1, 1));
-
- // Do we have a model to load?
- size_t dimensionality = 0;
- if (modelFile != "")
- {
- SaveRestoreUtility loader;
- loader.ReadFile(modelFile);
- LoadHMM(hmm, loader);
-
- dimensionality = hmm.Emission()[0].Dimensionality();
- }
- else
- {
- // Find dimension of the data.
- dimensionality = trainSeq[0].n_rows;
-
- const int gaussians = CLI::GetParam<int>("gaussians");
-
- if (gaussians == 0)
- Log::Fatal << "Number of gaussians for each GMM must be specified (-g) "
- << "when type = 'gmm'!" << endl;
-
- if (gaussians < 0)
- Log::Fatal << "Invalid number of gaussians (" << gaussians << "); must "
- << "be greater than or equal to 1." << endl;
-
- hmm = HMM<GMM<> >(size_t(states), GMM<>(size_t(gaussians),
- dimensionality), tolerance);
- }
-
- // Verify dimensionality of data.
- for (size_t i = 0; i < trainSeq.size(); ++i)
- if (trainSeq[i].n_rows != dimensionality)
- Log::Fatal << "Observation sequence " << i << " dimensionality ("
- << trainSeq[i].n_rows << " is incorrect (should be "
- << dimensionality << ")!" << endl;
-
- // Now run the training.
- if (labelsFile == "")
- {
- Log::Warn << "Unlabeled training of GMM HMMs is almost certainly not "
- << "going to produce good results!" << endl;
- hmm.Train(trainSeq);
- }
- else
- {
- hmm.Train(trainSeq, labelSeq);
- }
-
- // Save model.
- SaveRestoreUtility sr;
- SaveHMM(hmm, sr);
- sr.WriteFile(outputFile);
- }
- else
- {
- Log::Fatal << "Unknown HMM type: " << type << "; must be 'discrete', "
- << "'gaussian', or 'gmm'." << endl;
- }
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_train_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/hmm/hmm_train_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_train_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_train_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,378 @@
+/**
+ * @file hmm_train_main.cpp
+ * @author Ryan Curtin
+ *
+ * Executable which trains an HMM and saves the trained HMM to file.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+
+#include "hmm.hpp"
+#include "hmm_util.hpp"
+
+#include <mlpack/methods/gmm/gmm.hpp>
+
+PROGRAM_INFO("Hidden Markov Model (HMM) Training", "This program allows a "
+ "Hidden Markov Model to be trained on labeled or unlabeled data. It "
+ "support three types of HMMs: discrete HMMs, Gaussian HMMs, or GMM HMMs."
+ "\n\n"
+ "Either one input sequence can be specified (with --input_file), or, a "
+ "file containing files in which input sequences can be found (when "
+ "--input_file and --batch are used together). In addition, labels can be "
+ "provided in the file specified by --label_file, and if --batch is used, "
+ "the file given to --label_file should contain a list of files of labels "
+ "corresponding to the sequences in the file given to --input_file."
+ "\n\n"
+ "The HMM is trained with the Baum-Welch algorithm if no labels are "
+ "provided. The tolerance of the Baum-Welch algorithm can be set with the "
+ "--tolerance option."
+ "\n\n"
+ "Optionally, a pre-created HMM model can be used as a guess for the "
+ "transition matrix and emission probabilities; this is specifiable with "
+ "--model_file.");
+
+PARAM_STRING_REQ("input_file", "File containing input observations.", "i");
+PARAM_STRING_REQ("type", "Type of HMM: discrete | gaussian | gmm.", "t");
+
+PARAM_FLAG("batch", "If true, input_file (and if passed, labels_file) are "
+ "expected to contain a list of files to use as input observation sequences "
+ " (and label sequences).", "b");
+PARAM_INT("states", "Number of hidden states in HMM (necessary, unless "
+ "model_file is specified.", "n", 0);
+PARAM_INT("gaussians", "Number of gaussians in each GMM (necessary when type is"
+ " 'gmm'.", "g", 0);
+PARAM_STRING("model_file", "Pre-existing HMM model (optional).", "m", "");
+PARAM_STRING("labels_file", "Optional file of hidden states, used for "
+ "labeled training.", "l", "");
+PARAM_STRING("output_file", "File to save trained HMM to (XML).", "o",
+ "output_hmm.xml");
+PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
+PARAM_DOUBLE("tolerance", "Tolerance of the Baum-Welch algorithm.", "T", 1e-5);
+
+using namespace mlpack;
+using namespace mlpack::hmm;
+using namespace mlpack::distribution;
+using namespace mlpack::util;
+using namespace mlpack::gmm;
+using namespace mlpack::math;
+using namespace arma;
+using namespace std;
+
+int main(int argc, char** argv)
+{
+ // Parse command line options.
+ CLI::ParseCommandLine(argc, argv);
+
+ // Set random seed.
+ if (CLI::GetParam<int>("seed") != 0)
+ RandomSeed((size_t) CLI::GetParam<int>("seed"));
+ else
+ RandomSeed((size_t) time(NULL));
+
+ // Validate parameters.
+ const string inputFile = CLI::GetParam<string>("input_file");
+ const string labelsFile = CLI::GetParam<string>("labels_file");
+ const string modelFile = CLI::GetParam<string>("model_file");
+ const string outputFile = CLI::GetParam<string>("output_file");
+ const string type = CLI::GetParam<string>("type");
+ const int states = CLI::GetParam<int>("states");
+ const bool batch = CLI::HasParam("batch");
+ const double tolerance = CLI::GetParam<double>("tolerance");
+
+ // Validate number of states.
+ if (states == 0 && modelFile == "")
+ {
+ Log::Fatal << "Must specify number of states if model file is not "
+ << "specified!" << endl;
+ }
+
+ if (states < 0 && modelFile == "")
+ {
+ Log::Fatal << "Invalid number of states (" << states << "); must be greater"
+ << " than or equal to 1." << endl;
+ }
+
+ // Load the dataset(s) and labels.
+ vector<mat> trainSeq;
+ vector<arma::Col<size_t> > labelSeq; // May be empty.
+ if (batch)
+ {
+ // The input file contains a list of files to read.
+ Log::Info << "Reading list of training sequences from '" << inputFile
+ << "'." << endl;
+
+ fstream f(inputFile.c_str(), ios_base::in);
+
+ if (!f.is_open())
+ Log::Fatal << "Could not open '" << inputFile << "' for reading." << endl;
+
+ // Now read each line in.
+ char lineBuf[1024]; // Max 1024 characters... hopefully that is long enough.
+ f.getline(lineBuf, 1024, '\n');
+ while (!f.eof())
+ {
+ Log::Info << "Adding training sequence from '" << lineBuf << "'." << endl;
+
+ // Now read the matrix.
+ trainSeq.push_back(mat());
+ if (labelsFile == "") // Nonfatal in this case.
+ {
+ if (!data::Load(lineBuf, trainSeq.back(), false))
+ {
+ Log::Warn << "Loading training sequence from '" << lineBuf << "' "
+ << "failed. Sequence ignored." << endl;
+ trainSeq.pop_back(); // Remove last element which we did not use.
+ }
+ }
+ else
+ {
+ data::Load(lineBuf, trainSeq.back(), true);
+ }
+
+ // See if we need to transpose the data.
+ if (type == "discrete")
+ {
+ if (trainSeq.back().n_cols == 1)
+ trainSeq.back() = trans(trainSeq.back());
+ }
+
+ f.getline(lineBuf, 1024, '\n');
+ }
+
+ f.close();
+
+ // Now load labels, if we need to.
+ if (labelsFile != "")
+ {
+ f.open(labelsFile.c_str(), ios_base::in);
+
+ if (!f.is_open())
+ Log::Fatal << "Could not open '" << labelsFile << "' for reading."
+ << endl;
+
+ // Now read each line in.
+ f.getline(lineBuf, 1024, '\n');
+ while (!f.eof())
+ {
+ Log::Info << "Adding training sequence labels from '" << lineBuf
+ << "'." << endl;
+
+ // Now read the matrix.
+ Mat<size_t> label;
+ data::Load(lineBuf, label, true); // Fatal on failure.
+
+ // Ensure that matrix only has one column.
+ if (label.n_rows == 1)
+ label = trans(label);
+
+ if (label.n_cols > 1)
+ Log::Fatal << "Invalid labels; must be one-dimensional." << endl;
+
+ labelSeq.push_back(label.col(0));
+
+ f.getline(lineBuf, 1024, '\n');
+ }
+ }
+ }
+ else
+ {
+ // Only one input file.
+ trainSeq.resize(1);
+ data::Load(inputFile.c_str(), trainSeq[0], true);
+
+ // Do we need to load labels?
+ if (labelsFile != "")
+ {
+ Mat<size_t> label;
+ data::Load(labelsFile, label, true);
+
+ // Ensure that matrix only has one column.
+ if (label.n_rows == 1)
+ label = trans(label);
+
+ if (label.n_cols > 1)
+ Log::Fatal << "Invalid labels; must be one-dimensional." << endl;
+
+ // Verify the same number of observations as the data.
+ if (label.n_elem != trainSeq[labelSeq.size()].n_cols)
+ Log::Fatal << "Label sequence " << labelSeq.size() << " does not have "
+ << "the same number of points as observation sequence "
+ << labelSeq.size() << "!" << endl;
+
+ labelSeq.push_back(label.col(0));
+ }
+ }
+
+ // Now, train the HMM, since we have loaded the input data.
+ if (type == "discrete")
+ {
+ // Verify observations are valid.
+ for (size_t i = 0; i < trainSeq.size(); ++i)
+ if (trainSeq[i].n_rows > 1)
+ Log::Fatal << "Error in training sequence " << i << ": only "
+ << "one-dimensional discrete observations allowed for discrete "
+ << "HMMs!" << endl;
+
+ // Do we have a model to preload?
+ HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1), tolerance);
+
+ if (modelFile != "")
+ {
+ SaveRestoreUtility loader;
+ loader.ReadFile(modelFile);
+ LoadHMM(hmm, loader);
+ }
+ else // New model.
+ {
+ // Maximum observation is necessary so we know how to train the discrete
+ // distribution.
+ size_t maxEmission = 0;
+ for (vector<mat>::iterator it = trainSeq.begin(); it != trainSeq.end();
+ ++it)
+ {
+ size_t maxSeq = size_t(as_scalar(max(trainSeq[0], 1))) + 1;
+ if (maxSeq > maxEmission)
+ maxEmission = maxSeq;
+ }
+
+ Log::Info << maxEmission << " discrete observations in the input data."
+ << endl;
+
+ // Create HMM object.
+ hmm = HMM<DiscreteDistribution>(size_t(states),
+ DiscreteDistribution(maxEmission), tolerance);
+ }
+
+ // Do we have labels?
+ if (labelsFile == "")
+ hmm.Train(trainSeq); // Unsupervised training.
+ else
+ hmm.Train(trainSeq, labelSeq); // Supervised training.
+
+ // Finally, save the model. This should later be integrated into the HMM
+ // class itself.
+ SaveRestoreUtility sr;
+ SaveHMM(hmm, sr);
+ sr.WriteFile(outputFile);
+ }
+ else if (type == "gaussian")
+ {
+ // Create HMM object.
+ HMM<GaussianDistribution> hmm(1, GaussianDistribution(1), tolerance);
+
+ // Do we have a model to load?
+ size_t dimensionality = 0;
+ if (modelFile != "")
+ {
+ SaveRestoreUtility loader;
+ loader.ReadFile(modelFile);
+ LoadHMM(hmm, loader);
+
+ dimensionality = hmm.Emission()[0].Mean().n_elem;
+ }
+ else
+ {
+ // Find dimension of the data.
+ dimensionality = trainSeq[0].n_rows;
+
+ hmm = HMM<GaussianDistribution>(size_t(states),
+ GaussianDistribution(dimensionality), tolerance);
+ }
+
+ // Verify dimensionality of data.
+ for (size_t i = 0; i < trainSeq.size(); ++i)
+ if (trainSeq[i].n_rows != dimensionality)
+ Log::Fatal << "Observation sequence " << i << " dimensionality ("
+ << trainSeq[i].n_rows << " is incorrect (should be "
+ << dimensionality << ")!" << endl;
+
+ // Now run the training.
+ if (labelsFile == "")
+ hmm.Train(trainSeq); // Unsupervised training.
+ else
+ hmm.Train(trainSeq, labelSeq); // Supervised training.
+
+ // Finally, save the model. This should later be integrated into th HMM
+ // class itself.
+ SaveRestoreUtility sr;
+ SaveHMM(hmm, sr);
+ sr.WriteFile(outputFile);
+ }
+ else if (type == "gmm")
+ {
+ // Create HMM object.
+ HMM<GMM<> > hmm(1, GMM<>(1, 1));
+
+ // Do we have a model to load?
+ size_t dimensionality = 0;
+ if (modelFile != "")
+ {
+ SaveRestoreUtility loader;
+ loader.ReadFile(modelFile);
+ LoadHMM(hmm, loader);
+
+ dimensionality = hmm.Emission()[0].Dimensionality();
+ }
+ else
+ {
+ // Find dimension of the data.
+ dimensionality = trainSeq[0].n_rows;
+
+ const int gaussians = CLI::GetParam<int>("gaussians");
+
+ if (gaussians == 0)
+ Log::Fatal << "Number of gaussians for each GMM must be specified (-g) "
+ << "when type = 'gmm'!" << endl;
+
+ if (gaussians < 0)
+ Log::Fatal << "Invalid number of gaussians (" << gaussians << "); must "
+ << "be greater than or equal to 1." << endl;
+
+ hmm = HMM<GMM<> >(size_t(states), GMM<>(size_t(gaussians),
+ dimensionality), tolerance);
+ }
+
+ // Verify dimensionality of data.
+ for (size_t i = 0; i < trainSeq.size(); ++i)
+ if (trainSeq[i].n_rows != dimensionality)
+ Log::Fatal << "Observation sequence " << i << " dimensionality ("
+ << trainSeq[i].n_rows << " is incorrect (should be "
+ << dimensionality << ")!" << endl;
+
+ // Now run the training.
+ if (labelsFile == "")
+ {
+ Log::Warn << "Unlabeled training of GMM HMMs is almost certainly not "
+ << "going to produce good results!" << endl;
+ hmm.Train(trainSeq);
+ }
+ else
+ {
+ hmm.Train(trainSeq, labelSeq);
+ }
+
+ // Save model.
+ SaveRestoreUtility sr;
+ SaveHMM(hmm, sr);
+ sr.WriteFile(outputFile);
+ }
+ else
+ {
+ Log::Fatal << "Unknown HMM type: " << type << "; must be 'discrete', "
+ << "'gaussian', or 'gmm'." << endl;
+ }
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_util.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/hmm/hmm_util.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_util.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,57 +0,0 @@
-/**
- * @file hmm_util.hpp
- * @author Ryan Curtin
- *
- * Save/load utilities for HMMs. This should be eventually merged into the HMM
- * class itself.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_HMM_HMM_UTIL_HPP
-#define __MLPACK_METHODS_HMM_HMM_UTIL_HPP
-
-#include "hmm.hpp"
-
-namespace mlpack {
-namespace hmm {
-
-/**
- * Save an HMM to file. This only works for GMMs, DiscreteDistributions, and
- * GaussianDistributions.
- *
- * @tparam Distribution Distribution type of HMM.
- * @param sr SaveRestoreUtility to use.
- */
-template<typename Distribution>
-void SaveHMM(const HMM<Distribution>& hmm, util::SaveRestoreUtility& sr);
-
-/**
- * Load an HMM from file. This only works for GMMs, DiscreteDistributions, and
- * GaussianDistributions.
- *
- * @tparam Distribution Distribution type of HMM.
- * @param sr SaveRestoreUtility to use.
- */
-template<typename Distribution>
-void LoadHMM(HMM<Distribution>& hmm, util::SaveRestoreUtility& sr);
-
-}; // namespace hmm
-}; // namespace mlpack
-
-// Include implementation.
-#include "hmm_util_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_util.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/hmm/hmm_util.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_util.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_util.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,57 @@
+/**
+ * @file hmm_util.hpp
+ * @author Ryan Curtin
+ *
+ * Save/load utilities for HMMs. This should be eventually merged into the HMM
+ * class itself.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_HMM_HMM_UTIL_HPP
+#define __MLPACK_METHODS_HMM_HMM_UTIL_HPP
+
+#include "hmm.hpp"
+
+namespace mlpack {
+namespace hmm {
+
+/**
+ * Save an HMM to file. This only works for GMMs, DiscreteDistributions, and
+ * GaussianDistributions.
+ *
+ * @tparam Distribution Distribution type of HMM.
+ * @param sr SaveRestoreUtility to use.
+ */
+template<typename Distribution>
+void SaveHMM(const HMM<Distribution>& hmm, util::SaveRestoreUtility& sr);
+
+/**
+ * Load an HMM from file. This only works for GMMs, DiscreteDistributions, and
+ * GaussianDistributions.
+ *
+ * @tparam Distribution Distribution type of HMM.
+ * @param sr SaveRestoreUtility to use.
+ */
+template<typename Distribution>
+void LoadHMM(HMM<Distribution>& hmm, util::SaveRestoreUtility& sr);
+
+}; // namespace hmm
+}; // namespace mlpack
+
+// Include implementation.
+#include "hmm_util_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_util_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/hmm/hmm_util_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_util_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,256 +0,0 @@
-/**
- * @file hmm_util_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of HMM load/save functions.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_HMM_HMM_UTIL_IMPL_HPP
-#define __MLPACK_METHODS_HMM_HMM_UTIL_IMPL_HPP
-
-// In case it hasn't already been included.
-#include "hmm_util.hpp"
-
-#include <mlpack/methods/gmm/gmm.hpp>
-
-namespace mlpack {
-namespace hmm {
-
-template<typename Distribution>
-void SaveHMM(const HMM<Distribution>& hmm, util::SaveRestoreUtility& sr)
-{
- Log::Fatal << "HMM save not implemented for arbitrary distributions."
- << std::endl;
-}
-
-template<>
-void SaveHMM(const HMM<distribution::DiscreteDistribution>& hmm,
- util::SaveRestoreUtility& sr)
-{
- std::string type = "discrete";
- size_t states = hmm.Transition().n_rows;
-
- sr.SaveParameter(type, "hmm_type");
- sr.SaveParameter(states, "hmm_states");
- sr.SaveParameter(hmm.Transition(), "hmm_transition");
-
- // Now the emissions.
- for (size_t i = 0; i < states; ++i)
- {
- // Generate name.
- std::stringstream s;
- s << "hmm_emission_distribution_" << i;
- sr.SaveParameter(hmm.Emission()[i].Probabilities(), s.str());
- }
-}
-
-template<>
-void SaveHMM(const HMM<distribution::GaussianDistribution>& hmm,
- util::SaveRestoreUtility& sr)
-{
- std::string type = "gaussian";
- size_t states = hmm.Transition().n_rows;
-
- sr.SaveParameter(type, "hmm_type");
- sr.SaveParameter(states, "hmm_states");
- sr.SaveParameter(hmm.Transition(), "hmm_transition");
-
- // Now the emissions.
- for (size_t i = 0; i < states; ++i)
- {
- // Generate name.
- std::stringstream s;
- s << "hmm_emission_mean_" << i;
- sr.SaveParameter(hmm.Emission()[i].Mean(), s.str());
-
- s.str("");
- s << "hmm_emission_covariance_" << i;
- sr.SaveParameter(hmm.Emission()[i].Covariance(), s.str());
- }
-}
-
-template<>
-void SaveHMM(const HMM<gmm::GMM<> >& hmm,
- util::SaveRestoreUtility& sr)
-{
- std::string type = "gmm";
- size_t states = hmm.Transition().n_rows;
-
- sr.SaveParameter(type, "hmm_type");
- sr.SaveParameter(states, "hmm_states");
- sr.SaveParameter(hmm.Transition(), "hmm_transition");
-
- // Now the emissions.
- for (size_t i = 0; i < states; ++i)
- {
- // Generate name.
- std::stringstream s;
- s << "hmm_emission_" << i << "_gaussians";
- sr.SaveParameter(hmm.Emission()[i].Gaussians(), s.str());
-
- s.str("");
- s << "hmm_emission_" << i << "_weights";
- sr.SaveParameter(hmm.Emission()[i].Weights(), s.str());
-
- for (size_t g = 0; g < hmm.Emission()[i].Gaussians(); ++g)
- {
- s.str("");
- s << "hmm_emission_" << i << "_gaussian_" << g << "_mean";
- sr.SaveParameter(hmm.Emission()[i].Means()[g], s.str());
-
- s.str("");
- s << "hmm_emission_" << i << "_gaussian_" << g << "_covariance";
- sr.SaveParameter(hmm.Emission()[i].Covariances()[g], s.str());
- }
- }
-}
-
-template<typename Distribution>
-void LoadHMM(HMM<Distribution>& hmm, util::SaveRestoreUtility& sr)
-{
- Log::Fatal << "HMM load not implemented for arbitrary distributions."
- << std::endl;
-}
-
-template<>
-void LoadHMM(HMM<distribution::DiscreteDistribution>& hmm,
- util::SaveRestoreUtility& sr)
-{
- std::string type;
- size_t states;
-
- sr.LoadParameter(type, "hmm_type");
- if (type != "discrete")
- {
- Log::Fatal << "Cannot load non-discrete HMM (of type " << type << ") as "
- << "discrete HMM!" << std::endl;
- }
-
- sr.LoadParameter(states, "hmm_states");
-
- // Load transition matrix.
- sr.LoadParameter(hmm.Transition(), "hmm_transition");
-
- // Now each emission distribution.
- hmm.Emission().resize(states);
- for (size_t i = 0; i < states; ++i)
- {
- std::stringstream s;
- s << "hmm_emission_distribution_" << i;
- sr.LoadParameter(hmm.Emission()[i].Probabilities(), s.str());
- }
-
- hmm.Dimensionality() = 1;
-}
-
-template<>
-void LoadHMM(HMM<distribution::GaussianDistribution>& hmm,
- util::SaveRestoreUtility& sr)
-{
- std::string type;
- size_t states;
-
- sr.LoadParameter(type, "hmm_type");
- if (type != "gaussian")
- {
- Log::Fatal << "Cannot load non-Gaussian HMM (of type " << type << ") as "
- << "a Gaussian HMM!" << std::endl;
- }
-
- sr.LoadParameter(states, "hmm_states");
-
- // Load transition matrix.
- sr.LoadParameter(hmm.Transition(), "hmm_transition");
-
- // Now each emission distribution.
- hmm.Emission().resize(states);
- for (size_t i = 0; i < states; ++i)
- {
- std::stringstream s;
- s << "hmm_emission_mean_" << i;
- sr.LoadParameter(hmm.Emission()[i].Mean(), s.str());
-
- s.str("");
- s << "hmm_emission_covariance_" << i;
- sr.LoadParameter(hmm.Emission()[i].Covariance(), s.str());
- }
-
- hmm.Dimensionality() = hmm.Emission()[0].Mean().n_elem;
-}
-
-template<>
-void LoadHMM(HMM<gmm::GMM<> >& hmm,
- util::SaveRestoreUtility& sr)
-{
- std::string type;
- size_t states;
-
- sr.LoadParameter(type, "hmm_type");
- if (type != "gmm")
- {
- Log::Fatal << "Cannot load non-GMM HMM (of type " << type << ") as "
- << "a Gaussian Mixture Model HMM!" << std::endl;
- }
-
- sr.LoadParameter(states, "hmm_states");
-
- // Load transition matrix.
- sr.LoadParameter(hmm.Transition(), "hmm_transition");
-
- // Now each emission distribution.
- hmm.Emission().resize(states, gmm::GMM<>(1, 1));
- for (size_t i = 0; i < states; ++i)
- {
- std::stringstream s;
- s << "hmm_emission_" << i << "_gaussians";
- size_t gaussians;
- sr.LoadParameter(gaussians, s.str());
-
- s.str("");
- // Extract dimensionality.
- arma::vec meanzero;
- s << "hmm_emission_" << i << "_gaussian_0_mean";
- sr.LoadParameter(meanzero, s.str());
- size_t dimensionality = meanzero.n_elem;
-
- // Initialize GMM correctly.
- hmm.Emission()[i].Gaussians() = gaussians;
- hmm.Emission()[i].Dimensionality() = dimensionality;
-
- for (size_t g = 0; g < gaussians; ++g)
- {
- s.str("");
- s << "hmm_emission_" << i << "_gaussian_" << g << "_mean";
- sr.LoadParameter(hmm.Emission()[i].Means()[g], s.str());
-
- s.str("");
- s << "hmm_emission_" << i << "_gaussian_" << g << "_covariance";
- sr.LoadParameter(hmm.Emission()[i].Covariances()[g], s.str());
- }
-
- s.str("");
- s << "hmm_emission_" << i << "_weights";
- sr.LoadParameter(hmm.Emission()[i].Weights(), s.str());
- }
-
- hmm.Dimensionality() = hmm.Emission()[0].Dimensionality();
-}
-
-}; // namespace hmm
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_util_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/hmm/hmm_util_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_util_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_util_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,256 @@
+/**
+ * @file hmm_util_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of HMM load/save functions.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_HMM_HMM_UTIL_IMPL_HPP
+#define __MLPACK_METHODS_HMM_HMM_UTIL_IMPL_HPP
+
+// In case it hasn't already been included.
+#include "hmm_util.hpp"
+
+#include <mlpack/methods/gmm/gmm.hpp>
+
+namespace mlpack {
+namespace hmm {
+
+template<typename Distribution>
+void SaveHMM(const HMM<Distribution>& hmm, util::SaveRestoreUtility& sr)
+{
+ Log::Fatal << "HMM save not implemented for arbitrary distributions."
+ << std::endl;
+}
+
+template<>
+void SaveHMM(const HMM<distribution::DiscreteDistribution>& hmm,
+ util::SaveRestoreUtility& sr)
+{
+ std::string type = "discrete";
+ size_t states = hmm.Transition().n_rows;
+
+ sr.SaveParameter(type, "hmm_type");
+ sr.SaveParameter(states, "hmm_states");
+ sr.SaveParameter(hmm.Transition(), "hmm_transition");
+
+ // Now the emissions.
+ for (size_t i = 0; i < states; ++i)
+ {
+ // Generate name.
+ std::stringstream s;
+ s << "hmm_emission_distribution_" << i;
+ sr.SaveParameter(hmm.Emission()[i].Probabilities(), s.str());
+ }
+}
+
+template<>
+void SaveHMM(const HMM<distribution::GaussianDistribution>& hmm,
+ util::SaveRestoreUtility& sr)
+{
+ std::string type = "gaussian";
+ size_t states = hmm.Transition().n_rows;
+
+ sr.SaveParameter(type, "hmm_type");
+ sr.SaveParameter(states, "hmm_states");
+ sr.SaveParameter(hmm.Transition(), "hmm_transition");
+
+ // Now the emissions.
+ for (size_t i = 0; i < states; ++i)
+ {
+ // Generate name.
+ std::stringstream s;
+ s << "hmm_emission_mean_" << i;
+ sr.SaveParameter(hmm.Emission()[i].Mean(), s.str());
+
+ s.str("");
+ s << "hmm_emission_covariance_" << i;
+ sr.SaveParameter(hmm.Emission()[i].Covariance(), s.str());
+ }
+}
+
+template<>
+void SaveHMM(const HMM<gmm::GMM<> >& hmm,
+ util::SaveRestoreUtility& sr)
+{
+ std::string type = "gmm";
+ size_t states = hmm.Transition().n_rows;
+
+ sr.SaveParameter(type, "hmm_type");
+ sr.SaveParameter(states, "hmm_states");
+ sr.SaveParameter(hmm.Transition(), "hmm_transition");
+
+ // Now the emissions.
+ for (size_t i = 0; i < states; ++i)
+ {
+ // Generate name.
+ std::stringstream s;
+ s << "hmm_emission_" << i << "_gaussians";
+ sr.SaveParameter(hmm.Emission()[i].Gaussians(), s.str());
+
+ s.str("");
+ s << "hmm_emission_" << i << "_weights";
+ sr.SaveParameter(hmm.Emission()[i].Weights(), s.str());
+
+ for (size_t g = 0; g < hmm.Emission()[i].Gaussians(); ++g)
+ {
+ s.str("");
+ s << "hmm_emission_" << i << "_gaussian_" << g << "_mean";
+ sr.SaveParameter(hmm.Emission()[i].Means()[g], s.str());
+
+ s.str("");
+ s << "hmm_emission_" << i << "_gaussian_" << g << "_covariance";
+ sr.SaveParameter(hmm.Emission()[i].Covariances()[g], s.str());
+ }
+ }
+}
+
+template<typename Distribution>
+void LoadHMM(HMM<Distribution>& hmm, util::SaveRestoreUtility& sr)
+{
+ Log::Fatal << "HMM load not implemented for arbitrary distributions."
+ << std::endl;
+}
+
+template<>
+void LoadHMM(HMM<distribution::DiscreteDistribution>& hmm,
+ util::SaveRestoreUtility& sr)
+{
+ std::string type;
+ size_t states;
+
+ sr.LoadParameter(type, "hmm_type");
+ if (type != "discrete")
+ {
+ Log::Fatal << "Cannot load non-discrete HMM (of type " << type << ") as "
+ << "discrete HMM!" << std::endl;
+ }
+
+ sr.LoadParameter(states, "hmm_states");
+
+ // Load transition matrix.
+ sr.LoadParameter(hmm.Transition(), "hmm_transition");
+
+ // Now each emission distribution.
+ hmm.Emission().resize(states);
+ for (size_t i = 0; i < states; ++i)
+ {
+ std::stringstream s;
+ s << "hmm_emission_distribution_" << i;
+ sr.LoadParameter(hmm.Emission()[i].Probabilities(), s.str());
+ }
+
+ hmm.Dimensionality() = 1;
+}
+
+template<>
+void LoadHMM(HMM<distribution::GaussianDistribution>& hmm,
+ util::SaveRestoreUtility& sr)
+{
+ std::string type;
+ size_t states;
+
+ sr.LoadParameter(type, "hmm_type");
+ if (type != "gaussian")
+ {
+ Log::Fatal << "Cannot load non-Gaussian HMM (of type " << type << ") as "
+ << "a Gaussian HMM!" << std::endl;
+ }
+
+ sr.LoadParameter(states, "hmm_states");
+
+ // Load transition matrix.
+ sr.LoadParameter(hmm.Transition(), "hmm_transition");
+
+ // Now each emission distribution.
+ hmm.Emission().resize(states);
+ for (size_t i = 0; i < states; ++i)
+ {
+ std::stringstream s;
+ s << "hmm_emission_mean_" << i;
+ sr.LoadParameter(hmm.Emission()[i].Mean(), s.str());
+
+ s.str("");
+ s << "hmm_emission_covariance_" << i;
+ sr.LoadParameter(hmm.Emission()[i].Covariance(), s.str());
+ }
+
+ hmm.Dimensionality() = hmm.Emission()[0].Mean().n_elem;
+}
+
+template<>
+void LoadHMM(HMM<gmm::GMM<> >& hmm,
+ util::SaveRestoreUtility& sr)
+{
+ std::string type;
+ size_t states;
+
+ sr.LoadParameter(type, "hmm_type");
+ if (type != "gmm")
+ {
+ Log::Fatal << "Cannot load non-GMM HMM (of type " << type << ") as "
+ << "a Gaussian Mixture Model HMM!" << std::endl;
+ }
+
+ sr.LoadParameter(states, "hmm_states");
+
+ // Load transition matrix.
+ sr.LoadParameter(hmm.Transition(), "hmm_transition");
+
+ // Now each emission distribution.
+ hmm.Emission().resize(states, gmm::GMM<>(1, 1));
+ for (size_t i = 0; i < states; ++i)
+ {
+ std::stringstream s;
+ s << "hmm_emission_" << i << "_gaussians";
+ size_t gaussians;
+ sr.LoadParameter(gaussians, s.str());
+
+ s.str("");
+ // Extract dimensionality.
+ arma::vec meanzero;
+ s << "hmm_emission_" << i << "_gaussian_0_mean";
+ sr.LoadParameter(meanzero, s.str());
+ size_t dimensionality = meanzero.n_elem;
+
+ // Initialize GMM correctly.
+ hmm.Emission()[i].Gaussians() = gaussians;
+ hmm.Emission()[i].Dimensionality() = dimensionality;
+
+ for (size_t g = 0; g < gaussians; ++g)
+ {
+ s.str("");
+ s << "hmm_emission_" << i << "_gaussian_" << g << "_mean";
+ sr.LoadParameter(hmm.Emission()[i].Means()[g], s.str());
+
+ s.str("");
+ s << "hmm_emission_" << i << "_gaussian_" << g << "_covariance";
+ sr.LoadParameter(hmm.Emission()[i].Covariances()[g], s.str());
+ }
+
+ s.str("");
+ s << "hmm_emission_" << i << "_weights";
+ sr.LoadParameter(hmm.Emission()[i].Weights(), s.str());
+ }
+
+ hmm.Dimensionality() = hmm.Emission()[0].Dimensionality();
+}
+
+}; // namespace hmm
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_viterbi_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/hmm/hmm_viterbi_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_viterbi_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,121 +0,0 @@
-/**
- * @file hmm_viterbi_main.cpp
- * @author Ryan Curtin
- *
- * Compute the most probably hidden state sequence of a given observation
- * sequence for a given HMM.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-
-#include "hmm.hpp"
-#include "hmm_util.hpp"
-
-#include <mlpack/methods/gmm/gmm.hpp>
-
-PROGRAM_INFO("Hidden Markov Model (HMM) Viterbi State Prediction", "This "
- "utility takes an already-trained HMM (--model_file) and evaluates the "
- "most probably hidden state sequence of a given sequence of observations "
- "(--input_file), using the Viterbi algorithm. The computed state sequence "
- "is saved to the specified output file (--output_file).");
-
-PARAM_STRING_REQ("input_file", "File containing observations,", "i");
-PARAM_STRING_REQ("model_file", "File containing HMM (XML).", "m");
-PARAM_STRING("output_file", "File to save predicted state sequence to.", "o",
- "output.csv");
-
-using namespace mlpack;
-using namespace mlpack::hmm;
-using namespace mlpack::distribution;
-using namespace mlpack::util;
-using namespace mlpack::gmm;
-using namespace arma;
-using namespace std;
-
-int main(int argc, char** argv)
-{
- // Parse command line options.
- CLI::ParseCommandLine(argc, argv);
-
- // Load observations.
- const string inputFile = CLI::GetParam<string>("input_file");
- const string modelFile = CLI::GetParam<string>("model_file");
-
- mat dataSeq;
- data::Load(inputFile, dataSeq, true);
-
- // Load model, but first we have to determine its type.
- SaveRestoreUtility sr;
- sr.ReadFile(modelFile);
- string type;
- sr.LoadParameter(type, "hmm_type");
-
- arma::Col<size_t> sequence;
- if (type == "discrete")
- {
- HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));
-
- LoadHMM(hmm, sr);
-
- // Verify only one row in observations.
- if (dataSeq.n_cols == 1)
- dataSeq = trans(dataSeq);
-
- if (dataSeq.n_rows > 1)
- Log::Fatal << "Only one-dimensional discrete observations allowed for "
- << "discrete HMMs!" << endl;
-
- hmm.Predict(dataSeq, sequence);
- }
- else if (type == "gaussian")
- {
- HMM<GaussianDistribution> hmm(1, GaussianDistribution(1));
-
- LoadHMM(hmm, sr);
-
- // Verify correct dimensionality.
- if (dataSeq.n_rows != hmm.Emission()[0].Mean().n_elem)
- Log::Fatal << "Observation dimensionality (" << dataSeq.n_rows << ") "
- << "does not match HMM Gaussian dimensionality ("
- << hmm.Emission()[0].Mean().n_elem << ")!" << endl;
-
- hmm.Predict(dataSeq, sequence);
- }
- else if (type == "gmm")
- {
- HMM<GMM<> > hmm(1, GMM<>(1, 1));
-
- LoadHMM(hmm, sr);
-
- // Verify correct dimensionality.
- if (dataSeq.n_rows != hmm.Emission()[0].Dimensionality())
- Log::Fatal << "Observation dimensionality (" << dataSeq.n_rows << ") "
- << "does not match HMM Gaussian dimensionality ("
- << hmm.Emission()[0].Dimensionality() << ")!" << endl;
-
- hmm.Predict(dataSeq, sequence);
- }
- else
- {
- Log::Fatal << "Unknown HMM type '" << type << "' in file '" << modelFile
- << "'!" << endl;
- }
-
- // Save output.
- const string outputFile = CLI::GetParam<string>("output_file");
- data::Save(outputFile, sequence, true);
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_viterbi_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/hmm/hmm_viterbi_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_viterbi_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/hmm/hmm_viterbi_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,121 @@
+/**
+ * @file hmm_viterbi_main.cpp
+ * @author Ryan Curtin
+ *
+ * Compute the most probably hidden state sequence of a given observation
+ * sequence for a given HMM.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+
+#include "hmm.hpp"
+#include "hmm_util.hpp"
+
+#include <mlpack/methods/gmm/gmm.hpp>
+
+PROGRAM_INFO("Hidden Markov Model (HMM) Viterbi State Prediction", "This "
+ "utility takes an already-trained HMM (--model_file) and evaluates the "
+ "most probably hidden state sequence of a given sequence of observations "
+ "(--input_file), using the Viterbi algorithm. The computed state sequence "
+ "is saved to the specified output file (--output_file).");
+
+PARAM_STRING_REQ("input_file", "File containing observations,", "i");
+PARAM_STRING_REQ("model_file", "File containing HMM (XML).", "m");
+PARAM_STRING("output_file", "File to save predicted state sequence to.", "o",
+ "output.csv");
+
+using namespace mlpack;
+using namespace mlpack::hmm;
+using namespace mlpack::distribution;
+using namespace mlpack::util;
+using namespace mlpack::gmm;
+using namespace arma;
+using namespace std;
+
+int main(int argc, char** argv)
+{
+ // Parse command line options.
+ CLI::ParseCommandLine(argc, argv);
+
+ // Load observations.
+ const string inputFile = CLI::GetParam<string>("input_file");
+ const string modelFile = CLI::GetParam<string>("model_file");
+
+ mat dataSeq;
+ data::Load(inputFile, dataSeq, true);
+
+ // Load model, but first we have to determine its type.
+ SaveRestoreUtility sr;
+ sr.ReadFile(modelFile);
+ string type;
+ sr.LoadParameter(type, "hmm_type");
+
+ arma::Col<size_t> sequence;
+ if (type == "discrete")
+ {
+ HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));
+
+ LoadHMM(hmm, sr);
+
+ // Verify only one row in observations.
+ if (dataSeq.n_cols == 1)
+ dataSeq = trans(dataSeq);
+
+ if (dataSeq.n_rows > 1)
+ Log::Fatal << "Only one-dimensional discrete observations allowed for "
+ << "discrete HMMs!" << endl;
+
+ hmm.Predict(dataSeq, sequence);
+ }
+ else if (type == "gaussian")
+ {
+ HMM<GaussianDistribution> hmm(1, GaussianDistribution(1));
+
+ LoadHMM(hmm, sr);
+
+ // Verify correct dimensionality.
+ if (dataSeq.n_rows != hmm.Emission()[0].Mean().n_elem)
+ Log::Fatal << "Observation dimensionality (" << dataSeq.n_rows << ") "
+ << "does not match HMM Gaussian dimensionality ("
+ << hmm.Emission()[0].Mean().n_elem << ")!" << endl;
+
+ hmm.Predict(dataSeq, sequence);
+ }
+ else if (type == "gmm")
+ {
+ HMM<GMM<> > hmm(1, GMM<>(1, 1));
+
+ LoadHMM(hmm, sr);
+
+ // Verify correct dimensionality.
+ if (dataSeq.n_rows != hmm.Emission()[0].Dimensionality())
+ Log::Fatal << "Observation dimensionality (" << dataSeq.n_rows << ") "
+ << "does not match HMM Gaussian dimensionality ("
+ << hmm.Emission()[0].Dimensionality() << ")!" << endl;
+
+ hmm.Predict(dataSeq, sequence);
+ }
+ else
+ {
+ Log::Fatal << "Unknown HMM type '" << type << "' in file '" << modelFile
+ << "'!" << endl;
+ }
+
+ // Save output.
+ const string outputFile = CLI::GetParam<string>("output_file");
+ data::Save(outputFile, sequence, true);
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kernel_pca/kernel_pca.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/kernel_pca/kernel_pca.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kernel_pca/kernel_pca.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,102 +0,0 @@
-/**
- * @file kernel_pca.hpp
- * @author Ajinkya Kale
- *
- * Defines the KernelPCA class to perform Kernel Principal Components Analysis
- * on the specified data set.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_KERNEL_PCA_KERNEL_PCA_HPP
-#define __MLPACK_METHODS_KERNEL_PCA_KERNEL_PCA_HPP
-
-#include <mlpack/core.hpp>
-#include <mlpack/core/kernels/linear_kernel.hpp>
-
-namespace mlpack {
-namespace kpca {
-
-template <typename KernelType>
-class KernelPCA
-{
- public:
- KernelPCA(const KernelType kernel = KernelType(),
- const bool scaleData = false);
-
- /**
- * Apply Kernel Principal Component Analysis to the provided data set.
- *
- * @param data - Data matrix
- * @param transformedData - Data with PCA applied
- * @param eigVal - contains eigen values in a column vector
- * @param coeff - PCA Loadings/Coeffs/EigenVectors
- */
- void Apply(const arma::mat& data,
- arma::mat& transformedData,
- arma::vec& eigVal,
- arma::mat& coeff);
-
- /**
- * Apply Kernel Principal Component Analysis to the provided data set.
- *
- * @param data - Data matrix
- * @param transformedData - Data with PCA applied
- * @param eigVal - contains eigen values in a column vector
- */
- void Apply(const arma::mat& data,
- arma::mat& transformedData,
- arma::vec& eigVal);
-
- /**
- * Apply Dimensionality Reduction using Kernel Principal Component Analysis
- * to the provided data set.
- *
- * @param data - M x N Data matrix
- * @param newDimension - matrix consisting of N column vectors,
- * where each vector is the projection of the corresponding data vector
- * from data matrix onto the basis vectors contained in the columns of
- * coeff/eigen vector matrix with only newDimension number of columns chosen.
- */
- void Apply(arma::mat& data, const size_t newDimension);
-
- //! Get the kernel.
- const KernelType& Kernel() const { return kernel; }
- //! Modify the kernel.
- KernelType& Kernel() { return kernel; }
-
- //! Return whether or not this KernelPCA object will scale (by standard
- //! deviation) the data when kernel PCA is performed.
- bool ScaleData() const { return scaleData; }
- //! Modify whether or not this KernelPCA object will scale (by standard
- //! deviation) the data when kernel PCA is performed.
- bool& ScaleData() { return scaleData; }
-
- private:
- //! The instantiated kernel.
- KernelType kernel;
- //! If true, the data will be scaled (by standard deviation) when Apply() is
- //! run.
- bool scaleData;
-
-}; // class KernelPCA
-
-}; // namespace kpca
-}; // namespace mlpack
-
-// Include implementation.
-#include "kernel_pca_impl.hpp"
-
-#endif // __MLPACK_METHODS_KERNEL_PCA_KERNEL_PCA_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kernel_pca/kernel_pca.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/kernel_pca/kernel_pca.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kernel_pca/kernel_pca.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kernel_pca/kernel_pca.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,102 @@
+/**
+ * @file kernel_pca.hpp
+ * @author Ajinkya Kale
+ *
+ * Defines the KernelPCA class to perform Kernel Principal Components Analysis
+ * on the specified data set.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_KERNEL_PCA_KERNEL_PCA_HPP
+#define __MLPACK_METHODS_KERNEL_PCA_KERNEL_PCA_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/kernels/linear_kernel.hpp>
+
+namespace mlpack {
+namespace kpca {
+
+template <typename KernelType>
+class KernelPCA
+{
+ public:
+ KernelPCA(const KernelType kernel = KernelType(),
+ const bool scaleData = false);
+
+ /**
+ * Apply Kernel Principal Component Analysis to the provided data set.
+ *
+ * @param data - Data matrix
+ * @param transformedData - Data with PCA applied
+ * @param eigVal - contains eigen values in a column vector
+ * @param coeff - PCA Loadings/Coeffs/EigenVectors
+ */
+ void Apply(const arma::mat& data,
+ arma::mat& transformedData,
+ arma::vec& eigVal,
+ arma::mat& coeff);
+
+ /**
+ * Apply Kernel Principal Component Analysis to the provided data set.
+ *
+ * @param data - Data matrix
+ * @param transformedData - Data with PCA applied
+ * @param eigVal - contains eigen values in a column vector
+ */
+ void Apply(const arma::mat& data,
+ arma::mat& transformedData,
+ arma::vec& eigVal);
+
+ /**
+ * Apply Dimensionality Reduction using Kernel Principal Component Analysis
+ * to the provided data set.
+ *
+ * @param data - M x N Data matrix
+ * @param newDimension - matrix consisting of N column vectors,
+ * where each vector is the projection of the corresponding data vector
+ * from data matrix onto the basis vectors contained in the columns of
+ * coeff/eigen vector matrix with only newDimension number of columns chosen.
+ */
+ void Apply(arma::mat& data, const size_t newDimension);
+
+ //! Get the kernel.
+ const KernelType& Kernel() const { return kernel; }
+ //! Modify the kernel.
+ KernelType& Kernel() { return kernel; }
+
+ //! Return whether or not this KernelPCA object will scale (by standard
+ //! deviation) the data when kernel PCA is performed.
+ bool ScaleData() const { return scaleData; }
+ //! Modify whether or not this KernelPCA object will scale (by standard
+ //! deviation) the data when kernel PCA is performed.
+ bool& ScaleData() { return scaleData; }
+
+ private:
+ //! The instantiated kernel.
+ KernelType kernel;
+ //! If true, the data will be scaled (by standard deviation) when Apply() is
+ //! run.
+ bool scaleData;
+
+}; // class KernelPCA
+
+}; // namespace kpca
+}; // namespace mlpack
+
+// Include implementation.
+#include "kernel_pca_impl.hpp"
+
+#endif // __MLPACK_METHODS_KERNEL_PCA_KERNEL_PCA_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kernel_pca/kernel_pca_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/kernel_pca/kernel_pca_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kernel_pca/kernel_pca_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,144 +0,0 @@
-/**
- * @file kernel_pca_impl.hpp
- * @author Ajinkya Kale
- *
- * Implementation of KernelPCA class to perform Kernel Principal Components
- * Analysis on the specified data set.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_KERNEL_PCA_KERNEL_PCA_IMPL_HPP
-#define __MLPACK_METHODS_KERNEL_PCA_KERNEL_PCA_IMPL_HPP
-
-// In case it hasn't already been included.
-#include "kernel_pca.hpp"
-
-#include <iostream>
-
-namespace mlpack {
-namespace kpca {
-
-template <typename KernelType>
-arma::mat GetKernelMatrix(KernelType kernel, arma::mat transData);
-
-template <typename KernelType>
-KernelPCA<KernelType>::KernelPCA(const KernelType kernel,
- const bool scaleData) :
- kernel(kernel),
- scaleData(scaleData)
-{ }
-
-/**
- * Apply Kernel Principal Component Analysis to the provided data set.
- *
- * @param data - Data matrix
- * @param transformedData - Data with KernelPCA applied
- * @param eigVal - contains eigen values in a column vector
- * @param coeff - KernelPCA Loadings/Coeffs/EigenVectors
- */
-template <typename KernelType>
-void KernelPCA<KernelType>::Apply(const arma::mat& data,
- arma::mat& transformedData,
- arma::vec& eigVal,
- arma::mat& coeffs)
-{
- arma::mat transData = ccov(data);
-
- // Center the data if necessary.
-
- // Scale the data if necessary.
- if (scaleData)
- {
- transData = transData / (arma::ones<arma::colvec>(transData.n_rows) *
- stddev(transData, 0, 0));
- }
-
- arma::mat centeredData = trans(transData);
- arma::mat kernelMat = GetKernelMatrix(kernel, centeredData);
- arma::eig_sym(eigVal, coeffs, kernelMat);
-
- int n_eigVal = eigVal.n_elem;
- for(int i = 0; i < floor(n_eigVal / 2.0); i++)
- eigVal.swap_rows(i, (n_eigVal - 1) - i);
-
- coeffs = arma::fliplr(coeffs);
-
- transformedData = trans(coeffs) * data;
- arma::colvec transformedDataMean = arma::mean(transformedData, 1);
- transformedData = transformedData - (transformedDataMean *
- arma::ones<arma::rowvec>(transformedData.n_cols));
-}
-
-/**
- * Apply Kernel Principal Component Analysis to the provided data set.
- *
- * @param data - Data matrix
- * @param transformedData - Data with KernelPCA applied
- * @param eigVal - contains eigen values in a column vector
- */
-template <typename KernelType>
-void KernelPCA<KernelType>::Apply(const arma::mat& data,
- arma::mat& transformedData,
- arma::vec& eigVal)
-{
- arma::mat coeffs;
- Apply(data, transformedData, eigVal, coeffs);
-}
-
-/**
- * Apply Dimensionality Reduction using Kernel Principal Component Analysis
- * to the provided data set.
- *
- * @param data - M x N Data matrix
- * @param newDimension - matrix consisting of N column vectors,
- * where each vector is the projection of the corresponding data vector
- * from data matrix onto the basis vectors contained in the columns of
- * coeff/eigen vector matrix with only newDimension number of columns chosen.
- */
-template <typename KernelType>
-void KernelPCA<KernelType>::Apply(arma::mat& data, const size_t newDimension)
-{
- arma::mat coeffs;
- arma::vec eigVal;
-
- Apply(data, data, eigVal, coeffs);
-
- if (newDimension < coeffs.n_rows && newDimension > 0)
- data.shed_rows(newDimension, data.n_rows - 1);
-}
-
-template <typename KernelType>
-arma::mat GetKernelMatrix(KernelType kernel, arma::mat transData)
-{
- arma::mat kernelMat(transData.n_rows, transData.n_rows);
-
- for (size_t i = 0; i < transData.n_rows; i++)
- {
- for (size_t j = 0; j < transData.n_rows; j++)
- {
- arma::vec v1 = trans(transData.row(i));
- arma::vec v2 = trans(transData.row(j));
- kernelMat(i, j) = kernel.Evaluate(v1, v2);
- }
- }
-
- return kernelMat;
-}
-
-}; // namespace mlpack
-}; // namespace kpca
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kernel_pca/kernel_pca_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/kernel_pca/kernel_pca_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kernel_pca/kernel_pca_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kernel_pca/kernel_pca_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,144 @@
+/**
+ * @file kernel_pca_impl.hpp
+ * @author Ajinkya Kale
+ *
+ * Implementation of KernelPCA class to perform Kernel Principal Components
+ * Analysis on the specified data set.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_KERNEL_PCA_KERNEL_PCA_IMPL_HPP
+#define __MLPACK_METHODS_KERNEL_PCA_KERNEL_PCA_IMPL_HPP
+
+// In case it hasn't already been included.
+#include "kernel_pca.hpp"
+
+#include <iostream>
+
+namespace mlpack {
+namespace kpca {
+
+template <typename KernelType>
+arma::mat GetKernelMatrix(KernelType kernel, arma::mat transData);
+
+template <typename KernelType>
+KernelPCA<KernelType>::KernelPCA(const KernelType kernel,
+ const bool scaleData) :
+ kernel(kernel),
+ scaleData(scaleData)
+{ }
+
+/**
+ * Apply Kernel Principal Component Analysis to the provided data set.
+ *
+ * @param data - Data matrix
+ * @param transformedData - Data with KernelPCA applied
+ * @param eigVal - contains eigen values in a column vector
+ * @param coeff - KernelPCA Loadings/Coeffs/EigenVectors
+ */
+template <typename KernelType>
+void KernelPCA<KernelType>::Apply(const arma::mat& data,
+ arma::mat& transformedData,
+ arma::vec& eigVal,
+ arma::mat& coeffs)
+{
+ arma::mat transData = ccov(data);
+
+ // Center the data if necessary.
+
+ // Scale the data if necessary.
+ if (scaleData)
+ {
+ transData = transData / (arma::ones<arma::colvec>(transData.n_rows) *
+ stddev(transData, 0, 0));
+ }
+
+ arma::mat centeredData = trans(transData);
+ arma::mat kernelMat = GetKernelMatrix(kernel, centeredData);
+ arma::eig_sym(eigVal, coeffs, kernelMat);
+
+ int n_eigVal = eigVal.n_elem;
+ for(int i = 0; i < floor(n_eigVal / 2.0); i++)
+ eigVal.swap_rows(i, (n_eigVal - 1) - i);
+
+ coeffs = arma::fliplr(coeffs);
+
+ transformedData = trans(coeffs) * data;
+ arma::colvec transformedDataMean = arma::mean(transformedData, 1);
+ transformedData = transformedData - (transformedDataMean *
+ arma::ones<arma::rowvec>(transformedData.n_cols));
+}
+
+/**
+ * Apply Kernel Principal Component Analysis to the provided data set.
+ *
+ * @param data - Data matrix
+ * @param transformedData - Data with KernelPCA applied
+ * @param eigVal - contains eigen values in a column vector
+ */
+template <typename KernelType>
+void KernelPCA<KernelType>::Apply(const arma::mat& data,
+ arma::mat& transformedData,
+ arma::vec& eigVal)
+{
+ arma::mat coeffs;
+ Apply(data, transformedData, eigVal, coeffs);
+}
+
+/**
+ * Apply Dimensionality Reduction using Kernel Principal Component Analysis
+ * to the provided data set.
+ *
+ * @param data - M x N Data matrix
+ * @param newDimension - matrix consisting of N column vectors,
+ * where each vector is the projection of the corresponding data vector
+ * from data matrix onto the basis vectors contained in the columns of
+ * coeff/eigen vector matrix with only newDimension number of columns chosen.
+ */
+template <typename KernelType>
+void KernelPCA<KernelType>::Apply(arma::mat& data, const size_t newDimension)
+{
+ arma::mat coeffs;
+ arma::vec eigVal;
+
+ Apply(data, data, eigVal, coeffs);
+
+ if (newDimension < coeffs.n_rows && newDimension > 0)
+ data.shed_rows(newDimension, data.n_rows - 1);
+}
+
+template <typename KernelType>
+arma::mat GetKernelMatrix(KernelType kernel, arma::mat transData)
+{
+ arma::mat kernelMat(transData.n_rows, transData.n_rows);
+
+ for (size_t i = 0; i < transData.n_rows; i++)
+ {
+ for (size_t j = 0; j < transData.n_rows; j++)
+ {
+ arma::vec v1 = trans(transData.row(i));
+ arma::vec v2 = trans(transData.row(j));
+ kernelMat(i, j) = kernel.Evaluate(v1, v2);
+ }
+ }
+
+ return kernelMat;
+}
+
+}; // namespace mlpack
+}; // namespace kpca
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kernel_pca/kernel_pca_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/kernel_pca/kernel_pca_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kernel_pca/kernel_pca_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,176 +0,0 @@
-/**
- * @file kernel_pca_main.cpp
- * @author Ajinkya Kale <kaleajinkya at gmail.com>
- *
- * Executable for Kernel PCA.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/core/kernels/linear_kernel.hpp>
-#include <mlpack/core/kernels/gaussian_kernel.hpp>
-#include <mlpack/core/kernels/hyperbolic_tangent_kernel.hpp>
-#include <mlpack/core/kernels/laplacian_kernel.hpp>
-#include <mlpack/core/kernels/polynomial_kernel.hpp>
-#include <mlpack/core/kernels/cosine_distance.hpp>
-
-#include "kernel_pca.hpp"
-
-using namespace mlpack;
-using namespace mlpack::kpca;
-using namespace mlpack::kernel;
-using namespace std;
-using namespace arma;
-
-PROGRAM_INFO("Kernel Principal Components Analysis",
- "This program performs Kernel Principal Components Analysis (KPCA) on the "
- "specified dataset with the specified kernel. This will transform the "
- "data onto the kernel principal components, and optionally reduce the "
- "dimensionality by ignoring the kernel principal components with the "
- "smallest eigenvalues."
- "\n\n"
- "For the case where a linear kernel is used, this reduces to regular "
- "PCA."
- "\n\n"
- "The kernels that are supported are listed below:"
- "\n\n"
- " * 'linear': the standard linear dot product (same as normal PCA):\n"
- " K(x, y) = x^T y\n"
- "\n"
- " * 'gaussian': a Gaussian kernel; requires bandwidth:\n"
- " K(x, y) = exp(-(|| x - y || ^ 2) / (2 * (bandwidth ^ 2)))\n"
- "\n"
- " * 'polynomial': polynomial kernel; requires offset and degree:\n"
- " K(x, y) = (x^T y + offset) ^ degree\n"
- "\n"
- " * 'hyptan': hyperbolic tangent kernel; requires scale and offset:\n"
- " K(x, y) = tanh(scale * (x^T y) + offset)\n"
- "\n"
- " * 'laplacian': Laplacian kernel; requires bandwidth:\n"
- " K(x, y) = exp(-(|| x - y ||) / bandwidth)\n"
- "\n"
- " * 'cosine': cosine distance:\n"
- " K(x, y) = 1 - (x^T y) / (|| x || * || y ||)\n"
- "\n"
- "The parameters for each of the kernels should be specified with the "
- "options --bandwidth, --kernel_scale, --offset, or --degree (or a "
- "combination of those options).\n");
-
-PARAM_STRING_REQ("input_file", "Input dataset to perform KPCA on.", "i");
-PARAM_STRING_REQ("output_file", "File to save modified dataset to.", "o");
-PARAM_STRING_REQ("kernel", "The kernel to use; see the above documentation for "
- "the list of usable kernels.", "k");
-
-PARAM_INT("new_dimensionality", "If not 0, reduce the dimensionality of "
- "the output dataset by ignoring the dimensions with the smallest "
- "eigenvalues.", "d", 0);
-
-PARAM_FLAG("scale", "If set, the data will be scaled before performing KPCA "
- "such that the variance of each feature is 1.", "s");
-
-PARAM_DOUBLE("kernel_scale", "Scale, for 'hyptan' kernel.", "S", 1.0);
-PARAM_DOUBLE("offset", "Offset, for 'hyptan' and 'polynomial' kernels.", "O",
- 0.0);
-PARAM_DOUBLE("bandwidth", "Bandwidth, for 'gaussian' and 'laplacian' kernels.",
- "b", 1.0);
-PARAM_DOUBLE("degree", "Degree of polynomial, for 'polynomial' kernel.", "d",
- 1.0);
-
-int main(int argc, char** argv)
-{
- // Parse command line options.
- CLI::ParseCommandLine(argc, argv);
-
- // Load input dataset.
- mat dataset;
- const string inputFile = CLI::GetParam<string>("input_file");
- data::Load(inputFile, dataset, true); // Fatal on failure.
-
- // Get the new dimensionality, if it is necessary.
- size_t newDim = dataset.n_rows;
- if (CLI::GetParam<int>("new_dimensionality") != 0)
- {
- newDim = CLI::GetParam<int>("new_dimensionality");
-
- if (newDim > dataset.n_rows)
- {
- Log::Fatal << "New dimensionality (" << newDim
- << ") cannot be greater than existing dimensionality ("
- << dataset.n_rows << ")!" << endl;
- }
- }
-
- // Get the kernel type and make sure it is valid.
- const string kernelType = CLI::GetParam<string>("kernel");
-
- const bool scaleData = CLI::HasParam("scale");
-
- if (kernelType == "linear")
- {
- KernelPCA<LinearKernel> kpca(LinearKernel(), scaleData);
- kpca.Apply(dataset, newDim);
- }
- else if (kernelType == "gaussian")
- {
- const double bandwidth = CLI::GetParam<double>("bandwidth");
-
- GaussianKernel kernel(bandwidth);
- KernelPCA<GaussianKernel> kpca(kernel, scaleData);
- kpca.Apply(dataset, newDim);
- }
- else if (kernelType == "polynomial")
- {
- const double degree = CLI::GetParam<double>("degree");
- const double offset = CLI::GetParam<double>("offset");
-
- PolynomialKernel kernel(degree, offset);
- KernelPCA<PolynomialKernel> kpca(kernel, scaleData);
- kpca.Apply(dataset, newDim);
- }
- else if (kernelType == "hyptan")
- {
- const double scale = CLI::GetParam<double>("kernel_scale");
- const double offset = CLI::GetParam<double>("offset");
-
- HyperbolicTangentKernel kernel(scale, offset);
- KernelPCA<HyperbolicTangentKernel> kpca(kernel, scaleData);
- kpca.Apply(dataset, newDim);
- }
- else if (kernelType == "laplacian")
- {
- const double bandwidth = CLI::GetParam<double>("bandwidth");
-
- LaplacianKernel kernel(bandwidth);
- KernelPCA<LaplacianKernel> kpca(kernel, scaleData);
- kpca.Apply(dataset, newDim);
- }
- else if (kernelType == "cosine")
- {
- KernelPCA<CosineDistance> kpca(CosineDistance(), scaleData);
- kpca.Apply(dataset, newDim);
- }
- else
- {
- // Invalid kernel type.
- Log::Fatal << "Invalid kernel type ('" << kernelType << "'); valid choices "
- << "are 'linear', 'gaussian', 'polynomial', 'hyptan', 'laplacian', and "
- << "'cosine'." << endl;
- }
-
- // Save the output dataset.
- const string outputFile = CLI::GetParam<string>("output_file");
- data::Save(outputFile, dataset, true); // Fatal on failure.
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kernel_pca/kernel_pca_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/kernel_pca/kernel_pca_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kernel_pca/kernel_pca_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kernel_pca/kernel_pca_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,176 @@
+/**
+ * @file kernel_pca_main.cpp
+ * @author Ajinkya Kale <kaleajinkya at gmail.com>
+ *
+ * Executable for Kernel PCA.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/kernels/linear_kernel.hpp>
+#include <mlpack/core/kernels/gaussian_kernel.hpp>
+#include <mlpack/core/kernels/hyperbolic_tangent_kernel.hpp>
+#include <mlpack/core/kernels/laplacian_kernel.hpp>
+#include <mlpack/core/kernels/polynomial_kernel.hpp>
+#include <mlpack/core/kernels/cosine_distance.hpp>
+
+#include "kernel_pca.hpp"
+
+using namespace mlpack;
+using namespace mlpack::kpca;
+using namespace mlpack::kernel;
+using namespace std;
+using namespace arma;
+
+PROGRAM_INFO("Kernel Principal Components Analysis",
+ "This program performs Kernel Principal Components Analysis (KPCA) on the "
+ "specified dataset with the specified kernel. This will transform the "
+ "data onto the kernel principal components, and optionally reduce the "
+ "dimensionality by ignoring the kernel principal components with the "
+ "smallest eigenvalues."
+ "\n\n"
+ "For the case where a linear kernel is used, this reduces to regular "
+ "PCA."
+ "\n\n"
+ "The kernels that are supported are listed below:"
+ "\n\n"
+ " * 'linear': the standard linear dot product (same as normal PCA):\n"
+ " K(x, y) = x^T y\n"
+ "\n"
+ " * 'gaussian': a Gaussian kernel; requires bandwidth:\n"
+ " K(x, y) = exp(-(|| x - y || ^ 2) / (2 * (bandwidth ^ 2)))\n"
+ "\n"
+ " * 'polynomial': polynomial kernel; requires offset and degree:\n"
+ " K(x, y) = (x^T y + offset) ^ degree\n"
+ "\n"
+ " * 'hyptan': hyperbolic tangent kernel; requires scale and offset:\n"
+ " K(x, y) = tanh(scale * (x^T y) + offset)\n"
+ "\n"
+ " * 'laplacian': Laplacian kernel; requires bandwidth:\n"
+ " K(x, y) = exp(-(|| x - y ||) / bandwidth)\n"
+ "\n"
+ " * 'cosine': cosine distance:\n"
+ " K(x, y) = 1 - (x^T y) / (|| x || * || y ||)\n"
+ "\n"
+ "The parameters for each of the kernels should be specified with the "
+ "options --bandwidth, --kernel_scale, --offset, or --degree (or a "
+ "combination of those options).\n");
+
+PARAM_STRING_REQ("input_file", "Input dataset to perform KPCA on.", "i");
+PARAM_STRING_REQ("output_file", "File to save modified dataset to.", "o");
+PARAM_STRING_REQ("kernel", "The kernel to use; see the above documentation for "
+ "the list of usable kernels.", "k");
+
+PARAM_INT("new_dimensionality", "If not 0, reduce the dimensionality of "
+ "the output dataset by ignoring the dimensions with the smallest "
+ "eigenvalues.", "d", 0);
+
+PARAM_FLAG("scale", "If set, the data will be scaled before performing KPCA "
+ "such that the variance of each feature is 1.", "s");
+
+PARAM_DOUBLE("kernel_scale", "Scale, for 'hyptan' kernel.", "S", 1.0);
+PARAM_DOUBLE("offset", "Offset, for 'hyptan' and 'polynomial' kernels.", "O",
+ 0.0);
+PARAM_DOUBLE("bandwidth", "Bandwidth, for 'gaussian' and 'laplacian' kernels.",
+ "b", 1.0);
+PARAM_DOUBLE("degree", "Degree of polynomial, for 'polynomial' kernel.", "d",
+ 1.0);
+
+int main(int argc, char** argv)
+{
+ // Parse command line options.
+ CLI::ParseCommandLine(argc, argv);
+
+ // Load input dataset.
+ mat dataset;
+ const string inputFile = CLI::GetParam<string>("input_file");
+ data::Load(inputFile, dataset, true); // Fatal on failure.
+
+ // Get the new dimensionality, if it is necessary.
+ size_t newDim = dataset.n_rows;
+ if (CLI::GetParam<int>("new_dimensionality") != 0)
+ {
+ newDim = CLI::GetParam<int>("new_dimensionality");
+
+ if (newDim > dataset.n_rows)
+ {
+ Log::Fatal << "New dimensionality (" << newDim
+ << ") cannot be greater than existing dimensionality ("
+ << dataset.n_rows << ")!" << endl;
+ }
+ }
+
+ // Get the kernel type and make sure it is valid.
+ const string kernelType = CLI::GetParam<string>("kernel");
+
+ const bool scaleData = CLI::HasParam("scale");
+
+ if (kernelType == "linear")
+ {
+ KernelPCA<LinearKernel> kpca(LinearKernel(), scaleData);
+ kpca.Apply(dataset, newDim);
+ }
+ else if (kernelType == "gaussian")
+ {
+ const double bandwidth = CLI::GetParam<double>("bandwidth");
+
+ GaussianKernel kernel(bandwidth);
+ KernelPCA<GaussianKernel> kpca(kernel, scaleData);
+ kpca.Apply(dataset, newDim);
+ }
+ else if (kernelType == "polynomial")
+ {
+ const double degree = CLI::GetParam<double>("degree");
+ const double offset = CLI::GetParam<double>("offset");
+
+ PolynomialKernel kernel(degree, offset);
+ KernelPCA<PolynomialKernel> kpca(kernel, scaleData);
+ kpca.Apply(dataset, newDim);
+ }
+ else if (kernelType == "hyptan")
+ {
+ const double scale = CLI::GetParam<double>("kernel_scale");
+ const double offset = CLI::GetParam<double>("offset");
+
+ HyperbolicTangentKernel kernel(scale, offset);
+ KernelPCA<HyperbolicTangentKernel> kpca(kernel, scaleData);
+ kpca.Apply(dataset, newDim);
+ }
+ else if (kernelType == "laplacian")
+ {
+ const double bandwidth = CLI::GetParam<double>("bandwidth");
+
+ LaplacianKernel kernel(bandwidth);
+ KernelPCA<LaplacianKernel> kpca(kernel, scaleData);
+ kpca.Apply(dataset, newDim);
+ }
+ else if (kernelType == "cosine")
+ {
+ KernelPCA<CosineDistance> kpca(CosineDistance(), scaleData);
+ kpca.Apply(dataset, newDim);
+ }
+ else
+ {
+ // Invalid kernel type.
+ Log::Fatal << "Invalid kernel type ('" << kernelType << "'); valid choices "
+ << "are 'linear', 'gaussian', 'polynomial', 'hyptan', 'laplacian', and "
+ << "'cosine'." << endl;
+ }
+
+ // Save the output dataset.
+ const string outputFile = CLI::GetParam<string>("output_file");
+ data::Save(outputFile, dataset, true); // Fatal on failure.
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/allow_empty_clusters.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/kmeans/allow_empty_clusters.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/allow_empty_clusters.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,69 +0,0 @@
-/**
- * @file allow_empty_clusters.hpp
- * @author Ryan Curtin
- *
- * This very simple policy is used when K-Means is allowed to return empty
- * clusters.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_KMEANS_ALLOW_EMPTY_CLUSTERS_HPP
-#define __MLPACK_METHODS_KMEANS_ALLOW_EMPTY_CLUSTERS_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace kmeans {
-
-/**
- * Policy which allows K-Means to create empty clusters without any error being
- * reported.
- */
-class AllowEmptyClusters
-{
- public:
- //! Default constructor required by EmptyClusterPolicy policy.
- AllowEmptyClusters() { }
-
- /**
- * This function does nothing. It is called by K-Means when K-Means detects
- * an empty cluster.
- *
- * @tparam MatType Type of data (arma::mat or arma::spmat).
- * @param data Dataset on which clustering is being performed.
- * @param emptyCluster Index of cluster which is empty.
- * @param centroids Centroids of each cluster (one per column).
- * @param clusterCounts Number of points in each cluster.
- * @param assignments Cluster assignments of each point.
- *
- * @return Number of points changed (0).
- */
- template<typename MatType>
- static size_t EmptyCluster(const MatType& /* data */,
- const size_t /* emptyCluster */,
- const MatType& /* centroids */,
- arma::Col<size_t>& /* clusterCounts */,
- arma::Col<size_t>& /* assignments */)
- {
- // Empty clusters are okay! Do nothing.
- return 0;
- }
-};
-
-}; // namespace kmeans
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/allow_empty_clusters.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/kmeans/allow_empty_clusters.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/allow_empty_clusters.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/allow_empty_clusters.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,69 @@
+/**
+ * @file allow_empty_clusters.hpp
+ * @author Ryan Curtin
+ *
+ * This very simple policy is used when K-Means is allowed to return empty
+ * clusters.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_ALLOW_EMPTY_CLUSTERS_HPP
+#define __MLPACK_METHODS_KMEANS_ALLOW_EMPTY_CLUSTERS_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace kmeans {
+
+/**
+ * Policy which allows K-Means to create empty clusters without any error being
+ * reported.
+ */
+class AllowEmptyClusters
+{
+ public:
+ //! Default constructor required by EmptyClusterPolicy policy.
+ AllowEmptyClusters() { }
+
+ /**
+ * This function does nothing. It is called by K-Means when K-Means detects
+ * an empty cluster.
+ *
+ * @tparam MatType Type of data (arma::mat or arma::spmat).
+ * @param data Dataset on which clustering is being performed.
+ * @param emptyCluster Index of cluster which is empty.
+ * @param centroids Centroids of each cluster (one per column).
+ * @param clusterCounts Number of points in each cluster.
+ * @param assignments Cluster assignments of each point.
+ *
+ * @return Number of points changed (0).
+ */
+ template<typename MatType>
+ static size_t EmptyCluster(const MatType& /* data */,
+ const size_t /* emptyCluster */,
+ const MatType& /* centroids */,
+ arma::Col<size_t>& /* clusterCounts */,
+ arma::Col<size_t>& /* assignments */)
+ {
+ // Empty clusters are okay! Do nothing.
+ return 0;
+ }
+};
+
+}; // namespace kmeans
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/kmeans.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/kmeans/kmeans.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/kmeans.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,215 +0,0 @@
-/**
- * @file kmeans.hpp
- * @author Parikshit Ram (pram at cc.gatech.edu)
- *
- * K-Means clustering.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_KMEANS_KMEANS_HPP
-#define __MLPACK_METHODS_KMEANS_KMEANS_HPP
-
-#include <mlpack/core.hpp>
-
-#include <mlpack/core/metrics/lmetric.hpp>
-#include "random_partition.hpp"
-#include "max_variance_new_cluster.hpp"
-
-#include <mlpack/core/tree/binary_space_tree.hpp>
-
-namespace mlpack {
-namespace kmeans /** K-Means clustering. */ {
-
-/**
- * This class implements K-Means clustering. This implementation supports
- * overclustering, which means that more clusters than are requested will be
- * found; then, those clusters will be merged together to produce the desired
- * number of clusters.
- *
- * Two template parameters can (optionally) be supplied: the policy for how to
- * find the initial partition of the data, and the actions to be taken when an
- * empty cluster is encountered, as well as the distance metric to be used.
- *
- * A simple example of how to run K-Means clustering is shown below.
- *
- * @code
- * extern arma::mat data; // Dataset we want to run K-Means on.
- * arma::Col<size_t> assignments; // Cluster assignments.
- *
- * KMeans<> k; // Default options.
- * k.Cluster(data, 3, assignments); // 3 clusters.
- *
- * // Cluster using the Manhattan distance, 100 iterations maximum, and an
- * // overclustering factor of 4.0.
- * KMeans<metric::ManhattanDistance> k(100, 4.0);
- * k.Cluster(data, 6, assignments); // 6 clusters.
- * @endcode
- *
- * @tparam MetricType The distance metric to use for this KMeans; see
- * metric::LMetric for an example.
- * @tparam InitialPartitionPolicy Initial partitioning policy; must implement a
- * default constructor and 'void Cluster(const arma::mat&, const size_t,
- * arma::Col<size_t>&)'.
- * @tparam EmptyClusterPolicy Policy for what to do on an empty cluster; must
- * implement a default constructor and 'void EmptyCluster(const arma::mat&,
- * arma::Col<size_t&)'.
- *
- * @see RandomPartition, RefinedStart, AllowEmptyClusters, MaxVarianceNewCluster
- */
-template<typename MetricType = metric::SquaredEuclideanDistance,
- typename InitialPartitionPolicy = RandomPartition,
- typename EmptyClusterPolicy = MaxVarianceNewCluster>
-class KMeans
-{
- public:
- /**
- * Create a K-Means object and (optionally) set the parameters which K-Means
- * will be run with. This implementation allows a few strategies to improve
- * the performance of K-Means, including "overclustering" and disallowing
- * empty clusters.
- *
- * The overclustering factor controls how many clusters are
- * actually found; for instance, with an overclustering factor of 4, if
- * K-Means is run to find 3 clusters, it will actually find 12, then merge the
- * nearest clusters until only 3 are left.
- *
- * @param maxIterations Maximum number of iterations allowed before giving up
- * (0 is valid, but the algorithm may never terminate).
- * @param overclusteringFactor Factor controlling how many extra clusters are
- * found and then merged to get the desired number of clusters.
- * @param metric Optional MetricType object; for when the metric has state
- * it needs to store.
- * @param partitioner Optional InitialPartitionPolicy object; for when a
- * specially initialized partitioning policy is required.
- * @param emptyClusterAction Optional EmptyClusterPolicy object; for when a
- * specially initialized empty cluster policy is required.
- */
- KMeans(const size_t maxIterations = 1000,
- const double overclusteringFactor = 1.0,
- const MetricType metric = MetricType(),
- const InitialPartitionPolicy partitioner = InitialPartitionPolicy(),
- const EmptyClusterPolicy emptyClusterAction = EmptyClusterPolicy());
-
-
- /**
- * Perform k-means clustering on the data, returning a list of cluster
- * assignments. Optionally, the vector of assignments can be set to an
- * initial guess of the cluster assignments; to do this, set initialGuess to
- * true.
- *
- * @tparam MatType Type of matrix (arma::mat or arma::sp_mat).
- * @param data Dataset to cluster.
- * @param clusters Number of clusters to compute.
- * @param assignments Vector to store cluster assignments in.
- * @param initialGuess If true, then it is assumed that assignments has a list
- * of initial cluster assignments.
- */
- template<typename MatType>
- void Cluster(const MatType& data,
- const size_t clusters,
- arma::Col<size_t>& assignments,
- const bool initialGuess = false) const;
-
- /**
- * Perform k-means clustering on the data, returning a list of cluster
- * assignments and also the centroids of each cluster. Optionally, the vector
- * of assignments can be set to an initial guess of the cluster assignments;
- * to do this, set initialAssignmentGuess to true. Another way to set initial
- * cluster guesses is to fill the centroids matrix with the centroid guesses,
- * and then set initialCentroidGuess to true. initialAssignmentGuess
- * supersedes initialCentroidGuess, so if both are set to true, the
- * assignments vector is used.
- *
- * Note that if the overclustering factor is greater than 1, the centroids
- * matrix will be resized in the method. Regardless of the overclustering
- * factor, the centroid guess matrix (if initialCentroidGuess is set to true)
- * should have the same number of rows as the data matrix, and number of
- * columns equal to 'clusters'.
- *
- * @tparam MatType Type of matrix (arma::mat or arma::sp_mat).
- * @param data Dataset to cluster.
- * @param clusters Number of clusters to compute.
- * @param assignments Vector to store cluster assignments in.
- * @param centroids Matrix in which centroids are stored.
- * @param initialAssignmentGuess If true, then it is assumed that assignments
- * has a list of initial cluster assignments.
- * @param initialCentroidGuess If true, then it is assumed that centroids
- * contains the initial centroids of each cluster.
- */
- template<typename MatType>
- void Cluster(const MatType& data,
- const size_t clusters,
- arma::Col<size_t>& assignments,
- MatType& centroids,
- const bool initialAssignmentGuess = false,
- const bool initialCentroidGuess = false) const;
-
- /**
- * An implementation of k-means using the Pelleg-Moore algorithm; this is
- * known to not work -- do not use it! (Fixing it is TODO, of course; see
- * #251.)
- */
- template<typename MatType>
- void FastCluster(MatType& data,
- const size_t clusters,
- arma::Col<size_t>& assignments) const;
-
- //! Return the overclustering factor.
- double OverclusteringFactor() const { return overclusteringFactor; }
- //! Set the overclustering factor. Must be greater than 1.
- double& OverclusteringFactor() { return overclusteringFactor; }
-
- //! Get the maximum number of iterations.
- size_t MaxIterations() const { return maxIterations; }
- //! Set the maximum number of iterations.
- size_t& MaxIterations() { return maxIterations; }
-
- //! Get the distance metric.
- const MetricType& Metric() const { return metric; }
- //! Modify the distance metric.
- MetricType& Metric() { return metric; }
-
- //! Get the initial partitioning policy.
- const InitialPartitionPolicy& Partitioner() const { return partitioner; }
- //! Modify the initial partitioning policy.
- InitialPartitionPolicy& Partitioner() { return partitioner; }
-
- //! Get the empty cluster policy.
- const EmptyClusterPolicy& EmptyClusterAction() const
- { return emptyClusterAction; }
- //! Modify the empty cluster policy.
- EmptyClusterPolicy& EmptyClusterAction() { return emptyClusterAction; }
-
- private:
- //! Factor controlling how many clusters are actually found.
- double overclusteringFactor;
- //! Maximum number of iterations before giving up.
- size_t maxIterations;
- //! Instantiated distance metric.
- MetricType metric;
- //! Instantiated initial partitioning policy.
- InitialPartitionPolicy partitioner;
- //! Instantiated empty cluster policy.
- EmptyClusterPolicy emptyClusterAction;
-};
-
-}; // namespace kmeans
-}; // namespace mlpack
-
-// Include implementation.
-#include "kmeans_impl.hpp"
-
-#endif // __MLPACK_METHODS_MOG_KMEANS_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/kmeans.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/kmeans/kmeans.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/kmeans.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/kmeans.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,215 @@
+/**
+ * @file kmeans.hpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * K-Means clustering.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_KMEANS_HPP
+#define __MLPACK_METHODS_KMEANS_KMEANS_HPP
+
+#include <mlpack/core.hpp>
+
+#include <mlpack/core/metrics/lmetric.hpp>
+#include "random_partition.hpp"
+#include "max_variance_new_cluster.hpp"
+
+#include <mlpack/core/tree/binary_space_tree.hpp>
+
+namespace mlpack {
+namespace kmeans /** K-Means clustering. */ {
+
+/**
+ * This class implements K-Means clustering. This implementation supports
+ * overclustering, which means that more clusters than are requested will be
+ * found; then, those clusters will be merged together to produce the desired
+ * number of clusters.
+ *
+ * Two template parameters can (optionally) be supplied: the policy for how to
+ * find the initial partition of the data, and the actions to be taken when an
+ * empty cluster is encountered, as well as the distance metric to be used.
+ *
+ * A simple example of how to run K-Means clustering is shown below.
+ *
+ * @code
+ * extern arma::mat data; // Dataset we want to run K-Means on.
+ * arma::Col<size_t> assignments; // Cluster assignments.
+ *
+ * KMeans<> k; // Default options.
+ * k.Cluster(data, 3, assignments); // 3 clusters.
+ *
+ * // Cluster using the Manhattan distance, 100 iterations maximum, and an
+ * // overclustering factor of 4.0.
+ * KMeans<metric::ManhattanDistance> k(100, 4.0);
+ * k.Cluster(data, 6, assignments); // 6 clusters.
+ * @endcode
+ *
+ * @tparam MetricType The distance metric to use for this KMeans; see
+ * metric::LMetric for an example.
+ * @tparam InitialPartitionPolicy Initial partitioning policy; must implement a
+ * default constructor and 'void Cluster(const arma::mat&, const size_t,
+ * arma::Col<size_t>&)'.
+ * @tparam EmptyClusterPolicy Policy for what to do on an empty cluster; must
+ * implement a default constructor and 'void EmptyCluster(const arma::mat&,
+ * arma::Col<size_t&)'.
+ *
+ * @see RandomPartition, RefinedStart, AllowEmptyClusters, MaxVarianceNewCluster
+ */
+template<typename MetricType = metric::SquaredEuclideanDistance,
+ typename InitialPartitionPolicy = RandomPartition,
+ typename EmptyClusterPolicy = MaxVarianceNewCluster>
+class KMeans
+{
+ public:
+ /**
+ * Create a K-Means object and (optionally) set the parameters which K-Means
+ * will be run with. This implementation allows a few strategies to improve
+ * the performance of K-Means, including "overclustering" and disallowing
+ * empty clusters.
+ *
+ * The overclustering factor controls how many clusters are
+ * actually found; for instance, with an overclustering factor of 4, if
+ * K-Means is run to find 3 clusters, it will actually find 12, then merge the
+ * nearest clusters until only 3 are left.
+ *
+ * @param maxIterations Maximum number of iterations allowed before giving up
+ * (0 is valid, but the algorithm may never terminate).
+ * @param overclusteringFactor Factor controlling how many extra clusters are
+ * found and then merged to get the desired number of clusters.
+ * @param metric Optional MetricType object; for when the metric has state
+ * it needs to store.
+ * @param partitioner Optional InitialPartitionPolicy object; for when a
+ * specially initialized partitioning policy is required.
+ * @param emptyClusterAction Optional EmptyClusterPolicy object; for when a
+ * specially initialized empty cluster policy is required.
+ */
+ KMeans(const size_t maxIterations = 1000,
+ const double overclusteringFactor = 1.0,
+ const MetricType metric = MetricType(),
+ const InitialPartitionPolicy partitioner = InitialPartitionPolicy(),
+ const EmptyClusterPolicy emptyClusterAction = EmptyClusterPolicy());
+
+
+ /**
+ * Perform k-means clustering on the data, returning a list of cluster
+ * assignments. Optionally, the vector of assignments can be set to an
+ * initial guess of the cluster assignments; to do this, set initialGuess to
+ * true.
+ *
+ * @tparam MatType Type of matrix (arma::mat or arma::sp_mat).
+ * @param data Dataset to cluster.
+ * @param clusters Number of clusters to compute.
+ * @param assignments Vector to store cluster assignments in.
+ * @param initialGuess If true, then it is assumed that assignments has a list
+ * of initial cluster assignments.
+ */
+ template<typename MatType>
+ void Cluster(const MatType& data,
+ const size_t clusters,
+ arma::Col<size_t>& assignments,
+ const bool initialGuess = false) const;
+
+ /**
+ * Perform k-means clustering on the data, returning a list of cluster
+ * assignments and also the centroids of each cluster. Optionally, the vector
+ * of assignments can be set to an initial guess of the cluster assignments;
+ * to do this, set initialAssignmentGuess to true. Another way to set initial
+ * cluster guesses is to fill the centroids matrix with the centroid guesses,
+ * and then set initialCentroidGuess to true. initialAssignmentGuess
+ * supersedes initialCentroidGuess, so if both are set to true, the
+ * assignments vector is used.
+ *
+ * Note that if the overclustering factor is greater than 1, the centroids
+ * matrix will be resized in the method. Regardless of the overclustering
+ * factor, the centroid guess matrix (if initialCentroidGuess is set to true)
+ * should have the same number of rows as the data matrix, and number of
+ * columns equal to 'clusters'.
+ *
+ * @tparam MatType Type of matrix (arma::mat or arma::sp_mat).
+ * @param data Dataset to cluster.
+ * @param clusters Number of clusters to compute.
+ * @param assignments Vector to store cluster assignments in.
+ * @param centroids Matrix in which centroids are stored.
+ * @param initialAssignmentGuess If true, then it is assumed that assignments
+ * has a list of initial cluster assignments.
+ * @param initialCentroidGuess If true, then it is assumed that centroids
+ * contains the initial centroids of each cluster.
+ */
+ template<typename MatType>
+ void Cluster(const MatType& data,
+ const size_t clusters,
+ arma::Col<size_t>& assignments,
+ MatType& centroids,
+ const bool initialAssignmentGuess = false,
+ const bool initialCentroidGuess = false) const;
+
+ /**
+ * An implementation of k-means using the Pelleg-Moore algorithm; this is
+ * known to not work -- do not use it! (Fixing it is TODO, of course; see
+ * #251.)
+ */
+ template<typename MatType>
+ void FastCluster(MatType& data,
+ const size_t clusters,
+ arma::Col<size_t>& assignments) const;
+
+ //! Return the overclustering factor.
+ double OverclusteringFactor() const { return overclusteringFactor; }
+ //! Set the overclustering factor. Must be greater than 1.
+ double& OverclusteringFactor() { return overclusteringFactor; }
+
+ //! Get the maximum number of iterations.
+ size_t MaxIterations() const { return maxIterations; }
+ //! Set the maximum number of iterations.
+ size_t& MaxIterations() { return maxIterations; }
+
+ //! Get the distance metric.
+ const MetricType& Metric() const { return metric; }
+ //! Modify the distance metric.
+ MetricType& Metric() { return metric; }
+
+ //! Get the initial partitioning policy.
+ const InitialPartitionPolicy& Partitioner() const { return partitioner; }
+ //! Modify the initial partitioning policy.
+ InitialPartitionPolicy& Partitioner() { return partitioner; }
+
+ //! Get the empty cluster policy.
+ const EmptyClusterPolicy& EmptyClusterAction() const
+ { return emptyClusterAction; }
+ //! Modify the empty cluster policy.
+ EmptyClusterPolicy& EmptyClusterAction() { return emptyClusterAction; }
+
+ private:
+ //! Factor controlling how many clusters are actually found.
+ double overclusteringFactor;
+ //! Maximum number of iterations before giving up.
+ size_t maxIterations;
+ //! Instantiated distance metric.
+ MetricType metric;
+ //! Instantiated initial partitioning policy.
+ InitialPartitionPolicy partitioner;
+ //! Instantiated empty cluster policy.
+ EmptyClusterPolicy emptyClusterAction;
+};
+
+}; // namespace kmeans
+}; // namespace mlpack
+
+// Include implementation.
+#include "kmeans_impl.hpp"
+
+#endif // __MLPACK_METHODS_MOG_KMEANS_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/kmeans_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/kmeans/kmeans_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/kmeans_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,822 +0,0 @@
-/**
- * @file kmeans_impl.hpp
- * @author Parikshit Ram (pram at cc.gatech.edu)
- * @author Ryan Curtin
- *
- * Implementation for the K-means method for getting an initial point.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "kmeans.hpp"
-
-#include <mlpack/core/tree/mrkd_statistic.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-
-#include <stack>
-#include <limits>
-
-namespace mlpack {
-namespace kmeans {
-
-/**
- * Construct the K-Means object.
- */
-template<typename MetricType,
- typename InitialPartitionPolicy,
- typename EmptyClusterPolicy>
-KMeans<
- MetricType,
- InitialPartitionPolicy,
- EmptyClusterPolicy>::
-KMeans(const size_t maxIterations,
- const double overclusteringFactor,
- const MetricType metric,
- const InitialPartitionPolicy partitioner,
- const EmptyClusterPolicy emptyClusterAction) :
- maxIterations(maxIterations),
- metric(metric),
- partitioner(partitioner),
- emptyClusterAction(emptyClusterAction)
-{
- // Validate overclustering factor.
- if (overclusteringFactor < 1.0)
- {
- Log::Warn << "KMeans::KMeans(): overclustering factor must be >= 1.0 ("
- << overclusteringFactor << " given). Setting factor to 1.0.\n";
- this->overclusteringFactor = 1.0;
- }
- else
- {
- this->overclusteringFactor = overclusteringFactor;
- }
-}
-
-template<typename MetricType,
- typename InitialPartitionPolicy,
- typename EmptyClusterPolicy>
-template<typename MatType>
-void KMeans<
- MetricType,
- InitialPartitionPolicy,
- EmptyClusterPolicy>::
-FastCluster(MatType& data,
- const size_t clusters,
- arma::Col<size_t>& assignments) const
-{
- size_t actualClusters = size_t(overclusteringFactor * clusters);
- if (actualClusters > data.n_cols)
- {
- Log::Warn << "KMeans::Cluster(): overclustering factor is too large. No "
- << "overclustering will be done." << std::endl;
- actualClusters = clusters;
- }
-
- size_t dimensionality = data.n_rows;
-
- // Centroids of each cluster. Each column corresponds to a centroid.
- MatType centroids(dimensionality, actualClusters);
- centroids.zeros();
-
- // Counts of points in each cluster.
- arma::Col<size_t> counts(actualClusters);
- counts.zeros();
-
- // Build the mrkd-tree on this dataset.
- tree::BinarySpaceTree<typename bound::HRectBound<2>, tree::MRKDStatistic>
- tree(data, 1);
- Log::Debug << "Tree Built." << std::endl;
- // A pointer for traversing the mrkd-tree.
- tree::BinarySpaceTree<typename bound::HRectBound<2>, tree::MRKDStatistic>*
- node;
-
- // Now, the initial assignments. First determine if they are necessary.
- if (assignments.n_elem != data.n_cols)
- {
- // Use the partitioner to come up with the partition assignments.
- partitioner.Cluster(data, actualClusters, assignments);
- }
-
- // Set counts correctly.
- for (size_t i = 0; i < assignments.n_elem; i++)
- counts[assignments[i]]++;
-
- // Sum the points for each centroid
- for (size_t i = 0; i < data.n_cols; i++)
- centroids.col(assignments[i]) += data.col(i);
-
- // Then divide the sums by the count to get the center of mass for this
- // centroids assigned points
- for (size_t i = 0; i < actualClusters; i++)
- centroids.col(i) /= counts[i];
-
- // Instead of retraversing the tree after an iteration, we will update
- // centroid positions in this matrix, which also prevents clobbering our
- // centroids from the previous iteration.
- MatType newCentroids(dimensionality, centroids.n_cols);
-
- // Create a stack for traversing the mrkd-tree.
- std::stack<typename tree::BinarySpaceTree<typename bound::HRectBound<2>,
- tree::MRKDStatistic>* > stack;
-
- // A variable to keep track of how many kmeans iterations we have made.
- size_t iteration = 0;
-
- // A variable to keep track of how many nodes assignments have changed in
- // each kmeans iteration.
- size_t changedAssignments = 0;
-
- // A variable to keep track of the number of times something is skipped due
- // to the blacklist.
- size_t skip = 0;
-
- // A variable to keep track of the number of distances calculated.
- size_t comps = 0;
-
- // A variable to keep track of how often we stop at a parent node.
- size_t dominations = 0;
- do
- {
- // Keep track of what iteration we are on.
- ++iteration;
- changedAssignments = 0;
-
- // Reset the newCentroids so that we can store the newly calculated ones
- // here.
- newCentroids.zeros();
-
- // Reset the counts.
- counts.zeros();
-
- // Add the root node of the tree to the stack.
- stack.push(&tree);
- // Set the top level whitelist.
- tree.Stat().Whitelist().resize(centroids.n_cols, true);
-
- // Traverse the tree.
- while (!stack.empty())
- {
- // Get the next node in the tree.
- node = stack.top();
- // Remove the node from the stack.
- stack.pop();
-
- // Get a reference to the mrkd statistic for this hyperrectangle.
- tree::MRKDStatistic& mrkd = node->Stat();
-
- // We use this to store the index of the centroid with the minimum
- // distance from this hyperrectangle or point.
- size_t minIndex = 0;
-
- // If this node is a leaf, then we calculate the distance from
- // the centroids to every point the node contains.
- if (node->IsLeaf())
- {
- for (size_t i = mrkd.Begin(); i < mrkd.Count() + mrkd.Begin(); ++i)
- {
- // Initialize minDistance to be nonzero.
- double minDistance = metric.Evaluate(data.col(i), centroids.col(0));
-
- // Find the minimal distance centroid for this point.
- for (size_t j = 1; j < centroids.n_cols; ++j)
- {
- // If this centroid is not in the whitelist, skip it.
- if (!mrkd.Whitelist()[j])
- {
- ++skip;
- continue;
- }
-
- ++comps;
- double distance = metric.Evaluate(data.col(i), centroids.col(j));
- if (minDistance > distance)
- {
- minIndex = j;
- minDistance = distance;
- }
- }
-
- // Add this point to the undivided center of mass summation for its
- // assigned centroid.
- newCentroids.col(minIndex) += data.col(i);
-
- // Increment the count for the minimum distance centroid.
- ++counts(minIndex);
-
- // If we actually changed assignments, increment changedAssignments
- // and modify the assignment vector for this point.
- if (assignments(i) != minIndex)
- {
- ++changedAssignments;
- assignments(i) = minIndex;
- }
- }
- }
- // If this node is not a leaf, then we continue trying to find dominant
- // centroids.
- else
- {
- bound::HRectBound<2>& bound = node->Bound();
-
- // A flag to keep track of if we find a single centroid that is closer
- // to all points in this hyperrectangle than any other centroid.
- bool noDomination = false;
-
- // Calculate the center of mass of this hyperrectangle.
- arma::vec center = mrkd.CenterOfMass() / mrkd.Count();
-
- // Set the minDistance to the maximum value of a double so any value
- // must be smaller than this.
- double minDistance = std::numeric_limits<double>::max();
-
- // The candidate distance we calculate for each centroid.
- double distance = 0.0;
-
- // How many points are inside this hyperrectangle, we stop if we
- // see more than 1.
- size_t contains = 0;
-
- // Find the "owner" of this hyperrectangle, if one exists.
- for (size_t i = 0; i < centroids.n_cols; ++i)
- {
- // If this centroid is not in the whitelist, skip it.
- if (!mrkd.Whitelist()[i])
- {
- ++skip;
- continue;
- }
-
- // Incrememnt the number of distance calculations for what we are
- // about to do.
- comps += 2;
-
- // Reinitialize the distance so += works right.
- distance = 0.0;
-
- // We keep track of how many dimensions have nonzero distance,
- // if this is 0 then the distance is 0.
- size_t nonZero = 0;
-
- /*
- Compute the distance to the hyperrectangle for this centroid.
- We do this by finding the furthest point from the centroid inside
- the hyperrectangle. This is a corner of the hyperrectangle.
-
- In order to do this faster, we calculate both the distance and the
- furthest point simultaneously.
-
- This following code is equivalent to, but faster than:
-
- arma::vec p;
- p.zeros(dimensionality);
-
- for (size_t j = 0; j < dimensionality; ++j)
- {
- if (centroids(j,i) < bound[j].Lo())
- p(j) = bound[j].Lo();
- else
- p(j) = bound[j].Hi();
- }
-
- distance = metric.Evaluate(p.col(0), centroids.col(i));
- */
- for (size_t j = 0; j < dimensionality; ++j)
- {
- double ij = centroids(j,i);
- double lo = bound[j].Lo();
-
- if (ij < lo)
- {
- // (ij - lo)^2
- ij -= lo;
- ij *= ij;
-
- distance += ij;
- ++nonZero;
- }
- else
- {
- double hi = bound[j].Hi();
- if (ij > hi)
- {
- // (ij - hi)^2
- ij -= hi;
- ij *= ij;
-
- distance += ij;
- ++nonZero;
- }
- }
- }
-
- // The centroid is inside the hyperrectangle.
- if (nonZero == 0)
- {
- ++contains;
- minDistance = 0.0;
- minIndex = i;
-
- // If more than two points are within this hyperrectangle, then
- // there can be no dominating centroid, so we should continue
- // to the children nodes.
- if (contains > 1)
- {
- noDomination = true;
- break;
- }
- }
-
- if (fabs(distance - minDistance) <= 1e-10)
- {
- noDomination = true;
- break;
- }
- else if (distance < minDistance)
- {
- minIndex = i;
- minDistance = distance;
- }
- }
-
- distance = minDistance;
- // Determine if the owner dominates this centroid only if there was
- // exactly one owner.
- if (!noDomination)
- {
- for (size_t i = 0; i < centroids.n_cols; ++i)
- {
- if (i == minIndex)
- continue;
- // If this centroid is blacklisted for this hyperrectangle, then
- // we skip it.
- if (!mrkd.Whitelist()[i])
- {
- ++skip;
- continue;
- }
- /*
- Compute the dominating centroid for this hyperrectangle, if one
- exists. We do this by calculating the point which is furthest
- from the min'th centroid in the direction of c_k - c_min. We do
- this as outlined in the Pelleg and Moore paper.
-
- This following code is equivalent to, but faster than:
-
- arma::vec p;
- p.zeros(dimensionality);
-
- for (size_t k = 0; k < dimensionality; ++k)
- {
- p(k) = (centroids(k,i) > centroids(k,minIndex)) ?
- bound[k].Hi() : bound[k].Lo();
- }
-
- double distancei = metric.Evaluate(p.col(0), centroids.col(i));
- double distanceMin = metric.Evaluate(p.col(0),
- centroids.col(minIndex));
- */
-
- comps += 1;
- double distancei = 0.0;
- double distanceMin = 0.0;
- for (size_t k = 0; k < dimensionality; ++k)
- {
- double ci = centroids(k, i);
- double cm = centroids(k, minIndex);
- if (ci > cm)
- {
- double hi = bound[k].Hi();
-
- ci -= hi;
- cm -= hi;
-
- ci *= ci;
- cm *= cm;
-
- distancei += ci;
- distanceMin += cm;
- }
- else
- {
- double lo = bound[k].Lo();
-
- ci -= lo;
- cm -= lo;
-
- ci *= ci;
- cm *= cm;
-
- distancei += ci;
- distanceMin += cm;
- }
- }
-
- if (distanceMin >= distancei)
- {
- noDomination = true;
- break;
- }
- else
- {
- mrkd.Whitelist()[i] = false;
- }
- }
- }
-
- // If did found a centroid that was closer to every point in the
- // hyperrectangle than every other centroid, then update that centroid.
- if (!noDomination)
- {
- // Adjust the new centroid sum for the min distance point to this
- // hyperrectangle by the center of mass of this hyperrectangle.
- newCentroids.col(minIndex) += mrkd.CenterOfMass();
-
- // Increment the counts for this centroid.
- counts(minIndex) += mrkd.Count();
-
- // Update all assignments for this node.
- const size_t begin = node->Begin();
- const size_t end = node->End();
-
- // TODO: Do this outside of the kmeans iterations.
- for (size_t j = begin; j < end; ++j)
- {
- if (assignments(j) != minIndex)
- {
- ++changedAssignments;
- assignments(j) = minIndex;
- }
- }
- mrkd.DominatingCentroid() = minIndex;
-
- // Keep track of the number of times we found a dominating centroid.
- ++dominations;
- }
-
- // If we did not find a dominating centroid then we fall through to the
- // default case, where we add the children of this node to the stack.
- else
- {
- // Add this hyperrectangle's children to our stack.
- stack.push(node->Left());
- stack.push(node->Right());
-
- // (Re)Initialize the whiteList for the children.
- node->Left()->Stat().Whitelist() = mrkd.Whitelist();
- node->Right()->Stat().Whitelist() = mrkd.Whitelist();
- }
- }
-
- }
-
- // Divide by the number of points assigned to the centroids so that we
- // have the actual center of mass and update centroids' positions.
- for (size_t i = 0; i < centroids.n_cols; ++i)
- if (counts(i))
- centroids.col(i) = newCentroids.col(i) / counts(i);
-
- // Stop when we reach max iterations or we changed no assignments
- // assignments.
- } while (changedAssignments > 0 && iteration != maxIterations);
-
- Log::Info << "Iterations: " << iteration << std::endl
- << "Skips: " << skip << std::endl
- << "Comparisons: " << comps << std::endl
- << "Dominations: " << dominations << std::endl;
-}
-
-/**
- * Perform k-means clustering on the data, returning a list of cluster
- * assignments. This just forward to the other function, which returns the
- * centroids too. If this is properly inlined, there shouldn't be any
- * performance penalty whatsoever.
- */
-template<typename MetricType,
- typename InitialPartitionPolicy,
- typename EmptyClusterPolicy>
-template<typename MatType>
-inline void KMeans<
- MetricType,
- InitialPartitionPolicy,
- EmptyClusterPolicy>::
-Cluster(const MatType& data,
- const size_t clusters,
- arma::Col<size_t>& assignments,
- const bool initialGuess) const
-{
- MatType centroids(data.n_rows, clusters);
- Cluster(data, clusters, assignments, centroids, initialGuess);
-}
-
-/**
- * Perform k-means clustering on the data, returning a list of cluster
- * assignments and the centroids of each cluster.
- */
-template<typename MetricType,
- typename InitialPartitionPolicy,
- typename EmptyClusterPolicy>
-template<typename MatType>
-void KMeans<
- MetricType,
- InitialPartitionPolicy,
- EmptyClusterPolicy>::
-Cluster(const MatType& data,
- const size_t clusters,
- arma::Col<size_t>& assignments,
- MatType& centroids,
- const bool initialAssignmentGuess,
- const bool initialCentroidGuess) const
-{
- // Make sure we have more points than clusters.
- if (clusters > data.n_cols)
- Log::Warn << "KMeans::Cluster(): more clusters requested than points given."
- << std::endl;
-
- // Make sure our overclustering factor is valid.
- size_t actualClusters = size_t(overclusteringFactor * clusters);
- if (actualClusters > data.n_cols && overclusteringFactor != 1.0)
- {
- Log::Warn << "KMeans::Cluster(): overclustering factor is too large. No "
- << "overclustering will be done." << std::endl;
- actualClusters = clusters;
- }
-
- // Now, the initial assignments. First determine if they are necessary.
- if (initialAssignmentGuess)
- {
- if (assignments.n_elem != data.n_cols)
- Log::Fatal << "KMeans::Cluster(): initial cluster assignments (length "
- << assignments.n_elem << ") not the same size as the dataset (size "
- << data.n_cols << ")!" << std::endl;
- }
- else if (initialCentroidGuess)
- {
- if (centroids.n_cols != clusters)
- Log::Fatal << "KMeans::Cluster(): wrong number of initial cluster "
- << "centroids (" << centroids.n_cols << ", should be " << clusters
- << ")!" << std::endl;
-
- if (centroids.n_rows != data.n_rows)
- Log::Fatal << "KMeans::Cluster(): initial cluster centroids have wrong "
- << " dimensionality (" << centroids.n_rows << ", should be "
- << data.n_rows << ")!" << std::endl;
-
- // If there were no problems, construct the initial assignments from the
- // given centroids.
- assignments.set_size(data.n_cols);
- for (size_t i = 0; i < data.n_cols; ++i)
- {
- // Find the closest centroid to this point.
- double minDistance = std::numeric_limits<double>::infinity();
- size_t closestCluster = clusters; // Invalid value.
-
- for (size_t j = 0; j < clusters; j++)
- {
- double distance = metric.Evaluate(data.col(i), centroids.col(j));
-
- if (distance < minDistance)
- {
- minDistance = distance;
- closestCluster = j;
- }
- }
-
- // Assign the point to the closest cluster that we found.
- assignments[i] = closestCluster;
- }
- }
- else
- {
- // Use the partitioner to come up with the partition assignments.
- partitioner.Cluster(data, actualClusters, assignments);
- }
-
- // Counts of points in each cluster.
- arma::Col<size_t> counts(actualClusters);
- counts.zeros();
-
- // Resize to correct size.
- centroids.set_size(data.n_rows, actualClusters);
-
- // Set counts correctly.
- for (size_t i = 0; i < assignments.n_elem; i++)
- counts[assignments[i]]++;
-
- size_t changedAssignments = 0;
- size_t iteration = 0;
- do
- {
- // Update step.
- // Calculate centroids based on given assignments.
- centroids.zeros();
-
- for (size_t i = 0; i < data.n_cols; i++)
- centroids.col(assignments[i]) += data.col(i);
-
- for (size_t i = 0; i < actualClusters; i++)
- centroids.col(i) /= counts[i];
-
- // Assignment step.
- // Find the closest centroid to each point. We will keep track of how many
- // assignments change. When no assignments change, we are done.
- changedAssignments = 0;
- for (size_t i = 0; i < data.n_cols; i++)
- {
- // Find the closest centroid to this point.
- double minDistance = std::numeric_limits<double>::infinity();
- size_t closestCluster = actualClusters; // Invalid value.
-
- for (size_t j = 0; j < actualClusters; j++)
- {
- double distance = metric.Evaluate(data.col(i), centroids.col(j));
-
- if (distance < minDistance)
- {
- minDistance = distance;
- closestCluster = j;
- }
- }
-
- // Reassign this point to the closest cluster.
- if (assignments[i] != closestCluster)
- {
- // Update counts.
- counts[assignments[i]]--;
- counts[closestCluster]++;
- // Update assignment.
- assignments[i] = closestCluster;
- changedAssignments++;
- }
- }
-
- // If we are not allowing empty clusters, then check that all of our
- // clusters have points.
- for (size_t i = 0; i < actualClusters; i++)
- if (counts[i] == 0)
- changedAssignments += emptyClusterAction.EmptyCluster(data, i,
- centroids, counts, assignments);
-
- iteration++;
-
- } while (changedAssignments > 0 && iteration != maxIterations);
-
- if (iteration != maxIterations)
- {
- Log::Debug << "KMeans::Cluster(): converged after " << iteration
- << " iterations." << std::endl;
- }
- else
- {
- Log::Debug << "KMeans::Cluster(): terminated after limit of " << iteration
- << " iterations." << std::endl;
-
- // Recalculate final clusters.
- centroids.zeros();
-
- for (size_t i = 0; i < data.n_cols; i++)
- centroids.col(assignments[i]) += data.col(i);
-
- for (size_t i = 0; i < actualClusters; i++)
- centroids.col(i) /= counts[i];
- }
-
- // If we have overclustered, we need to merge the nearest clusters.
- if (actualClusters != clusters)
- {
- // Generate a list of all the clusters' distances from each other. This
- // list will become mangled and unused as the number of clusters decreases.
- size_t numDistances = ((actualClusters - 1) * actualClusters) / 2;
- size_t clustersLeft = actualClusters;
- arma::vec distances(numDistances);
- arma::Col<size_t> firstCluster(numDistances);
- arma::Col<size_t> secondCluster(numDistances);
-
- // Keep the mappings of clusters that we are changing.
- arma::Col<size_t> mappings = arma::linspace<arma::Col<size_t> >(0,
- actualClusters - 1, actualClusters);
-
- size_t i = 0;
- for (size_t first = 0; first < actualClusters; first++)
- {
- for (size_t second = first + 1; second < actualClusters; second++)
- {
- distances(i) = metric.Evaluate(centroids.col(first),
- centroids.col(second));
- firstCluster(i) = first;
- secondCluster(i) = second;
- i++;
- }
- }
-
- while (clustersLeft != clusters)
- {
- arma::uword minIndex;
- distances.min(minIndex);
-
- // Now we merge the clusters which that distance belongs to.
- size_t first = firstCluster(minIndex);
- size_t second = secondCluster(minIndex);
- for (size_t j = 0; j < assignments.n_elem; j++)
- if (assignments(j) == second)
- assignments(j) = first;
-
- // Now merge the centroids.
- centroids.col(first) *= counts[first];
- centroids.col(first) += (counts[second] * centroids.col(second));
- centroids.col(first) /= (counts[first] + counts[second]);
-
- // Update the counts.
- counts[first] += counts[second];
- counts[second] = 0;
-
- // Now update all the relevant distances.
- // First the distances where either cluster is the second cluster.
- for (size_t cluster = 0; cluster < second; cluster++)
- {
- // The offset is sum^n i - sum^(n - m) i, where n is actualClusters and
- // m is the cluster we are trying to offset to.
- size_t offset = (size_t) (((actualClusters - 1) * cluster)
- + (cluster - pow(cluster, 2.0)) / 2) - 1;
-
- // See if we need to update the distance from this cluster to the first
- // cluster.
- if (cluster < first)
- {
- // Make sure it isn't already DBL_MAX.
- if (distances(offset + (first - cluster)) != DBL_MAX)
- distances(offset + (first - cluster)) = metric.Evaluate(
- centroids.col(first), centroids.col(cluster));
- }
-
- distances(offset + (second - cluster)) = DBL_MAX;
- }
-
- // Now the distances where the first cluster is the first cluster.
- size_t offset = (size_t) (((actualClusters - 1) * first)
- + (first - pow(first, 2.0)) / 2) - 1;
- for (size_t cluster = first + 1; cluster < actualClusters; cluster++)
- {
- // Make sure it isn't already DBL_MAX.
- if (distances(offset + (cluster - first)) != DBL_MAX)
- {
- distances(offset + (cluster - first)) = metric.Evaluate(
- centroids.col(first), centroids.col(cluster));
- }
- }
-
- // Max the distance between the first and second clusters.
- distances(offset + (second - first)) = DBL_MAX;
-
- // Now max the distances for the second cluster (which no longer has
- // anything in it).
- offset = (size_t) (((actualClusters - 1) * second)
- + (second - pow(second, 2.0)) / 2) - 1;
- for (size_t cluster = second + 1; cluster < actualClusters; cluster++)
- distances(offset + (cluster - second)) = DBL_MAX;
-
- clustersLeft--;
-
- // Update the cluster mappings.
- mappings(second) = first;
- // Also update any mappings that were pointed at the previous cluster.
- for (size_t cluster = 0; cluster < actualClusters; cluster++)
- if (mappings(cluster) == second)
- mappings(cluster) = first;
- }
-
- // Now remap the mappings down to the smallest possible numbers.
- // Could this process be sped up?
- arma::Col<size_t> remappings(actualClusters);
- remappings.fill(actualClusters);
- size_t remap = 0; // Counter variable.
- for (size_t cluster = 0; cluster < actualClusters; cluster++)
- {
- // If the mapping of the current cluster has not been assigned a value
- // yet, we will assign it a cluster number.
- if (remappings(mappings(cluster)) == actualClusters)
- {
- remappings(mappings(cluster)) = remap;
- remap++;
- }
- }
-
- // Fix the assignments using the mappings we created.
- for (size_t j = 0; j < assignments.n_elem; j++)
- assignments(j) = remappings(mappings(assignments(j)));
- }
-}
-
-}; // namespace kmeans
-}; // namespace mlpack
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/kmeans_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/kmeans/kmeans_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/kmeans_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/kmeans_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,822 @@
+/**
+ * @file kmeans_impl.hpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ * @author Ryan Curtin
+ *
+ * Implementation for the K-means method for getting an initial point.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "kmeans.hpp"
+
+#include <mlpack/core/tree/mrkd_statistic.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+
+#include <stack>
+#include <limits>
+
+namespace mlpack {
+namespace kmeans {
+
+/**
+ * Construct the K-Means object.
+ */
+template<typename MetricType,
+ typename InitialPartitionPolicy,
+ typename EmptyClusterPolicy>
+KMeans<
+ MetricType,
+ InitialPartitionPolicy,
+ EmptyClusterPolicy>::
+KMeans(const size_t maxIterations,
+ const double overclusteringFactor,
+ const MetricType metric,
+ const InitialPartitionPolicy partitioner,
+ const EmptyClusterPolicy emptyClusterAction) :
+ maxIterations(maxIterations),
+ metric(metric),
+ partitioner(partitioner),
+ emptyClusterAction(emptyClusterAction)
+{
+ // Validate overclustering factor.
+ if (overclusteringFactor < 1.0)
+ {
+ Log::Warn << "KMeans::KMeans(): overclustering factor must be >= 1.0 ("
+ << overclusteringFactor << " given). Setting factor to 1.0.\n";
+ this->overclusteringFactor = 1.0;
+ }
+ else
+ {
+ this->overclusteringFactor = overclusteringFactor;
+ }
+}
+
+template<typename MetricType,
+ typename InitialPartitionPolicy,
+ typename EmptyClusterPolicy>
+template<typename MatType>
+void KMeans<
+ MetricType,
+ InitialPartitionPolicy,
+ EmptyClusterPolicy>::
+FastCluster(MatType& data,
+ const size_t clusters,
+ arma::Col<size_t>& assignments) const
+{
+ size_t actualClusters = size_t(overclusteringFactor * clusters);
+ if (actualClusters > data.n_cols)
+ {
+ Log::Warn << "KMeans::Cluster(): overclustering factor is too large. No "
+ << "overclustering will be done." << std::endl;
+ actualClusters = clusters;
+ }
+
+ size_t dimensionality = data.n_rows;
+
+ // Centroids of each cluster. Each column corresponds to a centroid.
+ MatType centroids(dimensionality, actualClusters);
+ centroids.zeros();
+
+ // Counts of points in each cluster.
+ arma::Col<size_t> counts(actualClusters);
+ counts.zeros();
+
+ // Build the mrkd-tree on this dataset.
+ tree::BinarySpaceTree<typename bound::HRectBound<2>, tree::MRKDStatistic>
+ tree(data, 1);
+ Log::Debug << "Tree Built." << std::endl;
+ // A pointer for traversing the mrkd-tree.
+ tree::BinarySpaceTree<typename bound::HRectBound<2>, tree::MRKDStatistic>*
+ node;
+
+ // Now, the initial assignments. First determine if they are necessary.
+ if (assignments.n_elem != data.n_cols)
+ {
+ // Use the partitioner to come up with the partition assignments.
+ partitioner.Cluster(data, actualClusters, assignments);
+ }
+
+ // Set counts correctly.
+ for (size_t i = 0; i < assignments.n_elem; i++)
+ counts[assignments[i]]++;
+
+ // Sum the points for each centroid
+ for (size_t i = 0; i < data.n_cols; i++)
+ centroids.col(assignments[i]) += data.col(i);
+
+ // Then divide the sums by the count to get the center of mass for this
+ // centroids assigned points
+ for (size_t i = 0; i < actualClusters; i++)
+ centroids.col(i) /= counts[i];
+
+ // Instead of retraversing the tree after an iteration, we will update
+ // centroid positions in this matrix, which also prevents clobbering our
+ // centroids from the previous iteration.
+ MatType newCentroids(dimensionality, centroids.n_cols);
+
+ // Create a stack for traversing the mrkd-tree.
+ std::stack<typename tree::BinarySpaceTree<typename bound::HRectBound<2>,
+ tree::MRKDStatistic>* > stack;
+
+ // A variable to keep track of how many kmeans iterations we have made.
+ size_t iteration = 0;
+
+ // A variable to keep track of how many nodes assignments have changed in
+ // each kmeans iteration.
+ size_t changedAssignments = 0;
+
+ // A variable to keep track of the number of times something is skipped due
+ // to the blacklist.
+ size_t skip = 0;
+
+ // A variable to keep track of the number of distances calculated.
+ size_t comps = 0;
+
+ // A variable to keep track of how often we stop at a parent node.
+ size_t dominations = 0;
+ do
+ {
+ // Keep track of what iteration we are on.
+ ++iteration;
+ changedAssignments = 0;
+
+ // Reset the newCentroids so that we can store the newly calculated ones
+ // here.
+ newCentroids.zeros();
+
+ // Reset the counts.
+ counts.zeros();
+
+ // Add the root node of the tree to the stack.
+ stack.push(&tree);
+ // Set the top level whitelist.
+ tree.Stat().Whitelist().resize(centroids.n_cols, true);
+
+ // Traverse the tree.
+ while (!stack.empty())
+ {
+ // Get the next node in the tree.
+ node = stack.top();
+ // Remove the node from the stack.
+ stack.pop();
+
+ // Get a reference to the mrkd statistic for this hyperrectangle.
+ tree::MRKDStatistic& mrkd = node->Stat();
+
+ // We use this to store the index of the centroid with the minimum
+ // distance from this hyperrectangle or point.
+ size_t minIndex = 0;
+
+ // If this node is a leaf, then we calculate the distance from
+ // the centroids to every point the node contains.
+ if (node->IsLeaf())
+ {
+ for (size_t i = mrkd.Begin(); i < mrkd.Count() + mrkd.Begin(); ++i)
+ {
+ // Initialize minDistance to be nonzero.
+ double minDistance = metric.Evaluate(data.col(i), centroids.col(0));
+
+ // Find the minimal distance centroid for this point.
+ for (size_t j = 1; j < centroids.n_cols; ++j)
+ {
+ // If this centroid is not in the whitelist, skip it.
+ if (!mrkd.Whitelist()[j])
+ {
+ ++skip;
+ continue;
+ }
+
+ ++comps;
+ double distance = metric.Evaluate(data.col(i), centroids.col(j));
+ if (minDistance > distance)
+ {
+ minIndex = j;
+ minDistance = distance;
+ }
+ }
+
+ // Add this point to the undivided center of mass summation for its
+ // assigned centroid.
+ newCentroids.col(minIndex) += data.col(i);
+
+ // Increment the count for the minimum distance centroid.
+ ++counts(minIndex);
+
+ // If we actually changed assignments, increment changedAssignments
+ // and modify the assignment vector for this point.
+ if (assignments(i) != minIndex)
+ {
+ ++changedAssignments;
+ assignments(i) = minIndex;
+ }
+ }
+ }
+ // If this node is not a leaf, then we continue trying to find dominant
+ // centroids.
+ else
+ {
+ bound::HRectBound<2>& bound = node->Bound();
+
+ // A flag to keep track of if we find a single centroid that is closer
+ // to all points in this hyperrectangle than any other centroid.
+ bool noDomination = false;
+
+ // Calculate the center of mass of this hyperrectangle.
+ arma::vec center = mrkd.CenterOfMass() / mrkd.Count();
+
+ // Set the minDistance to the maximum value of a double so any value
+ // must be smaller than this.
+ double minDistance = std::numeric_limits<double>::max();
+
+ // The candidate distance we calculate for each centroid.
+ double distance = 0.0;
+
+ // How many points are inside this hyperrectangle, we stop if we
+ // see more than 1.
+ size_t contains = 0;
+
+ // Find the "owner" of this hyperrectangle, if one exists.
+ for (size_t i = 0; i < centroids.n_cols; ++i)
+ {
+ // If this centroid is not in the whitelist, skip it.
+ if (!mrkd.Whitelist()[i])
+ {
+ ++skip;
+ continue;
+ }
+
+ // Incrememnt the number of distance calculations for what we are
+ // about to do.
+ comps += 2;
+
+ // Reinitialize the distance so += works right.
+ distance = 0.0;
+
+ // We keep track of how many dimensions have nonzero distance,
+ // if this is 0 then the distance is 0.
+ size_t nonZero = 0;
+
+ /*
+ Compute the distance to the hyperrectangle for this centroid.
+ We do this by finding the furthest point from the centroid inside
+ the hyperrectangle. This is a corner of the hyperrectangle.
+
+ In order to do this faster, we calculate both the distance and the
+ furthest point simultaneously.
+
+ This following code is equivalent to, but faster than:
+
+ arma::vec p;
+ p.zeros(dimensionality);
+
+ for (size_t j = 0; j < dimensionality; ++j)
+ {
+ if (centroids(j,i) < bound[j].Lo())
+ p(j) = bound[j].Lo();
+ else
+ p(j) = bound[j].Hi();
+ }
+
+ distance = metric.Evaluate(p.col(0), centroids.col(i));
+ */
+ for (size_t j = 0; j < dimensionality; ++j)
+ {
+ double ij = centroids(j,i);
+ double lo = bound[j].Lo();
+
+ if (ij < lo)
+ {
+ // (ij - lo)^2
+ ij -= lo;
+ ij *= ij;
+
+ distance += ij;
+ ++nonZero;
+ }
+ else
+ {
+ double hi = bound[j].Hi();
+ if (ij > hi)
+ {
+ // (ij - hi)^2
+ ij -= hi;
+ ij *= ij;
+
+ distance += ij;
+ ++nonZero;
+ }
+ }
+ }
+
+ // The centroid is inside the hyperrectangle.
+ if (nonZero == 0)
+ {
+ ++contains;
+ minDistance = 0.0;
+ minIndex = i;
+
+ // If more than two points are within this hyperrectangle, then
+ // there can be no dominating centroid, so we should continue
+ // to the children nodes.
+ if (contains > 1)
+ {
+ noDomination = true;
+ break;
+ }
+ }
+
+ if (fabs(distance - minDistance) <= 1e-10)
+ {
+ noDomination = true;
+ break;
+ }
+ else if (distance < minDistance)
+ {
+ minIndex = i;
+ minDistance = distance;
+ }
+ }
+
+ distance = minDistance;
+ // Determine if the owner dominates this centroid only if there was
+ // exactly one owner.
+ if (!noDomination)
+ {
+ for (size_t i = 0; i < centroids.n_cols; ++i)
+ {
+ if (i == minIndex)
+ continue;
+ // If this centroid is blacklisted for this hyperrectangle, then
+ // we skip it.
+ if (!mrkd.Whitelist()[i])
+ {
+ ++skip;
+ continue;
+ }
+ /*
+ Compute the dominating centroid for this hyperrectangle, if one
+ exists. We do this by calculating the point which is furthest
+ from the min'th centroid in the direction of c_k - c_min. We do
+ this as outlined in the Pelleg and Moore paper.
+
+ This following code is equivalent to, but faster than:
+
+ arma::vec p;
+ p.zeros(dimensionality);
+
+ for (size_t k = 0; k < dimensionality; ++k)
+ {
+ p(k) = (centroids(k,i) > centroids(k,minIndex)) ?
+ bound[k].Hi() : bound[k].Lo();
+ }
+
+ double distancei = metric.Evaluate(p.col(0), centroids.col(i));
+ double distanceMin = metric.Evaluate(p.col(0),
+ centroids.col(minIndex));
+ */
+
+ comps += 1;
+ double distancei = 0.0;
+ double distanceMin = 0.0;
+ for (size_t k = 0; k < dimensionality; ++k)
+ {
+ double ci = centroids(k, i);
+ double cm = centroids(k, minIndex);
+ if (ci > cm)
+ {
+ double hi = bound[k].Hi();
+
+ ci -= hi;
+ cm -= hi;
+
+ ci *= ci;
+ cm *= cm;
+
+ distancei += ci;
+ distanceMin += cm;
+ }
+ else
+ {
+ double lo = bound[k].Lo();
+
+ ci -= lo;
+ cm -= lo;
+
+ ci *= ci;
+ cm *= cm;
+
+ distancei += ci;
+ distanceMin += cm;
+ }
+ }
+
+ if (distanceMin >= distancei)
+ {
+ noDomination = true;
+ break;
+ }
+ else
+ {
+ mrkd.Whitelist()[i] = false;
+ }
+ }
+ }
+
+ // If did found a centroid that was closer to every point in the
+ // hyperrectangle than every other centroid, then update that centroid.
+ if (!noDomination)
+ {
+ // Adjust the new centroid sum for the min distance point to this
+ // hyperrectangle by the center of mass of this hyperrectangle.
+ newCentroids.col(minIndex) += mrkd.CenterOfMass();
+
+ // Increment the counts for this centroid.
+ counts(minIndex) += mrkd.Count();
+
+ // Update all assignments for this node.
+ const size_t begin = node->Begin();
+ const size_t end = node->End();
+
+ // TODO: Do this outside of the kmeans iterations.
+ for (size_t j = begin; j < end; ++j)
+ {
+ if (assignments(j) != minIndex)
+ {
+ ++changedAssignments;
+ assignments(j) = minIndex;
+ }
+ }
+ mrkd.DominatingCentroid() = minIndex;
+
+ // Keep track of the number of times we found a dominating centroid.
+ ++dominations;
+ }
+
+ // If we did not find a dominating centroid then we fall through to the
+ // default case, where we add the children of this node to the stack.
+ else
+ {
+ // Add this hyperrectangle's children to our stack.
+ stack.push(node->Left());
+ stack.push(node->Right());
+
+ // (Re)Initialize the whiteList for the children.
+ node->Left()->Stat().Whitelist() = mrkd.Whitelist();
+ node->Right()->Stat().Whitelist() = mrkd.Whitelist();
+ }
+ }
+
+ }
+
+ // Divide by the number of points assigned to the centroids so that we
+ // have the actual center of mass and update centroids' positions.
+ for (size_t i = 0; i < centroids.n_cols; ++i)
+ if (counts(i))
+ centroids.col(i) = newCentroids.col(i) / counts(i);
+
+ // Stop when we reach max iterations or we changed no assignments
+ // assignments.
+ } while (changedAssignments > 0 && iteration != maxIterations);
+
+ Log::Info << "Iterations: " << iteration << std::endl
+ << "Skips: " << skip << std::endl
+ << "Comparisons: " << comps << std::endl
+ << "Dominations: " << dominations << std::endl;
+}
+
+/**
+ * Perform k-means clustering on the data, returning a list of cluster
+ * assignments. This just forward to the other function, which returns the
+ * centroids too. If this is properly inlined, there shouldn't be any
+ * performance penalty whatsoever.
+ */
+template<typename MetricType,
+ typename InitialPartitionPolicy,
+ typename EmptyClusterPolicy>
+template<typename MatType>
+inline void KMeans<
+ MetricType,
+ InitialPartitionPolicy,
+ EmptyClusterPolicy>::
+Cluster(const MatType& data,
+ const size_t clusters,
+ arma::Col<size_t>& assignments,
+ const bool initialGuess) const
+{
+ MatType centroids(data.n_rows, clusters);
+ Cluster(data, clusters, assignments, centroids, initialGuess);
+}
+
+/**
+ * Perform k-means clustering on the data, returning a list of cluster
+ * assignments and the centroids of each cluster.
+ */
+template<typename MetricType,
+ typename InitialPartitionPolicy,
+ typename EmptyClusterPolicy>
+template<typename MatType>
+void KMeans<
+ MetricType,
+ InitialPartitionPolicy,
+ EmptyClusterPolicy>::
+Cluster(const MatType& data,
+ const size_t clusters,
+ arma::Col<size_t>& assignments,
+ MatType& centroids,
+ const bool initialAssignmentGuess,
+ const bool initialCentroidGuess) const
+{
+ // Make sure we have more points than clusters.
+ if (clusters > data.n_cols)
+ Log::Warn << "KMeans::Cluster(): more clusters requested than points given."
+ << std::endl;
+
+ // Make sure our overclustering factor is valid.
+ size_t actualClusters = size_t(overclusteringFactor * clusters);
+ if (actualClusters > data.n_cols && overclusteringFactor != 1.0)
+ {
+ Log::Warn << "KMeans::Cluster(): overclustering factor is too large. No "
+ << "overclustering will be done." << std::endl;
+ actualClusters = clusters;
+ }
+
+ // Now, the initial assignments. First determine if they are necessary.
+ if (initialAssignmentGuess)
+ {
+ if (assignments.n_elem != data.n_cols)
+ Log::Fatal << "KMeans::Cluster(): initial cluster assignments (length "
+ << assignments.n_elem << ") not the same size as the dataset (size "
+ << data.n_cols << ")!" << std::endl;
+ }
+ else if (initialCentroidGuess)
+ {
+ if (centroids.n_cols != clusters)
+ Log::Fatal << "KMeans::Cluster(): wrong number of initial cluster "
+ << "centroids (" << centroids.n_cols << ", should be " << clusters
+ << ")!" << std::endl;
+
+ if (centroids.n_rows != data.n_rows)
+ Log::Fatal << "KMeans::Cluster(): initial cluster centroids have wrong "
+ << " dimensionality (" << centroids.n_rows << ", should be "
+ << data.n_rows << ")!" << std::endl;
+
+ // If there were no problems, construct the initial assignments from the
+ // given centroids.
+ assignments.set_size(data.n_cols);
+ for (size_t i = 0; i < data.n_cols; ++i)
+ {
+ // Find the closest centroid to this point.
+ double minDistance = std::numeric_limits<double>::infinity();
+ size_t closestCluster = clusters; // Invalid value.
+
+ for (size_t j = 0; j < clusters; j++)
+ {
+ double distance = metric.Evaluate(data.col(i), centroids.col(j));
+
+ if (distance < minDistance)
+ {
+ minDistance = distance;
+ closestCluster = j;
+ }
+ }
+
+ // Assign the point to the closest cluster that we found.
+ assignments[i] = closestCluster;
+ }
+ }
+ else
+ {
+ // Use the partitioner to come up with the partition assignments.
+ partitioner.Cluster(data, actualClusters, assignments);
+ }
+
+ // Counts of points in each cluster.
+ arma::Col<size_t> counts(actualClusters);
+ counts.zeros();
+
+ // Resize to correct size.
+ centroids.set_size(data.n_rows, actualClusters);
+
+ // Set counts correctly.
+ for (size_t i = 0; i < assignments.n_elem; i++)
+ counts[assignments[i]]++;
+
+ size_t changedAssignments = 0;
+ size_t iteration = 0;
+ do
+ {
+ // Update step.
+ // Calculate centroids based on given assignments.
+ centroids.zeros();
+
+ for (size_t i = 0; i < data.n_cols; i++)
+ centroids.col(assignments[i]) += data.col(i);
+
+ for (size_t i = 0; i < actualClusters; i++)
+ centroids.col(i) /= counts[i];
+
+ // Assignment step.
+ // Find the closest centroid to each point. We will keep track of how many
+ // assignments change. When no assignments change, we are done.
+ changedAssignments = 0;
+ for (size_t i = 0; i < data.n_cols; i++)
+ {
+ // Find the closest centroid to this point.
+ double minDistance = std::numeric_limits<double>::infinity();
+ size_t closestCluster = actualClusters; // Invalid value.
+
+ for (size_t j = 0; j < actualClusters; j++)
+ {
+ double distance = metric.Evaluate(data.col(i), centroids.col(j));
+
+ if (distance < minDistance)
+ {
+ minDistance = distance;
+ closestCluster = j;
+ }
+ }
+
+ // Reassign this point to the closest cluster.
+ if (assignments[i] != closestCluster)
+ {
+ // Update counts.
+ counts[assignments[i]]--;
+ counts[closestCluster]++;
+ // Update assignment.
+ assignments[i] = closestCluster;
+ changedAssignments++;
+ }
+ }
+
+ // If we are not allowing empty clusters, then check that all of our
+ // clusters have points.
+ for (size_t i = 0; i < actualClusters; i++)
+ if (counts[i] == 0)
+ changedAssignments += emptyClusterAction.EmptyCluster(data, i,
+ centroids, counts, assignments);
+
+ iteration++;
+
+ } while (changedAssignments > 0 && iteration != maxIterations);
+
+ if (iteration != maxIterations)
+ {
+ Log::Debug << "KMeans::Cluster(): converged after " << iteration
+ << " iterations." << std::endl;
+ }
+ else
+ {
+ Log::Debug << "KMeans::Cluster(): terminated after limit of " << iteration
+ << " iterations." << std::endl;
+
+ // Recalculate final clusters.
+ centroids.zeros();
+
+ for (size_t i = 0; i < data.n_cols; i++)
+ centroids.col(assignments[i]) += data.col(i);
+
+ for (size_t i = 0; i < actualClusters; i++)
+ centroids.col(i) /= counts[i];
+ }
+
+ // If we have overclustered, we need to merge the nearest clusters.
+ if (actualClusters != clusters)
+ {
+ // Generate a list of all the clusters' distances from each other. This
+ // list will become mangled and unused as the number of clusters decreases.
+ size_t numDistances = ((actualClusters - 1) * actualClusters) / 2;
+ size_t clustersLeft = actualClusters;
+ arma::vec distances(numDistances);
+ arma::Col<size_t> firstCluster(numDistances);
+ arma::Col<size_t> secondCluster(numDistances);
+
+ // Keep the mappings of clusters that we are changing.
+ arma::Col<size_t> mappings = arma::linspace<arma::Col<size_t> >(0,
+ actualClusters - 1, actualClusters);
+
+ size_t i = 0;
+ for (size_t first = 0; first < actualClusters; first++)
+ {
+ for (size_t second = first + 1; second < actualClusters; second++)
+ {
+ distances(i) = metric.Evaluate(centroids.col(first),
+ centroids.col(second));
+ firstCluster(i) = first;
+ secondCluster(i) = second;
+ i++;
+ }
+ }
+
+ while (clustersLeft != clusters)
+ {
+ arma::uword minIndex;
+ distances.min(minIndex);
+
+ // Now we merge the clusters which that distance belongs to.
+ size_t first = firstCluster(minIndex);
+ size_t second = secondCluster(minIndex);
+ for (size_t j = 0; j < assignments.n_elem; j++)
+ if (assignments(j) == second)
+ assignments(j) = first;
+
+ // Now merge the centroids.
+ centroids.col(first) *= counts[first];
+ centroids.col(first) += (counts[second] * centroids.col(second));
+ centroids.col(first) /= (counts[first] + counts[second]);
+
+ // Update the counts.
+ counts[first] += counts[second];
+ counts[second] = 0;
+
+ // Now update all the relevant distances.
+ // First the distances where either cluster is the second cluster.
+ for (size_t cluster = 0; cluster < second; cluster++)
+ {
+ // The offset is sum^n i - sum^(n - m) i, where n is actualClusters and
+ // m is the cluster we are trying to offset to.
+ size_t offset = (size_t) (((actualClusters - 1) * cluster)
+ + (cluster - pow(cluster, 2.0)) / 2) - 1;
+
+ // See if we need to update the distance from this cluster to the first
+ // cluster.
+ if (cluster < first)
+ {
+ // Make sure it isn't already DBL_MAX.
+ if (distances(offset + (first - cluster)) != DBL_MAX)
+ distances(offset + (first - cluster)) = metric.Evaluate(
+ centroids.col(first), centroids.col(cluster));
+ }
+
+ distances(offset + (second - cluster)) = DBL_MAX;
+ }
+
+ // Now the distances where the first cluster is the first cluster.
+ size_t offset = (size_t) (((actualClusters - 1) * first)
+ + (first - pow(first, 2.0)) / 2) - 1;
+ for (size_t cluster = first + 1; cluster < actualClusters; cluster++)
+ {
+ // Make sure it isn't already DBL_MAX.
+ if (distances(offset + (cluster - first)) != DBL_MAX)
+ {
+ distances(offset + (cluster - first)) = metric.Evaluate(
+ centroids.col(first), centroids.col(cluster));
+ }
+ }
+
+ // Max the distance between the first and second clusters.
+ distances(offset + (second - first)) = DBL_MAX;
+
+ // Now max the distances for the second cluster (which no longer has
+ // anything in it).
+ offset = (size_t) (((actualClusters - 1) * second)
+ + (second - pow(second, 2.0)) / 2) - 1;
+ for (size_t cluster = second + 1; cluster < actualClusters; cluster++)
+ distances(offset + (cluster - second)) = DBL_MAX;
+
+ clustersLeft--;
+
+ // Update the cluster mappings.
+ mappings(second) = first;
+ // Also update any mappings that were pointed at the previous cluster.
+ for (size_t cluster = 0; cluster < actualClusters; cluster++)
+ if (mappings(cluster) == second)
+ mappings(cluster) = first;
+ }
+
+ // Now remap the mappings down to the smallest possible numbers.
+ // Could this process be sped up?
+ arma::Col<size_t> remappings(actualClusters);
+ remappings.fill(actualClusters);
+ size_t remap = 0; // Counter variable.
+ for (size_t cluster = 0; cluster < actualClusters; cluster++)
+ {
+ // If the mapping of the current cluster has not been assigned a value
+ // yet, we will assign it a cluster number.
+ if (remappings(mappings(cluster)) == actualClusters)
+ {
+ remappings(mappings(cluster)) = remap;
+ remap++;
+ }
+ }
+
+ // Fix the assignments using the mappings we created.
+ for (size_t j = 0; j < assignments.n_elem; j++)
+ assignments(j) = remappings(mappings(assignments(j)));
+ }
+}
+
+}; // namespace kmeans
+}; // namespace mlpack
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/kmeans_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/kmeans/kmeans_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/kmeans_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,256 +0,0 @@
-/**
- * @file kmeans_main.cpp
- * @author Ryan Curtin
- *
- * Executable for running K-Means.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-
-#include "kmeans.hpp"
-#include "allow_empty_clusters.hpp"
-#include "refined_start.hpp"
-
-using namespace mlpack;
-using namespace mlpack::kmeans;
-using namespace std;
-
-// Define parameters for the executable.
-PROGRAM_INFO("K-Means Clustering", "This program performs K-Means clustering "
- "on the given dataset, storing the learned cluster assignments either as "
- "a column of labels in the file containing the input dataset or in a "
- "separate file. Empty clusters are not allowed by default; when a cluster "
- "becomes empty, the point furthest from the centroid of the cluster with "
- "maximum variance is taken to fill that cluster."
- "\n\n"
- "Optionally, the Bradley and Fayyad approach (\"Refining initial points for"
- " k-means clustering\", 1998) can be used to select initial points by "
- "specifying the --refined_start (-r) option. This approach works by taking"
- " random samples of the dataset; to specify the number of samples, the "
- "--samples parameter is used, and to specify the percentage of the dataset "
- "to be used in each sample, the --percentage parameter is used (it should "
- "be a value between 0.0 and 1.0)."
- "\n\n"
- "If you want to specify your own initial cluster assignments or initial "
- "cluster centroids, this functionality is available in the C++ interface. "
- "Alternately, file a bug (well, a feature request) on the mlpack bug "
- "tracker.");
-
-// Required options.
-PARAM_STRING_REQ("inputFile", "Input dataset to perform clustering on.", "i");
-PARAM_INT_REQ("clusters", "Number of clusters to find.", "c");
-
-// Output options.
-PARAM_FLAG("in_place", "If specified, a column of the learned cluster "
- "assignments will be added to the input dataset file. In this case, "
- "--outputFile is not necessary.", "p");
-PARAM_STRING("output_file", "File to write output labels or labeled data to.",
- "o", "output.csv");
-PARAM_STRING("centroid_file", "If specified, the centroids of each cluster will"
- " be written to the given file.", "C", "");
-
-// k-means configuration options.
-PARAM_FLAG("allow_empty_clusters", "Allow empty clusters to be created.", "e");
-PARAM_FLAG("labels_only", "Only output labels into output file.", "l");
-PARAM_DOUBLE("overclustering", "Finds (overclustering * clusters) clusters, "
- "then merges them together until only the desired number of clusters are "
- "left.", "O", 1.0);
-PARAM_INT("max_iterations", "Maximum number of iterations before K-Means "
- "terminates.", "m", 1000);
-PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
-
-// This is known to not work (#251).
-//PARAM_FLAG("fast_kmeans", "Use the experimental fast k-means algorithm by "
-// "Pelleg and Moore.", "f");
-
-// Parameters for "refined start" k-means.
-PARAM_FLAG("refined_start", "Use the refined initial point strategy by Bradley "
- "and Fayyad to choose initial points.", "r");
-PARAM_INT("samplings", "Number of samplings to perform for refined start (use "
- "when --refined_start is specified).", "S", 100);
-PARAM_DOUBLE("percentage", "Percentage of dataset to use for each refined start"
- " sampling (use when --refined_start is specified).", "p", 0.02);
-
-
-int main(int argc, char** argv)
-{
- CLI::ParseCommandLine(argc, argv);
-
- // Initialize random seed.
- if (CLI::GetParam<int>("seed") != 0)
- math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
- else
- math::RandomSeed((size_t) std::time(NULL));
-
- // Now do validation of options.
- string inputFile = CLI::GetParam<string>("inputFile");
- int clusters = CLI::GetParam<int>("clusters");
- if (clusters < 1)
- {
- Log::Fatal << "Invalid number of clusters requested (" << clusters << ")! "
- << "Must be greater than or equal to 1." << std::endl;
- }
-
- int maxIterations = CLI::GetParam<int>("max_iterations");
- if (maxIterations < 0)
- {
- Log::Fatal << "Invalid value for maximum iterations (" << maxIterations <<
- ")! Must be greater than or equal to 0." << std::endl;
- }
-
- double overclustering = CLI::GetParam<double>("overclustering");
- if (overclustering < 1)
- {
- Log::Fatal << "Invalid value for overclustering (" << overclustering <<
- ")! Must be greater than or equal to 1." << std::endl;
- }
-
- // Make sure we have an output file if we're not doing the work in-place.
- if (!CLI::HasParam("in_place") && !CLI::HasParam("output_file"))
- {
- Log::Fatal << "--outputFile not specified (and --in_place not set)."
- << std::endl;
- }
-
- // Load our dataset.
- arma::mat dataset;
- data::Load(inputFile.c_str(), dataset, true); // Fatal upon failure.
-
- // Now create the KMeans object. Because we could be using different types,
- // it gets a little weird...
- arma::Col<size_t> assignments;
- arma::mat centroids;
-
- if (CLI::HasParam("allow_empty_clusters"))
- {
- if (CLI::HasParam("refined_start"))
- {
- const int samplings = CLI::GetParam<int>("samplings");
- const double percentage = CLI::GetParam<int>("percentage");
-
- if (samplings < 0)
- Log::Fatal << "Number of samplings (" << samplings << ") must be "
- << "greater than 0!" << std::endl;
- if (percentage <= 0.0 || percentage > 1.0)
- Log::Fatal << "Percentage for sampling (" << percentage << ") must be "
- << "greater than 0.0 and less than or equal to 1.0!" << std::endl;
-
- KMeans<metric::SquaredEuclideanDistance, RefinedStart, AllowEmptyClusters>
- k(maxIterations, overclustering, metric::SquaredEuclideanDistance(),
- RefinedStart(samplings, percentage));
-
- Timer::Start("clustering");
- if (CLI::HasParam("fast_kmeans"))
- k.FastCluster(dataset, clusters, assignments);
- else
- k.Cluster(dataset, clusters, assignments, centroids);
- Timer::Stop("clustering");
- }
- else
- {
- KMeans<metric::SquaredEuclideanDistance, RandomPartition,
- AllowEmptyClusters> k(maxIterations, overclustering);
-
- Timer::Start("clustering");
- if (CLI::HasParam("fast_kmeans"))
- k.FastCluster(dataset, clusters, assignments);
- else
- k.Cluster(dataset, clusters, assignments, centroids);
- Timer::Stop("clustering");
- }
- }
- else
- {
- if (CLI::HasParam("refined_start"))
- {
- const int samplings = CLI::GetParam<int>("samplings");
- const double percentage = CLI::GetParam<int>("percentage");
-
- if (samplings < 0)
- Log::Fatal << "Number of samplings (" << samplings << ") must be "
- << "greater than 0!" << std::endl;
- if (percentage <= 0.0 || percentage > 1.0)
- Log::Fatal << "Percentage for sampling (" << percentage << ") must be "
- << "greater than 0.0 and less than or equal to 1.0!" << std::endl;
-
- KMeans<metric::SquaredEuclideanDistance, RefinedStart, AllowEmptyClusters>
- k(maxIterations, overclustering, metric::SquaredEuclideanDistance(),
- RefinedStart(samplings, percentage));
-
- Timer::Start("clustering");
- if (CLI::HasParam("fast_kmeans"))
- k.FastCluster(dataset, clusters, assignments);
- else
- k.Cluster(dataset, clusters, assignments, centroids);
- Timer::Stop("clustering");
- }
- else
- {
- KMeans<> k(maxIterations, overclustering);
-
- Timer::Start("clustering");
- if (CLI::HasParam("fast_kmeans"))
- k.FastCluster(dataset, clusters, assignments);
- else
- k.Cluster(dataset, clusters, assignments, centroids);
- Timer::Stop("clustering");
- }
- }
-
- // Now figure out what to do with our results.
- if (CLI::HasParam("in_place"))
- {
- // Add the column of assignments to the dataset; but we have to convert them
- // to type double first.
- arma::vec converted(assignments.n_elem);
- for (size_t i = 0; i < assignments.n_elem; i++)
- converted(i) = (double) assignments(i);
-
- dataset.insert_rows(dataset.n_rows, trans(converted));
-
- // Save the dataset.
- data::Save(inputFile, dataset);
- }
- else
- {
- if (CLI::HasParam("labels_only"))
- {
- // Save only the labels.
- string outputFile = CLI::GetParam<string>("output_file");
- arma::Mat<size_t> output = trans(assignments);
- data::Save(outputFile, output);
- }
- else
- {
- // Convert the assignments to doubles.
- arma::vec converted(assignments.n_elem);
- for (size_t i = 0; i < assignments.n_elem; i++)
- converted(i) = (double) assignments(i);
-
- dataset.insert_rows(dataset.n_rows, trans(converted));
-
- // Now save, in the different file.
- string outputFile = CLI::GetParam<string>("output_file");
- data::Save(outputFile, dataset);
- }
- }
-
- // Should we write the centroids to a file?
- if (CLI::HasParam("centroid_file"))
- data::Save(CLI::GetParam<std::string>("centroid_file"), centroids);
-}
-
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/kmeans_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/kmeans/kmeans_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/kmeans_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/kmeans_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,256 @@
+/**
+ * @file kmeans_main.cpp
+ * @author Ryan Curtin
+ *
+ * Executable for running K-Means.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+
+#include "kmeans.hpp"
+#include "allow_empty_clusters.hpp"
+#include "refined_start.hpp"
+
+using namespace mlpack;
+using namespace mlpack::kmeans;
+using namespace std;
+
+// Define parameters for the executable.
+PROGRAM_INFO("K-Means Clustering", "This program performs K-Means clustering "
+ "on the given dataset, storing the learned cluster assignments either as "
+ "a column of labels in the file containing the input dataset or in a "
+ "separate file. Empty clusters are not allowed by default; when a cluster "
+ "becomes empty, the point furthest from the centroid of the cluster with "
+ "maximum variance is taken to fill that cluster."
+ "\n\n"
+ "Optionally, the Bradley and Fayyad approach (\"Refining initial points for"
+ " k-means clustering\", 1998) can be used to select initial points by "
+ "specifying the --refined_start (-r) option. This approach works by taking"
+ " random samples of the dataset; to specify the number of samples, the "
+ "--samples parameter is used, and to specify the percentage of the dataset "
+ "to be used in each sample, the --percentage parameter is used (it should "
+ "be a value between 0.0 and 1.0)."
+ "\n\n"
+ "If you want to specify your own initial cluster assignments or initial "
+ "cluster centroids, this functionality is available in the C++ interface. "
+ "Alternately, file a bug (well, a feature request) on the mlpack bug "
+ "tracker.");
+
+// Required options.
+PARAM_STRING_REQ("inputFile", "Input dataset to perform clustering on.", "i");
+PARAM_INT_REQ("clusters", "Number of clusters to find.", "c");
+
+// Output options.
+PARAM_FLAG("in_place", "If specified, a column of the learned cluster "
+ "assignments will be added to the input dataset file. In this case, "
+ "--outputFile is not necessary.", "p");
+PARAM_STRING("output_file", "File to write output labels or labeled data to.",
+ "o", "output.csv");
+PARAM_STRING("centroid_file", "If specified, the centroids of each cluster will"
+ " be written to the given file.", "C", "");
+
+// k-means configuration options.
+PARAM_FLAG("allow_empty_clusters", "Allow empty clusters to be created.", "e");
+PARAM_FLAG("labels_only", "Only output labels into output file.", "l");
+PARAM_DOUBLE("overclustering", "Finds (overclustering * clusters) clusters, "
+ "then merges them together until only the desired number of clusters are "
+ "left.", "O", 1.0);
+PARAM_INT("max_iterations", "Maximum number of iterations before K-Means "
+ "terminates.", "m", 1000);
+PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
+
+// This is known to not work (#251).
+//PARAM_FLAG("fast_kmeans", "Use the experimental fast k-means algorithm by "
+// "Pelleg and Moore.", "f");
+
+// Parameters for "refined start" k-means.
+PARAM_FLAG("refined_start", "Use the refined initial point strategy by Bradley "
+ "and Fayyad to choose initial points.", "r");
+PARAM_INT("samplings", "Number of samplings to perform for refined start (use "
+ "when --refined_start is specified).", "S", 100);
+PARAM_DOUBLE("percentage", "Percentage of dataset to use for each refined start"
+ " sampling (use when --refined_start is specified).", "p", 0.02);
+
+
+int main(int argc, char** argv)
+{
+ CLI::ParseCommandLine(argc, argv);
+
+ // Initialize random seed.
+ if (CLI::GetParam<int>("seed") != 0)
+ math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
+ else
+ math::RandomSeed((size_t) std::time(NULL));
+
+ // Now do validation of options.
+ string inputFile = CLI::GetParam<string>("inputFile");
+ int clusters = CLI::GetParam<int>("clusters");
+ if (clusters < 1)
+ {
+ Log::Fatal << "Invalid number of clusters requested (" << clusters << ")! "
+ << "Must be greater than or equal to 1." << std::endl;
+ }
+
+ int maxIterations = CLI::GetParam<int>("max_iterations");
+ if (maxIterations < 0)
+ {
+ Log::Fatal << "Invalid value for maximum iterations (" << maxIterations <<
+ ")! Must be greater than or equal to 0." << std::endl;
+ }
+
+ double overclustering = CLI::GetParam<double>("overclustering");
+ if (overclustering < 1)
+ {
+ Log::Fatal << "Invalid value for overclustering (" << overclustering <<
+ ")! Must be greater than or equal to 1." << std::endl;
+ }
+
+ // Make sure we have an output file if we're not doing the work in-place.
+ if (!CLI::HasParam("in_place") && !CLI::HasParam("output_file"))
+ {
+ Log::Fatal << "--outputFile not specified (and --in_place not set)."
+ << std::endl;
+ }
+
+ // Load our dataset.
+ arma::mat dataset;
+ data::Load(inputFile.c_str(), dataset, true); // Fatal upon failure.
+
+ // Now create the KMeans object. Because we could be using different types,
+ // it gets a little weird...
+ arma::Col<size_t> assignments;
+ arma::mat centroids;
+
+ if (CLI::HasParam("allow_empty_clusters"))
+ {
+ if (CLI::HasParam("refined_start"))
+ {
+ const int samplings = CLI::GetParam<int>("samplings");
+ const double percentage = CLI::GetParam<int>("percentage");
+
+ if (samplings < 0)
+ Log::Fatal << "Number of samplings (" << samplings << ") must be "
+ << "greater than 0!" << std::endl;
+ if (percentage <= 0.0 || percentage > 1.0)
+ Log::Fatal << "Percentage for sampling (" << percentage << ") must be "
+ << "greater than 0.0 and less than or equal to 1.0!" << std::endl;
+
+ KMeans<metric::SquaredEuclideanDistance, RefinedStart, AllowEmptyClusters>
+ k(maxIterations, overclustering, metric::SquaredEuclideanDistance(),
+ RefinedStart(samplings, percentage));
+
+ Timer::Start("clustering");
+ if (CLI::HasParam("fast_kmeans"))
+ k.FastCluster(dataset, clusters, assignments);
+ else
+ k.Cluster(dataset, clusters, assignments, centroids);
+ Timer::Stop("clustering");
+ }
+ else
+ {
+ KMeans<metric::SquaredEuclideanDistance, RandomPartition,
+ AllowEmptyClusters> k(maxIterations, overclustering);
+
+ Timer::Start("clustering");
+ if (CLI::HasParam("fast_kmeans"))
+ k.FastCluster(dataset, clusters, assignments);
+ else
+ k.Cluster(dataset, clusters, assignments, centroids);
+ Timer::Stop("clustering");
+ }
+ }
+ else
+ {
+ if (CLI::HasParam("refined_start"))
+ {
+ const int samplings = CLI::GetParam<int>("samplings");
+ const double percentage = CLI::GetParam<int>("percentage");
+
+ if (samplings < 0)
+ Log::Fatal << "Number of samplings (" << samplings << ") must be "
+ << "greater than 0!" << std::endl;
+ if (percentage <= 0.0 || percentage > 1.0)
+ Log::Fatal << "Percentage for sampling (" << percentage << ") must be "
+ << "greater than 0.0 and less than or equal to 1.0!" << std::endl;
+
+ KMeans<metric::SquaredEuclideanDistance, RefinedStart, AllowEmptyClusters>
+ k(maxIterations, overclustering, metric::SquaredEuclideanDistance(),
+ RefinedStart(samplings, percentage));
+
+ Timer::Start("clustering");
+ if (CLI::HasParam("fast_kmeans"))
+ k.FastCluster(dataset, clusters, assignments);
+ else
+ k.Cluster(dataset, clusters, assignments, centroids);
+ Timer::Stop("clustering");
+ }
+ else
+ {
+ KMeans<> k(maxIterations, overclustering);
+
+ Timer::Start("clustering");
+ if (CLI::HasParam("fast_kmeans"))
+ k.FastCluster(dataset, clusters, assignments);
+ else
+ k.Cluster(dataset, clusters, assignments, centroids);
+ Timer::Stop("clustering");
+ }
+ }
+
+ // Now figure out what to do with our results.
+ if (CLI::HasParam("in_place"))
+ {
+ // Add the column of assignments to the dataset; but we have to convert them
+ // to type double first.
+ arma::vec converted(assignments.n_elem);
+ for (size_t i = 0; i < assignments.n_elem; i++)
+ converted(i) = (double) assignments(i);
+
+ dataset.insert_rows(dataset.n_rows, trans(converted));
+
+ // Save the dataset.
+ data::Save(inputFile, dataset);
+ }
+ else
+ {
+ if (CLI::HasParam("labels_only"))
+ {
+ // Save only the labels.
+ string outputFile = CLI::GetParam<string>("output_file");
+ arma::Mat<size_t> output = trans(assignments);
+ data::Save(outputFile, output);
+ }
+ else
+ {
+ // Convert the assignments to doubles.
+ arma::vec converted(assignments.n_elem);
+ for (size_t i = 0; i < assignments.n_elem; i++)
+ converted(i) = (double) assignments(i);
+
+ dataset.insert_rows(dataset.n_rows, trans(converted));
+
+ // Now save, in the different file.
+ string outputFile = CLI::GetParam<string>("output_file");
+ data::Save(outputFile, dataset);
+ }
+ }
+
+ // Should we write the centroids to a file?
+ if (CLI::HasParam("centroid_file"))
+ data::Save(CLI::GetParam<std::string>("centroid_file"), centroids);
+}
+
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/max_variance_new_cluster.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/kmeans/max_variance_new_cluster.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/max_variance_new_cluster.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,69 +0,0 @@
-/**
- * @file max_variance_new_cluster.hpp
- * @author Ryan Curtin
- *
- * An implementation of the EmptyClusterPolicy policy class for K-Means. When
- * an empty cluster is detected, the point furthest from the centroid of the
- * cluster with maximum variance is taken to be a new cluster.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_KMEANS_MAX_VARIANCE_NEW_CLUSTER_HPP
-#define __MLPACK_METHODS_KMEANS_MAX_VARIANCE_NEW_CLUSTER_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace kmeans {
-
-/**
- * When an empty cluster is detected, this class takes the point furthest from
- * the centroid of the cluster with maximum variance as a new cluster.
- */
-class MaxVarianceNewCluster
-{
- public:
- //! Default constructor required by EmptyClusterPolicy.
- MaxVarianceNewCluster() { }
-
- /**
- * Take the point furthest from the centroid of the cluster with maximum
- * variance to be a new cluster.
- *
- * @tparam MatType Type of data (arma::mat or arma::spmat).
- * @param data Dataset on which clustering is being performed.
- * @param emptyCluster Index of cluster which is empty.
- * @param centroids Centroids of each cluster (one per column).
- * @param clusterCounts Number of points in each cluster.
- * @param assignments Cluster assignments of each point.
- *
- * @return Number of points changed.
- */
- template<typename MatType>
- static size_t EmptyCluster(const MatType& data,
- const size_t emptyCluster,
- const MatType& centroids,
- arma::Col<size_t>& clusterCounts,
- arma::Col<size_t>& assignments);
-};
-
-}; // namespace kmeans
-}; // namespace mlpack
-
-// Include implementation.
-#include "max_variance_new_cluster_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/max_variance_new_cluster.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/kmeans/max_variance_new_cluster.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/max_variance_new_cluster.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/max_variance_new_cluster.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,69 @@
+/**
+ * @file max_variance_new_cluster.hpp
+ * @author Ryan Curtin
+ *
+ * An implementation of the EmptyClusterPolicy policy class for K-Means. When
+ * an empty cluster is detected, the point furthest from the centroid of the
+ * cluster with maximum variance is taken to be a new cluster.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_MAX_VARIANCE_NEW_CLUSTER_HPP
+#define __MLPACK_METHODS_KMEANS_MAX_VARIANCE_NEW_CLUSTER_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace kmeans {
+
+/**
+ * When an empty cluster is detected, this class takes the point furthest from
+ * the centroid of the cluster with maximum variance as a new cluster.
+ */
+class MaxVarianceNewCluster
+{
+ public:
+ //! Default constructor required by EmptyClusterPolicy.
+ MaxVarianceNewCluster() { }
+
+ /**
+ * Take the point furthest from the centroid of the cluster with maximum
+ * variance to be a new cluster.
+ *
+ * @tparam MatType Type of data (arma::mat or arma::spmat).
+ * @param data Dataset on which clustering is being performed.
+ * @param emptyCluster Index of cluster which is empty.
+ * @param centroids Centroids of each cluster (one per column).
+ * @param clusterCounts Number of points in each cluster.
+ * @param assignments Cluster assignments of each point.
+ *
+ * @return Number of points changed.
+ */
+ template<typename MatType>
+ static size_t EmptyCluster(const MatType& data,
+ const size_t emptyCluster,
+ const MatType& centroids,
+ arma::Col<size_t>& clusterCounts,
+ arma::Col<size_t>& assignments);
+};
+
+}; // namespace kmeans
+}; // namespace mlpack
+
+// Include implementation.
+#include "max_variance_new_cluster_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,91 +0,0 @@
-/**
- * @file max_variance_new_cluster_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of MaxVarianceNewCluster class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_KMEANS_MAX_VARIANCE_NEW_CLUSTER_IMPL_HPP
-#define __MLPACK_METHODS_KMEANS_MAX_VARIANCE_NEW_CLUSTER_IMPL_HPP
-
-// Just in case it has not been included.
-#include "max_variance_new_cluster.hpp"
-
-namespace mlpack {
-namespace kmeans {
-
-/**
- * Take action about an empty cluster.
- */
-template<typename MatType>
-size_t MaxVarianceNewCluster::EmptyCluster(const MatType& data,
- const size_t emptyCluster,
- const MatType& centroids,
- arma::Col<size_t>& clusterCounts,
- arma::Col<size_t>& assignments)
-{
- // First, we need to find the cluster with maximum variance (by which I mean
- // the sum of the covariance matrices).
- arma::vec variances;
- variances.zeros(clusterCounts.n_elem); // Start with 0.
-
- // Add the variance of each point's distance away from the cluster. I think
- // this is the sensible thing to do.
- for (size_t i = 0; i < data.n_cols; i++)
- {
- variances[assignments[i]] += arma::as_scalar(
- arma::var(data.col(i) - centroids.col(assignments[i])));
- }
-
- // Now find the cluster with maximum variance.
- arma::uword maxVarCluster;
- variances.max(maxVarCluster);
-
- // Now, inside this cluster, find the point which is furthest away.
- size_t furthestPoint = data.n_cols;
- double maxDistance = 0;
- for (size_t i = 0; i < data.n_cols; i++)
- {
- if (assignments[i] == maxVarCluster)
- {
- double distance = arma::as_scalar(
- arma::var(data.col(i) - centroids.col(maxVarCluster)));
-
- if (distance > maxDistance)
- {
- maxDistance = distance;
- furthestPoint = i;
- }
- }
- }
-
- // Take that point and add it to the empty cluster.
- clusterCounts[maxVarCluster]--;
- clusterCounts[emptyCluster]++;
- assignments[furthestPoint] = emptyCluster;
-
- // Output some debugging information.
- Log::Debug << "Point " << furthestPoint << " assigned to empty cluster " <<
- emptyCluster << ".\n";
-
- return 1; // We only changed one point.
-}
-
-}; // namespace kmeans
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,91 @@
+/**
+ * @file max_variance_new_cluster_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of MaxVarianceNewCluster class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_MAX_VARIANCE_NEW_CLUSTER_IMPL_HPP
+#define __MLPACK_METHODS_KMEANS_MAX_VARIANCE_NEW_CLUSTER_IMPL_HPP
+
+// Just in case it has not been included.
+#include "max_variance_new_cluster.hpp"
+
+namespace mlpack {
+namespace kmeans {
+
+/**
+ * Take action about an empty cluster.
+ */
+template<typename MatType>
+size_t MaxVarianceNewCluster::EmptyCluster(const MatType& data,
+ const size_t emptyCluster,
+ const MatType& centroids,
+ arma::Col<size_t>& clusterCounts,
+ arma::Col<size_t>& assignments)
+{
+ // First, we need to find the cluster with maximum variance (by which I mean
+ // the sum of the covariance matrices).
+ arma::vec variances;
+ variances.zeros(clusterCounts.n_elem); // Start with 0.
+
+ // Add the variance of each point's distance away from the cluster. I think
+ // this is the sensible thing to do.
+ for (size_t i = 0; i < data.n_cols; i++)
+ {
+ variances[assignments[i]] += arma::as_scalar(
+ arma::var(data.col(i) - centroids.col(assignments[i])));
+ }
+
+ // Now find the cluster with maximum variance.
+ arma::uword maxVarCluster;
+ variances.max(maxVarCluster);
+
+ // Now, inside this cluster, find the point which is furthest away.
+ size_t furthestPoint = data.n_cols;
+ double maxDistance = 0;
+ for (size_t i = 0; i < data.n_cols; i++)
+ {
+ if (assignments[i] == maxVarCluster)
+ {
+ double distance = arma::as_scalar(
+ arma::var(data.col(i) - centroids.col(maxVarCluster)));
+
+ if (distance > maxDistance)
+ {
+ maxDistance = distance;
+ furthestPoint = i;
+ }
+ }
+ }
+
+ // Take that point and add it to the empty cluster.
+ clusterCounts[maxVarCluster]--;
+ clusterCounts[emptyCluster]++;
+ assignments[furthestPoint] = emptyCluster;
+
+ // Output some debugging information.
+ Log::Debug << "Point " << furthestPoint << " assigned to empty cluster " <<
+ emptyCluster << ".\n";
+
+ return 1; // We only changed one point.
+}
+
+}; // namespace kmeans
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/random_partition.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/kmeans/random_partition.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/random_partition.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,67 +0,0 @@
-/**
- * @file random_partition.hpp
- * @author Ryan Curtin
- *
- * Very simple partitioner which partitions the data randomly into the number of
- * desired clusters. Used as the default InitialPartitionPolicy for KMeans.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_KMEANS_RANDOM_PARTITION_HPP
-#define __MLPACK_METHODS_KMEANS_RANDOM_PARTITION_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace kmeans {
-
-/**
- * A very simple partitioner which partitions the data randomly into the number
- * of desired clusters. It has no parameters, and so an instance of the class
- * is not even necessary.
- */
-class RandomPartition
-{
- public:
- //! Empty constructor, required by the InitialPartitionPolicy policy.
- RandomPartition() { }
-
- /**
- * Partition the given dataset into the given number of clusters. Assignments
- * are random, and the number of points in each cluster should be equal (or
- * approximately equal).
- *
- * @tparam MatType Type of data (arma::mat or arma::sp_mat).
- * @param data Dataset to partition.
- * @param clusters Number of clusters to split dataset into.
- * @param assignments Vector to store cluster assignments into. Values will
- * be between 0 and (clusters - 1).
- */
- template<typename MatType>
- inline static void Cluster(const MatType& data,
- const size_t clusters,
- arma::Col<size_t>& assignments)
- {
- // Implementation is so simple we'll put it here in the header file.
- assignments = arma::shuffle(arma::linspace<arma::Col<size_t> >(0,
- (clusters - 1), data.n_cols));
- }
-};
-
-};
-};
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/random_partition.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/kmeans/random_partition.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/random_partition.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/random_partition.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,67 @@
+/**
+ * @file random_partition.hpp
+ * @author Ryan Curtin
+ *
+ * Very simple partitioner which partitions the data randomly into the number of
+ * desired clusters. Used as the default InitialPartitionPolicy for KMeans.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_RANDOM_PARTITION_HPP
+#define __MLPACK_METHODS_KMEANS_RANDOM_PARTITION_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace kmeans {
+
+/**
+ * A very simple partitioner which partitions the data randomly into the number
+ * of desired clusters. It has no parameters, and so an instance of the class
+ * is not even necessary.
+ */
+class RandomPartition
+{
+ public:
+ //! Empty constructor, required by the InitialPartitionPolicy policy.
+ RandomPartition() { }
+
+ /**
+ * Partition the given dataset into the given number of clusters. Assignments
+ * are random, and the number of points in each cluster should be equal (or
+ * approximately equal).
+ *
+ * @tparam MatType Type of data (arma::mat or arma::sp_mat).
+ * @param data Dataset to partition.
+ * @param clusters Number of clusters to split dataset into.
+ * @param assignments Vector to store cluster assignments into. Values will
+ * be between 0 and (clusters - 1).
+ */
+ template<typename MatType>
+ inline static void Cluster(const MatType& data,
+ const size_t clusters,
+ arma::Col<size_t>& assignments)
+ {
+ // Implementation is so simple we'll put it here in the header file.
+ assignments = arma::shuffle(arma::linspace<arma::Col<size_t> >(0,
+ (clusters - 1), data.n_cols));
+ }
+};
+
+};
+};
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/refined_start.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/kmeans/refined_start.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/refined_start.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,97 +0,0 @@
-/**
- * @file refined_start.hpp
- * @author Ryan Curtin
- *
- * An implementation of Bradley and Fayyad's "Refining Initial Points for
- * K-Means clustering". This class is meant to provide better initial points
- * for the k-means algorithm.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_KMEANS_REFINED_START_HPP
-#define __MLPACK_METHODS_KMEANS_REFINED_START_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace kmeans {
-
-/**
- * A refined approach for choosing initial points for k-means clustering. This
- * approach runs k-means several times on random subsets of the data, and then
- * clusters those solutions to select refined initial cluster assignments. It
- * is an implementation of the following paper:
- *
- * @inproceedings{bradley1998refining,
- * title={Refining initial points for k-means clustering},
- * author={Bradley, Paul S and Fayyad, Usama M},
- * booktitle={Proceedings of the Fifteenth International Conference on Machine
- * Learning (ICML 1998)},
- * volume={66},
- * year={1998}
- * }
- */
-class RefinedStart
-{
- public:
- /**
- * Create the RefinedStart object, optionally specifying parameters for the
- * number of samplings to perform and the percentage of the dataset to use in
- * each sampling.
- */
- RefinedStart(const size_t samplings = 100,
- const double percentage = 0.02) :
- samplings(samplings), percentage(percentage) { }
-
- /**
- * Partition the given dataset into the given number of clusters according to
- * the random sampling scheme outlined in Bradley and Fayyad's paper.
- *
- * @tparam MatType Type of data (arma::mat or arma::sp_mat).
- * @param data Dataset to partition.
- * @param clusters Number of clusters to split dataset into.
- * @param assignments Vector to store cluster assignments into. Values will
- * be between 0 and (clusters - 1).
- */
- template<typename MatType>
- void Cluster(const MatType& data,
- const size_t clusters,
- arma::Col<size_t>& assignments) const;
-
- //! Get the number of samplings that will be performed.
- size_t Samplings() const { return samplings; }
- //! Modify the number of samplings that will be performed.
- size_t& Samplings() { return samplings; }
-
- //! Get the percentage of the data used by each subsampling.
- double Percentage() const { return percentage; }
- //! Modify the percentage of the data used by each subsampling.
- double& Percentage() { return percentage; }
-
- private:
- //! The number of samplings to perform.
- size_t samplings;
- //! The percentage of the data to use for each subsampling.
- double percentage;
-};
-
-}; // namespace kmeans
-}; // namespace mlpack
-
-// Include implementation.
-#include "refined_start_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/refined_start.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/kmeans/refined_start.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/refined_start.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/refined_start.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,97 @@
+/**
+ * @file refined_start.hpp
+ * @author Ryan Curtin
+ *
+ * An implementation of Bradley and Fayyad's "Refining Initial Points for
+ * K-Means clustering". This class is meant to provide better initial points
+ * for the k-means algorithm.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_REFINED_START_HPP
+#define __MLPACK_METHODS_KMEANS_REFINED_START_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace kmeans {
+
+/**
+ * A refined approach for choosing initial points for k-means clustering. This
+ * approach runs k-means several times on random subsets of the data, and then
+ * clusters those solutions to select refined initial cluster assignments. It
+ * is an implementation of the following paper:
+ *
+ * @inproceedings{bradley1998refining,
+ * title={Refining initial points for k-means clustering},
+ * author={Bradley, Paul S and Fayyad, Usama M},
+ * booktitle={Proceedings of the Fifteenth International Conference on Machine
+ * Learning (ICML 1998)},
+ * volume={66},
+ * year={1998}
+ * }
+ */
+class RefinedStart
+{
+ public:
+ /**
+ * Create the RefinedStart object, optionally specifying parameters for the
+ * number of samplings to perform and the percentage of the dataset to use in
+ * each sampling.
+ */
+ RefinedStart(const size_t samplings = 100,
+ const double percentage = 0.02) :
+ samplings(samplings), percentage(percentage) { }
+
+ /**
+ * Partition the given dataset into the given number of clusters according to
+ * the random sampling scheme outlined in Bradley and Fayyad's paper.
+ *
+ * @tparam MatType Type of data (arma::mat or arma::sp_mat).
+ * @param data Dataset to partition.
+ * @param clusters Number of clusters to split dataset into.
+ * @param assignments Vector to store cluster assignments into. Values will
+ * be between 0 and (clusters - 1).
+ */
+ template<typename MatType>
+ void Cluster(const MatType& data,
+ const size_t clusters,
+ arma::Col<size_t>& assignments) const;
+
+ //! Get the number of samplings that will be performed.
+ size_t Samplings() const { return samplings; }
+ //! Modify the number of samplings that will be performed.
+ size_t& Samplings() { return samplings; }
+
+ //! Get the percentage of the data used by each subsampling.
+ double Percentage() const { return percentage; }
+ //! Modify the percentage of the data used by each subsampling.
+ double& Percentage() { return percentage; }
+
+ private:
+ //! The number of samplings to perform.
+ size_t samplings;
+ //! The percentage of the data to use for each subsampling.
+ double percentage;
+};
+
+}; // namespace kmeans
+}; // namespace mlpack
+
+// Include implementation.
+#include "refined_start_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/refined_start_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/kmeans/refined_start_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/refined_start_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,114 +0,0 @@
-/**
- * @file refined_start_impl.hpp
- * @author Ryan Curtin
- *
- * An implementation of Bradley and Fayyad's "Refining Initial Points for
- * K-Means clustering". This class is meant to provide better initial points
- * for the k-means algorithm.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_KMEANS_REFINED_START_IMPL_HPP
-#define __MLPACK_METHODS_KMEANS_REFINED_START_IMPL_HPP
-
-// In case it hasn't been included yet.
-#include "refined_start.hpp"
-
-namespace mlpack {
-namespace kmeans {
-
-//! Partition the given dataset according to Bradley and Fayyad's algorithm.
-template<typename MatType>
-void RefinedStart::Cluster(const MatType& data,
- const size_t clusters,
- arma::Col<size_t>& assignments) const
-{
- math::RandomSeed(std::time(NULL));
-
- // This will hold the sampled datasets.
- const size_t numPoints = size_t(percentage * data.n_cols);
- MatType sampledData(data.n_rows, numPoints);
- // vector<bool> is packed so each bool is 1 bit.
- std::vector<bool> pointsUsed(data.n_cols, false);
- arma::mat sampledCentroids(data.n_rows, samplings * clusters);
-
- // We will use these objects repeatedly for clustering.
- arma::Col<size_t> sampledAssignments;
- arma::mat centroids;
- KMeans<> kmeans;
-
- for (size_t i = 0; i < samplings; ++i)
- {
- // First, assemble the sampled dataset.
- size_t curSample = 0;
- while (curSample < numPoints)
- {
- // Pick a random point in [0, numPoints).
- size_t sample = (size_t) math::RandInt(data.n_cols);
-
- if (!pointsUsed[sample])
- {
- // This point isn't used yet. So we'll put it in our sample.
- pointsUsed[sample] = true;
- sampledData.col(curSample) = data.col(sample);
- ++curSample;
- }
- }
-
- // Now, using the sampled dataset, run k-means. In the case of an empty
- // cluster, we re-initialize that cluster as the point furthest away from
- // the cluster with maximum variance. This is not *exactly* what the paper
- // implements, but it is quite similar, and we'll call it "good enough".
- kmeans.Cluster(sampledData, clusters, sampledAssignments, centroids);
-
- // Store the sampled centroids.
- sampledCentroids.cols(i * clusters, (i + 1) * clusters - 1) = centroids;
-
- pointsUsed.assign(data.n_cols, false);
- }
-
- // Now, we run k-means on the sampled centroids to get our final clusters.
- kmeans.Cluster(sampledCentroids, clusters, sampledAssignments, centroids);
-
- // Turn the final centroids into assignments.
- assignments.set_size(data.n_cols);
- for (size_t i = 0; i < data.n_cols; ++i)
- {
- // Find the closest centroid to this point.
- double minDistance = std::numeric_limits<double>::infinity();
- size_t closestCluster = clusters;
-
- for (size_t j = 0; j < clusters; ++j)
- {
- const double distance = kmeans.Metric().Evaluate(data.col(i),
- centroids.col(j));
-
- if (distance < minDistance)
- {
- minDistance = distance;
- closestCluster = j;
- }
- }
-
- // Assign the point to its closest cluster.
- assignments[i] = closestCluster;
- }
-}
-
-}; // namespace kmeans
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/refined_start_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/kmeans/refined_start_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/refined_start_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/kmeans/refined_start_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,114 @@
+/**
+ * @file refined_start_impl.hpp
+ * @author Ryan Curtin
+ *
+ * An implementation of Bradley and Fayyad's "Refining Initial Points for
+ * K-Means clustering". This class is meant to provide better initial points
+ * for the k-means algorithm.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_REFINED_START_IMPL_HPP
+#define __MLPACK_METHODS_KMEANS_REFINED_START_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "refined_start.hpp"
+
+namespace mlpack {
+namespace kmeans {
+
+//! Partition the given dataset according to Bradley and Fayyad's algorithm.
+template<typename MatType>
+void RefinedStart::Cluster(const MatType& data,
+ const size_t clusters,
+ arma::Col<size_t>& assignments) const
+{
+ math::RandomSeed(std::time(NULL));
+
+ // This will hold the sampled datasets.
+ const size_t numPoints = size_t(percentage * data.n_cols);
+ MatType sampledData(data.n_rows, numPoints);
+ // vector<bool> is packed so each bool is 1 bit.
+ std::vector<bool> pointsUsed(data.n_cols, false);
+ arma::mat sampledCentroids(data.n_rows, samplings * clusters);
+
+ // We will use these objects repeatedly for clustering.
+ arma::Col<size_t> sampledAssignments;
+ arma::mat centroids;
+ KMeans<> kmeans;
+
+ for (size_t i = 0; i < samplings; ++i)
+ {
+ // First, assemble the sampled dataset.
+ size_t curSample = 0;
+ while (curSample < numPoints)
+ {
+ // Pick a random point in [0, numPoints).
+ size_t sample = (size_t) math::RandInt(data.n_cols);
+
+ if (!pointsUsed[sample])
+ {
+ // This point isn't used yet. So we'll put it in our sample.
+ pointsUsed[sample] = true;
+ sampledData.col(curSample) = data.col(sample);
+ ++curSample;
+ }
+ }
+
+ // Now, using the sampled dataset, run k-means. In the case of an empty
+ // cluster, we re-initialize that cluster as the point furthest away from
+ // the cluster with maximum variance. This is not *exactly* what the paper
+ // implements, but it is quite similar, and we'll call it "good enough".
+ kmeans.Cluster(sampledData, clusters, sampledAssignments, centroids);
+
+ // Store the sampled centroids.
+ sampledCentroids.cols(i * clusters, (i + 1) * clusters - 1) = centroids;
+
+ pointsUsed.assign(data.n_cols, false);
+ }
+
+ // Now, we run k-means on the sampled centroids to get our final clusters.
+ kmeans.Cluster(sampledCentroids, clusters, sampledAssignments, centroids);
+
+ // Turn the final centroids into assignments.
+ assignments.set_size(data.n_cols);
+ for (size_t i = 0; i < data.n_cols; ++i)
+ {
+ // Find the closest centroid to this point.
+ double minDistance = std::numeric_limits<double>::infinity();
+ size_t closestCluster = clusters;
+
+ for (size_t j = 0; j < clusters; ++j)
+ {
+ const double distance = kmeans.Metric().Evaluate(data.col(i),
+ centroids.col(j));
+
+ if (distance < minDistance)
+ {
+ minDistance = distance;
+ closestCluster = j;
+ }
+ }
+
+ // Assign the point to its closest cluster.
+ assignments[i] = closestCluster;
+ }
+}
+
+}; // namespace kmeans
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lars/lars.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/lars/lars.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lars/lars.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,448 +0,0 @@
-/**
- * @file lars.cpp
- * @author Nishant Mehta (niche)
- *
- * Implementation of LARS and LASSO.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "lars.hpp"
-
-using namespace mlpack;
-using namespace mlpack::regression;
-
-LARS::LARS(const bool useCholesky,
- const double lambda1,
- const double lambda2,
- const double tolerance) :
- matGram(matGramInternal),
- useCholesky(useCholesky),
- lasso((lambda1 != 0)),
- lambda1(lambda1),
- elasticNet((lambda1 != 0) && (lambda2 != 0)),
- lambda2(lambda2),
- tolerance(tolerance)
-{ /* Nothing left to do. */ }
-
-LARS::LARS(const bool useCholesky,
- const arma::mat& gramMatrix,
- const double lambda1,
- const double lambda2,
- const double tolerance) :
- matGram(gramMatrix),
- useCholesky(useCholesky),
- lasso((lambda1 != 0)),
- lambda1(lambda1),
- elasticNet((lambda1 != 0) && (lambda2 != 0)),
- lambda2(lambda2),
- tolerance(tolerance)
-{ /* Nothing left to do */ }
-
-void LARS::Regress(const arma::mat& matX,
- const arma::vec& y,
- arma::vec& beta,
- const bool transposeData)
-{
- Timer::Start("lars_regression");
-
- // This matrix may end up holding the transpose -- if necessary.
- arma::mat dataTrans;
- // dataRef is row-major.
- const arma::mat& dataRef = (transposeData ? dataTrans : matX);
- if (transposeData)
- dataTrans = trans(matX);
-
- // Compute X' * y.
- arma::vec vecXTy = trans(dataRef) * y;
-
- // Set up active set variables. In the beginning, the active set has size 0
- // (all dimensions are inactive).
- isActive.resize(dataRef.n_cols, false);
-
- // Initialize yHat and beta.
- beta = arma::zeros(dataRef.n_cols);
- arma::vec yHat = arma::zeros(dataRef.n_rows);
- arma::vec yHatDirection = arma::vec(dataRef.n_rows);
-
- bool lassocond = false;
-
- // Compute the initial maximum correlation among all dimensions.
- arma::vec corr = vecXTy;
- double maxCorr = 0;
- size_t changeInd = 0;
- for (size_t i = 0; i < vecXTy.n_elem; ++i)
- {
- if (fabs(corr(i)) > maxCorr)
- {
- maxCorr = fabs(corr(i));
- changeInd = i;
- }
- }
-
- betaPath.push_back(beta);
- lambdaPath.push_back(maxCorr);
-
- // If the maximum correlation is too small, there is no reason to continue.
- if (maxCorr < lambda1)
- {
- lambdaPath[0] = lambda1;
- Timer::Stop("lars_regression");
- return;
- }
-
- // Compute the Gram matrix. If this is the elastic net problem, we will add
- // lambda2 * I_n to the matrix.
- if (matGram.n_elem == 0)
- {
- // In this case, matGram should reference matGramInternal.
- matGramInternal = trans(dataRef) * dataRef;
-
- if (elasticNet && !useCholesky)
- matGramInternal += lambda2 * arma::eye(dataRef.n_cols, dataRef.n_cols);
- }
-
- // Main loop.
- while ((activeSet.size() < dataRef.n_cols) && (maxCorr > tolerance))
- {
- // Compute the maximum correlation among inactive dimensions.
- maxCorr = 0;
- for (size_t i = 0; i < dataRef.n_cols; i++)
- {
- if ((!isActive[i]) && (fabs(corr(i)) > maxCorr))
- {
- maxCorr = fabs(corr(i));
- changeInd = i;
- }
- }
-
- if (!lassocond)
- {
- if (useCholesky)
- {
- // vec newGramCol = vec(activeSet.size());
- // for (size_t i = 0; i < activeSet.size(); i++)
- // {
- // newGramCol[i] = dot(matX.col(activeSet[i]), matX.col(changeInd));
- // }
- // This is equivalent to the above 5 lines.
- arma::vec newGramCol = matGram.elem(changeInd * dataRef.n_cols +
- arma::conv_to<arma::uvec>::from(activeSet));
-
- CholeskyInsert(matGram(changeInd, changeInd), newGramCol);
- }
-
- // Add variable to active set.
- Activate(changeInd);
- }
-
- // Compute signs of correlations.
- arma::vec s = arma::vec(activeSet.size());
- for (size_t i = 0; i < activeSet.size(); i++)
- s(i) = corr(activeSet[i]) / fabs(corr(activeSet[i]));
-
- // Compute the "equiangular" direction in parameter space (betaDirection).
- // We use quotes because in the case of non-unit norm variables, this need
- // not be equiangular.
- arma::vec unnormalizedBetaDirection;
- double normalization;
- arma::vec betaDirection;
- if (useCholesky)
- {
- /**
- * Note that:
- * R^T R % S^T % S = (R % S)^T (R % S)
- * Now, for 1 the ones vector:
- * inv( (R % S)^T (R % S) ) 1
- * = inv(R % S) inv((R % S)^T) 1
- * = inv(R % S) Solve((R % S)^T, 1)
- * = inv(R % S) Solve(R^T, s)
- * = Solve(R % S, Solve(R^T, s)
- * = s % Solve(R, Solve(R^T, s))
- */
- unnormalizedBetaDirection = solve(trimatu(matUtriCholFactor),
- solve(trimatl(trans(matUtriCholFactor)), s));
-
- normalization = 1.0 / sqrt(dot(s, unnormalizedBetaDirection));
- betaDirection = normalization * unnormalizedBetaDirection;
- }
- else
- {
- arma::mat matGramActive = arma::mat(activeSet.size(), activeSet.size());
- for (size_t i = 0; i < activeSet.size(); i++)
- for (size_t j = 0; j < activeSet.size(); j++)
- matGramActive(i, j) = matGram(activeSet[i], activeSet[j]);
-
- arma::mat matS = s * arma::ones<arma::mat>(1, activeSet.size());
- unnormalizedBetaDirection = solve(matGramActive % trans(matS) % matS,
- arma::ones<arma::mat>(activeSet.size(), 1));
- normalization = 1.0 / sqrt(sum(unnormalizedBetaDirection));
- betaDirection = normalization * unnormalizedBetaDirection % s;
- }
-
- // compute "equiangular" direction in output space
- ComputeYHatDirection(dataRef, betaDirection, yHatDirection);
-
- double gamma = maxCorr / normalization;
-
- // If not all variables are active.
- if (activeSet.size() < dataRef.n_cols)
- {
- // Compute correlations with direction.
- for (size_t ind = 0; ind < dataRef.n_cols; ind++)
- {
- if (isActive[ind])
- continue;
-
- double dirCorr = dot(dataRef.col(ind), yHatDirection);
- double val1 = (maxCorr - corr(ind)) / (normalization - dirCorr);
- double val2 = (maxCorr + corr(ind)) / (normalization + dirCorr);
- if ((val1 > 0) && (val1 < gamma))
- gamma = val1;
- if ((val2 > 0) && (val2 < gamma))
- gamma = val2;
- }
- }
-
- // Bound gamma according to LASSO.
- if (lasso)
- {
- lassocond = false;
- double lassoboundOnGamma = DBL_MAX;
- size_t activeIndToKickOut = -1;
-
- for (size_t i = 0; i < activeSet.size(); i++)
- {
- double val = -beta(activeSet[i]) / betaDirection(i);
- if ((val > 0) && (val < lassoboundOnGamma))
- {
- lassoboundOnGamma = val;
- activeIndToKickOut = i;
- }
- }
-
- if (lassoboundOnGamma < gamma)
- {
- gamma = lassoboundOnGamma;
- lassocond = true;
- changeInd = activeIndToKickOut;
- }
- }
-
- // Update the prediction.
- yHat += gamma * yHatDirection;
-
- // Update the estimator.
- for (size_t i = 0; i < activeSet.size(); i++)
- {
- beta(activeSet[i]) += gamma * betaDirection(i);
- }
-
- // Sanity check to make sure the kicked out dimension is actually zero.
- if (lassocond)
- {
- if (beta(activeSet[changeInd]) != 0)
- beta(activeSet[changeInd]) = 0;
- }
-
- betaPath.push_back(beta);
-
- if (lassocond)
- {
- // Index is in position changeInd in activeSet.
- if (useCholesky)
- CholeskyDelete(changeInd);
-
- Deactivate(changeInd);
- }
-
- corr = vecXTy - trans(dataRef) * yHat;
- if (elasticNet)
- corr -= lambda2 * beta;
-
- double curLambda = 0;
- for (size_t i = 0; i < activeSet.size(); i++)
- curLambda += fabs(corr(activeSet[i]));
-
- curLambda /= ((double) activeSet.size());
-
- lambdaPath.push_back(curLambda);
-
- // Time to stop for LASSO?
- if (lasso)
- {
- if (curLambda <= lambda1)
- {
- InterpolateBeta();
- break;
- }
- }
- }
-
- // Unfortunate copy...
- beta = betaPath.back();
-
- Timer::Stop("lars_regression");
-}
-
-// Private functions.
-void LARS::Deactivate(const size_t activeVarInd)
-{
- isActive[activeSet[activeVarInd]] = false;
- activeSet.erase(activeSet.begin() + activeVarInd);
-}
-
-void LARS::Activate(const size_t varInd)
-{
- isActive[varInd] = true;
- activeSet.push_back(varInd);
-}
-
-void LARS::ComputeYHatDirection(const arma::mat& matX,
- const arma::vec& betaDirection,
- arma::vec& yHatDirection)
-{
- yHatDirection.fill(0);
- for (size_t i = 0; i < activeSet.size(); i++)
- yHatDirection += betaDirection(i) * matX.col(activeSet[i]);
-}
-
-void LARS::InterpolateBeta()
-{
- int pathLength = betaPath.size();
-
- // interpolate beta and stop
- double ultimateLambda = lambdaPath[pathLength - 1];
- double penultimateLambda = lambdaPath[pathLength - 2];
- double interp = (penultimateLambda - lambda1)
- / (penultimateLambda - ultimateLambda);
-
- betaPath[pathLength - 1] = (1 - interp) * (betaPath[pathLength - 2])
- + interp * betaPath[pathLength - 1];
-
- lambdaPath[pathLength - 1] = lambda1;
-}
-
-void LARS::CholeskyInsert(const arma::vec& newX, const arma::mat& X)
-{
- if (matUtriCholFactor.n_rows == 0)
- {
- matUtriCholFactor = arma::mat(1, 1);
-
- if (elasticNet)
- matUtriCholFactor(0, 0) = sqrt(dot(newX, newX) + lambda2);
- else
- matUtriCholFactor(0, 0) = norm(newX, 2);
- }
- else
- {
- arma::vec newGramCol = trans(X) * newX;
- CholeskyInsert(dot(newX, newX), newGramCol);
- }
-}
-
-void LARS::CholeskyInsert(double sqNormNewX, const arma::vec& newGramCol)
-{
- int n = matUtriCholFactor.n_rows;
-
- if (n == 0)
- {
- matUtriCholFactor = arma::mat(1, 1);
-
- if (elasticNet)
- matUtriCholFactor(0, 0) = sqrt(sqNormNewX + lambda2);
- else
- matUtriCholFactor(0, 0) = sqrt(sqNormNewX);
- }
- else
- {
- arma::mat matNewR = arma::mat(n + 1, n + 1);
-
- if (elasticNet)
- sqNormNewX += lambda2;
-
- arma::vec matUtriCholFactork = solve(trimatl(trans(matUtriCholFactor)),
- newGramCol);
-
- matNewR(arma::span(0, n - 1), arma::span(0, n - 1)) = matUtriCholFactor;
- matNewR(arma::span(0, n - 1), n) = matUtriCholFactork;
- matNewR(n, arma::span(0, n - 1)).fill(0.0);
- matNewR(n, n) = sqrt(sqNormNewX - dot(matUtriCholFactork,
- matUtriCholFactork));
-
- matUtriCholFactor = matNewR;
- }
-}
-
-void LARS::GivensRotate(const arma::vec::fixed<2>& x,
- arma::vec::fixed<2>& rotatedX,
- arma::mat& matG)
-{
- if (x(1) == 0)
- {
- matG = arma::eye(2, 2);
- rotatedX = x;
- }
- else
- {
- double r = norm(x, 2);
- matG = arma::mat(2, 2);
-
- double scaledX1 = x(0) / r;
- double scaledX2 = x(1) / r;
-
- matG(0, 0) = scaledX1;
- matG(1, 0) = -scaledX2;
- matG(0, 1) = scaledX2;
- matG(1, 1) = scaledX1;
-
- rotatedX = arma::vec(2);
- rotatedX(0) = r;
- rotatedX(1) = 0;
- }
-}
-
-void LARS::CholeskyDelete(const size_t colToKill)
-{
- size_t n = matUtriCholFactor.n_rows;
-
- if (colToKill == (n - 1))
- {
- matUtriCholFactor = matUtriCholFactor(arma::span(0, n - 2),
- arma::span(0, n - 2));
- }
- else
- {
- matUtriCholFactor.shed_col(colToKill); // remove column colToKill
- n--;
-
- for (size_t k = colToKill; k < n; k++)
- {
- arma::mat matG;
- arma::vec::fixed<2> rotatedVec;
- GivensRotate(matUtriCholFactor(arma::span(k, k + 1), k), rotatedVec,
- matG);
- matUtriCholFactor(arma::span(k, k + 1), k) = rotatedVec;
- if (k < n - 1)
- {
- matUtriCholFactor(arma::span(k, k + 1), arma::span(k + 1, n - 1)) =
- matG * matUtriCholFactor(arma::span(k, k + 1),
- arma::span(k + 1, n - 1));
- }
- }
-
- matUtriCholFactor.shed_row(n);
- }
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lars/lars.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/lars/lars.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lars/lars.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lars/lars.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,448 @@
+/**
+ * @file lars.cpp
+ * @author Nishant Mehta (niche)
+ *
+ * Implementation of LARS and LASSO.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "lars.hpp"
+
+using namespace mlpack;
+using namespace mlpack::regression;
+
+LARS::LARS(const bool useCholesky,
+ const double lambda1,
+ const double lambda2,
+ const double tolerance) :
+ matGram(matGramInternal),
+ useCholesky(useCholesky),
+ lasso((lambda1 != 0)),
+ lambda1(lambda1),
+ elasticNet((lambda1 != 0) && (lambda2 != 0)),
+ lambda2(lambda2),
+ tolerance(tolerance)
+{ /* Nothing left to do. */ }
+
+LARS::LARS(const bool useCholesky,
+ const arma::mat& gramMatrix,
+ const double lambda1,
+ const double lambda2,
+ const double tolerance) :
+ matGram(gramMatrix),
+ useCholesky(useCholesky),
+ lasso((lambda1 != 0)),
+ lambda1(lambda1),
+ elasticNet((lambda1 != 0) && (lambda2 != 0)),
+ lambda2(lambda2),
+ tolerance(tolerance)
+{ /* Nothing left to do */ }
+
+void LARS::Regress(const arma::mat& matX,
+ const arma::vec& y,
+ arma::vec& beta,
+ const bool transposeData)
+{
+ Timer::Start("lars_regression");
+
+ // This matrix may end up holding the transpose -- if necessary.
+ arma::mat dataTrans;
+ // dataRef is row-major.
+ const arma::mat& dataRef = (transposeData ? dataTrans : matX);
+ if (transposeData)
+ dataTrans = trans(matX);
+
+ // Compute X' * y.
+ arma::vec vecXTy = trans(dataRef) * y;
+
+ // Set up active set variables. In the beginning, the active set has size 0
+ // (all dimensions are inactive).
+ isActive.resize(dataRef.n_cols, false);
+
+ // Initialize yHat and beta.
+ beta = arma::zeros(dataRef.n_cols);
+ arma::vec yHat = arma::zeros(dataRef.n_rows);
+ arma::vec yHatDirection = arma::vec(dataRef.n_rows);
+
+ bool lassocond = false;
+
+ // Compute the initial maximum correlation among all dimensions.
+ arma::vec corr = vecXTy;
+ double maxCorr = 0;
+ size_t changeInd = 0;
+ for (size_t i = 0; i < vecXTy.n_elem; ++i)
+ {
+ if (fabs(corr(i)) > maxCorr)
+ {
+ maxCorr = fabs(corr(i));
+ changeInd = i;
+ }
+ }
+
+ betaPath.push_back(beta);
+ lambdaPath.push_back(maxCorr);
+
+ // If the maximum correlation is too small, there is no reason to continue.
+ if (maxCorr < lambda1)
+ {
+ lambdaPath[0] = lambda1;
+ Timer::Stop("lars_regression");
+ return;
+ }
+
+ // Compute the Gram matrix. If this is the elastic net problem, we will add
+ // lambda2 * I_n to the matrix.
+ if (matGram.n_elem == 0)
+ {
+ // In this case, matGram should reference matGramInternal.
+ matGramInternal = trans(dataRef) * dataRef;
+
+ if (elasticNet && !useCholesky)
+ matGramInternal += lambda2 * arma::eye(dataRef.n_cols, dataRef.n_cols);
+ }
+
+ // Main loop.
+ while ((activeSet.size() < dataRef.n_cols) && (maxCorr > tolerance))
+ {
+ // Compute the maximum correlation among inactive dimensions.
+ maxCorr = 0;
+ for (size_t i = 0; i < dataRef.n_cols; i++)
+ {
+ if ((!isActive[i]) && (fabs(corr(i)) > maxCorr))
+ {
+ maxCorr = fabs(corr(i));
+ changeInd = i;
+ }
+ }
+
+ if (!lassocond)
+ {
+ if (useCholesky)
+ {
+ // vec newGramCol = vec(activeSet.size());
+ // for (size_t i = 0; i < activeSet.size(); i++)
+ // {
+ // newGramCol[i] = dot(matX.col(activeSet[i]), matX.col(changeInd));
+ // }
+ // This is equivalent to the above 5 lines.
+ arma::vec newGramCol = matGram.elem(changeInd * dataRef.n_cols +
+ arma::conv_to<arma::uvec>::from(activeSet));
+
+ CholeskyInsert(matGram(changeInd, changeInd), newGramCol);
+ }
+
+ // Add variable to active set.
+ Activate(changeInd);
+ }
+
+ // Compute signs of correlations.
+ arma::vec s = arma::vec(activeSet.size());
+ for (size_t i = 0; i < activeSet.size(); i++)
+ s(i) = corr(activeSet[i]) / fabs(corr(activeSet[i]));
+
+ // Compute the "equiangular" direction in parameter space (betaDirection).
+ // We use quotes because in the case of non-unit norm variables, this need
+ // not be equiangular.
+ arma::vec unnormalizedBetaDirection;
+ double normalization;
+ arma::vec betaDirection;
+ if (useCholesky)
+ {
+ /**
+ * Note that:
+ * R^T R % S^T % S = (R % S)^T (R % S)
+ * Now, for 1 the ones vector:
+ * inv( (R % S)^T (R % S) ) 1
+ * = inv(R % S) inv((R % S)^T) 1
+ * = inv(R % S) Solve((R % S)^T, 1)
+ * = inv(R % S) Solve(R^T, s)
+ * = Solve(R % S, Solve(R^T, s)
+ * = s % Solve(R, Solve(R^T, s))
+ */
+ unnormalizedBetaDirection = solve(trimatu(matUtriCholFactor),
+ solve(trimatl(trans(matUtriCholFactor)), s));
+
+ normalization = 1.0 / sqrt(dot(s, unnormalizedBetaDirection));
+ betaDirection = normalization * unnormalizedBetaDirection;
+ }
+ else
+ {
+ arma::mat matGramActive = arma::mat(activeSet.size(), activeSet.size());
+ for (size_t i = 0; i < activeSet.size(); i++)
+ for (size_t j = 0; j < activeSet.size(); j++)
+ matGramActive(i, j) = matGram(activeSet[i], activeSet[j]);
+
+ arma::mat matS = s * arma::ones<arma::mat>(1, activeSet.size());
+ unnormalizedBetaDirection = solve(matGramActive % trans(matS) % matS,
+ arma::ones<arma::mat>(activeSet.size(), 1));
+ normalization = 1.0 / sqrt(sum(unnormalizedBetaDirection));
+ betaDirection = normalization * unnormalizedBetaDirection % s;
+ }
+
+ // compute "equiangular" direction in output space
+ ComputeYHatDirection(dataRef, betaDirection, yHatDirection);
+
+ double gamma = maxCorr / normalization;
+
+ // If not all variables are active.
+ if (activeSet.size() < dataRef.n_cols)
+ {
+ // Compute correlations with direction.
+ for (size_t ind = 0; ind < dataRef.n_cols; ind++)
+ {
+ if (isActive[ind])
+ continue;
+
+ double dirCorr = dot(dataRef.col(ind), yHatDirection);
+ double val1 = (maxCorr - corr(ind)) / (normalization - dirCorr);
+ double val2 = (maxCorr + corr(ind)) / (normalization + dirCorr);
+ if ((val1 > 0) && (val1 < gamma))
+ gamma = val1;
+ if ((val2 > 0) && (val2 < gamma))
+ gamma = val2;
+ }
+ }
+
+ // Bound gamma according to LASSO.
+ if (lasso)
+ {
+ lassocond = false;
+ double lassoboundOnGamma = DBL_MAX;
+ size_t activeIndToKickOut = -1;
+
+ for (size_t i = 0; i < activeSet.size(); i++)
+ {
+ double val = -beta(activeSet[i]) / betaDirection(i);
+ if ((val > 0) && (val < lassoboundOnGamma))
+ {
+ lassoboundOnGamma = val;
+ activeIndToKickOut = i;
+ }
+ }
+
+ if (lassoboundOnGamma < gamma)
+ {
+ gamma = lassoboundOnGamma;
+ lassocond = true;
+ changeInd = activeIndToKickOut;
+ }
+ }
+
+ // Update the prediction.
+ yHat += gamma * yHatDirection;
+
+ // Update the estimator.
+ for (size_t i = 0; i < activeSet.size(); i++)
+ {
+ beta(activeSet[i]) += gamma * betaDirection(i);
+ }
+
+ // Sanity check to make sure the kicked out dimension is actually zero.
+ if (lassocond)
+ {
+ if (beta(activeSet[changeInd]) != 0)
+ beta(activeSet[changeInd]) = 0;
+ }
+
+ betaPath.push_back(beta);
+
+ if (lassocond)
+ {
+ // Index is in position changeInd in activeSet.
+ if (useCholesky)
+ CholeskyDelete(changeInd);
+
+ Deactivate(changeInd);
+ }
+
+ corr = vecXTy - trans(dataRef) * yHat;
+ if (elasticNet)
+ corr -= lambda2 * beta;
+
+ double curLambda = 0;
+ for (size_t i = 0; i < activeSet.size(); i++)
+ curLambda += fabs(corr(activeSet[i]));
+
+ curLambda /= ((double) activeSet.size());
+
+ lambdaPath.push_back(curLambda);
+
+ // Time to stop for LASSO?
+ if (lasso)
+ {
+ if (curLambda <= lambda1)
+ {
+ InterpolateBeta();
+ break;
+ }
+ }
+ }
+
+ // Unfortunate copy...
+ beta = betaPath.back();
+
+ Timer::Stop("lars_regression");
+}
+
+// Private functions.
+void LARS::Deactivate(const size_t activeVarInd)
+{
+ isActive[activeSet[activeVarInd]] = false;
+ activeSet.erase(activeSet.begin() + activeVarInd);
+}
+
+void LARS::Activate(const size_t varInd)
+{
+ isActive[varInd] = true;
+ activeSet.push_back(varInd);
+}
+
+void LARS::ComputeYHatDirection(const arma::mat& matX,
+ const arma::vec& betaDirection,
+ arma::vec& yHatDirection)
+{
+ yHatDirection.fill(0);
+ for (size_t i = 0; i < activeSet.size(); i++)
+ yHatDirection += betaDirection(i) * matX.col(activeSet[i]);
+}
+
+void LARS::InterpolateBeta()
+{
+ int pathLength = betaPath.size();
+
+ // interpolate beta and stop
+ double ultimateLambda = lambdaPath[pathLength - 1];
+ double penultimateLambda = lambdaPath[pathLength - 2];
+ double interp = (penultimateLambda - lambda1)
+ / (penultimateLambda - ultimateLambda);
+
+ betaPath[pathLength - 1] = (1 - interp) * (betaPath[pathLength - 2])
+ + interp * betaPath[pathLength - 1];
+
+ lambdaPath[pathLength - 1] = lambda1;
+}
+
+void LARS::CholeskyInsert(const arma::vec& newX, const arma::mat& X)
+{
+ if (matUtriCholFactor.n_rows == 0)
+ {
+ matUtriCholFactor = arma::mat(1, 1);
+
+ if (elasticNet)
+ matUtriCholFactor(0, 0) = sqrt(dot(newX, newX) + lambda2);
+ else
+ matUtriCholFactor(0, 0) = norm(newX, 2);
+ }
+ else
+ {
+ arma::vec newGramCol = trans(X) * newX;
+ CholeskyInsert(dot(newX, newX), newGramCol);
+ }
+}
+
+void LARS::CholeskyInsert(double sqNormNewX, const arma::vec& newGramCol)
+{
+ int n = matUtriCholFactor.n_rows;
+
+ if (n == 0)
+ {
+ matUtriCholFactor = arma::mat(1, 1);
+
+ if (elasticNet)
+ matUtriCholFactor(0, 0) = sqrt(sqNormNewX + lambda2);
+ else
+ matUtriCholFactor(0, 0) = sqrt(sqNormNewX);
+ }
+ else
+ {
+ arma::mat matNewR = arma::mat(n + 1, n + 1);
+
+ if (elasticNet)
+ sqNormNewX += lambda2;
+
+ arma::vec matUtriCholFactork = solve(trimatl(trans(matUtriCholFactor)),
+ newGramCol);
+
+ matNewR(arma::span(0, n - 1), arma::span(0, n - 1)) = matUtriCholFactor;
+ matNewR(arma::span(0, n - 1), n) = matUtriCholFactork;
+ matNewR(n, arma::span(0, n - 1)).fill(0.0);
+ matNewR(n, n) = sqrt(sqNormNewX - dot(matUtriCholFactork,
+ matUtriCholFactork));
+
+ matUtriCholFactor = matNewR;
+ }
+}
+
+void LARS::GivensRotate(const arma::vec::fixed<2>& x,
+ arma::vec::fixed<2>& rotatedX,
+ arma::mat& matG)
+{
+ if (x(1) == 0)
+ {
+ matG = arma::eye(2, 2);
+ rotatedX = x;
+ }
+ else
+ {
+ double r = norm(x, 2);
+ matG = arma::mat(2, 2);
+
+ double scaledX1 = x(0) / r;
+ double scaledX2 = x(1) / r;
+
+ matG(0, 0) = scaledX1;
+ matG(1, 0) = -scaledX2;
+ matG(0, 1) = scaledX2;
+ matG(1, 1) = scaledX1;
+
+ rotatedX = arma::vec(2);
+ rotatedX(0) = r;
+ rotatedX(1) = 0;
+ }
+}
+
+void LARS::CholeskyDelete(const size_t colToKill)
+{
+ size_t n = matUtriCholFactor.n_rows;
+
+ if (colToKill == (n - 1))
+ {
+ matUtriCholFactor = matUtriCholFactor(arma::span(0, n - 2),
+ arma::span(0, n - 2));
+ }
+ else
+ {
+ matUtriCholFactor.shed_col(colToKill); // remove column colToKill
+ n--;
+
+ for (size_t k = colToKill; k < n; k++)
+ {
+ arma::mat matG;
+ arma::vec::fixed<2> rotatedVec;
+ GivensRotate(matUtriCholFactor(arma::span(k, k + 1), k), rotatedVec,
+ matG);
+ matUtriCholFactor(arma::span(k, k + 1), k) = rotatedVec;
+ if (k < n - 1)
+ {
+ matUtriCholFactor(arma::span(k, k + 1), arma::span(k + 1, n - 1)) =
+ matG * matUtriCholFactor(arma::span(k, k + 1),
+ arma::span(k + 1, n - 1));
+ }
+ }
+
+ matUtriCholFactor.shed_row(n);
+ }
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lars/lars.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/lars/lars.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lars/lars.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,243 +0,0 @@
-/**
- * @file lars.hpp
- * @author Nishant Mehta (niche)
- *
- * Definition of the LARS class, which performs Least Angle Regression and the
- * LASSO.
- *
- * Only minor modifications of LARS are necessary to handle the constrained
- * version of the problem:
- *
- * \f[
- * \min_{\beta} 0.5 || X \beta - y ||_2^2 + 0.5 \lambda_2 || \beta ||_2^2
- * \f]
- * subject to \f$ ||\beta||_1 <= \tau \f$
- *
- * Although this option currently is not implemented, it will be implemented
- * very soon.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_LARS_LARS_HPP
-#define __MLPACK_METHODS_LARS_LARS_HPP
-
-#include <armadillo>
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace regression {
-
-// beta is the estimator
-// yHat is the prediction from the current estimator
-
-/**
- * An implementation of LARS, a stage-wise homotopy-based algorithm for
- * l1-regularized linear regression (LASSO) and l1+l2 regularized linear
- * regression (Elastic Net).
- *
- * Let \f$ X \f$ be a matrix where each row is a point and each column is a
- * dimension and let \f$ y \f$ be a vector of responses.
- *
- * The Elastic Net problem is to solve
- *
- * \f[ \min_{\beta} 0.5 || X \beta - y ||_2^2 + \lambda_1 || \beta ||_1 +
- * 0.5 \lambda_2 || \beta ||_2^2 \f]
- *
- * where \f$ \beta \f$ is the vector of regression coefficients.
- *
- * If \f$ \lambda_1 > 0 \f$ and \f$ \lambda_2 = 0 \f$, the problem is the LASSO.
- * If \f$ \lambda_1 > 0 \f$ and \f$ \lambda_2 > 0 \f$, the problem is the
- * elastic net.
- * If \f$ \lambda_1 = 0 \f$ and \f$ \lambda_2 > 0 \f$, the problem is ridge
- * regression.
- * If \f$ \lambda_1 = 0 \f$ and \f$ \lambda_2 = 0 \f$, the problem is
- * unregularized linear regression.
- *
- * Note: This algorithm is not recommended for use (in terms of efficiency)
- * when \f$ \lambda_1 \f$ = 0.
- *
- * For more details, see the following papers:
- *
- * @code
- * @article{efron2004least,
- * title={Least angle regression},
- * author={Efron, B. and Hastie, T. and Johnstone, I. and Tibshirani, R.},
- * journal={The Annals of statistics},
- * volume={32},
- * number={2},
- * pages={407--499},
- * year={2004},
- * publisher={Institute of Mathematical Statistics}
- * }
- * @endcode
- *
- * @code
- * @article{zou2005regularization,
- * title={Regularization and variable selection via the elastic net},
- * author={Zou, H. and Hastie, T.},
- * journal={Journal of the Royal Statistical Society Series B},
- * volume={67},
- * number={2},
- * pages={301--320},
- * year={2005},
- * publisher={Royal Statistical Society}
- * }
- * @endcode
- */
-class LARS
-{
- public:
- /**
- * Set the parameters to LARS. Both lambda1 and lambda2 default to 0.
- *
- * @param useCholesky Whether or not to use Cholesky decomposition when
- * solving linear system (as opposed to using the full Gram matrix).
- * @param lambda1 Regularization parameter for l1-norm penalty.
- * @param lambda2 Regularization parameter for l2-norm penalty.
- * @param tolerance Run until the maximum correlation of elements in (X^T y)
- * is less than this.
- */
- LARS(const bool useCholesky,
- const double lambda1 = 0.0,
- const double lambda2 = 0.0,
- const double tolerance = 1e-16);
-
- /**
- * Set the parameters to LARS, and pass in a precalculated Gram matrix. Both
- * lambda1 and lambda2 default to 0.
- *
- * @param useCholesky Whether or not to use Cholesky decomposition when
- * solving linear system (as opposed to using the full Gram matrix).
- * @param gramMatrix Gram matrix.
- * @param lambda1 Regularization parameter for l1-norm penalty.
- * @param lambda2 Regularization parameter for l2-norm penalty.
- * @param tolerance Run until the maximum correlation of elements in (X^T y)
- * is less than this.
- */
- LARS(const bool useCholesky,
- const arma::mat& gramMatrix,
- const double lambda1 = 0.0,
- const double lambda2 = 0.0,
- const double tolerance = 1e-16);
-
- /**
- * Run LARS. The input matrix (like all MLPACK matrices) should be
- * column-major -- each column is an observation and each row is a dimension.
- * However, because LARS is more efficient on a row-major matrix, this method
- * will (internally) transpose the matrix. If this transposition is not
- * necessary (i.e., you want to pass in a row-major matrix), pass 'false' for
- * the transposeData parameter.
- *
- * @param data Column-major input data (or row-major input data if rowMajor =
- * true).
- * @param responses A vector of targets.
- * @param beta Vector to store the solution (the coefficients) in.
- * @param rowMajor Set to false if the data is row-major.
- */
- void Regress(const arma::mat& data,
- const arma::vec& responses,
- arma::vec& beta,
- const bool transposeData = true);
-
- //! Access the set of active dimensions.
- const std::vector<size_t>& ActiveSet() const { return activeSet; }
-
- //! Access the set of coefficients after each iteration; the solution is the
- //! last element.
- const std::vector<arma::vec>& BetaPath() const { return betaPath; }
-
- //! Access the set of values for lambda1 after each iteration; the solution is
- //! the last element.
- const std::vector<double>& LambdaPath() const { return lambdaPath; }
-
- //! Access the upper triangular cholesky factor
- const arma::mat& MatUtriCholFactor() const { return matUtriCholFactor; }
-
-private:
- //! Gram matrix.
- arma::mat matGramInternal;
-
- //! Reference to the Gram matrix we will use.
- const arma::mat& matGram;
-
- //! Upper triangular cholesky factor; initially 0x0 matrix.
- arma::mat matUtriCholFactor;
-
- //! Whether or not to use Cholesky decomposition when solving linear system.
- bool useCholesky;
-
- //! True if this is the LASSO problem.
- bool lasso;
- //! Regularization parameter for l1 penalty.
- double lambda1;
-
- //! True if this is the elastic net problem.
- bool elasticNet;
- //! Regularization parameter for l2 penalty.
- double lambda2;
-
- //! Tolerance for main loop.
- double tolerance;
-
- //! Solution path.
- std::vector<arma::vec> betaPath;
-
- //! Value of lambda_1 for each solution in solution path.
- std::vector<double> lambdaPath;
-
- //! Active set of dimensions.
- std::vector<size_t> activeSet;
-
- //! Active set membership indicator (for each dimension).
- std::vector<bool> isActive;
-
- /**
- * Remove activeVarInd'th element from active set.
- *
- * @param activeVarInd Index of element to remove from active set.
- */
- void Deactivate(const size_t activeVarInd);
-
- /**
- * Add dimension varInd to active set.
- *
- * @param varInd Dimension to add to active set.
- */
- void Activate(const size_t varInd);
-
- // compute "equiangular" direction in output space
- void ComputeYHatDirection(const arma::mat& matX,
- const arma::vec& betaDirection,
- arma::vec& yHatDirection);
-
- // interpolate to compute last solution vector
- void InterpolateBeta();
-
- void CholeskyInsert(const arma::vec& newX, const arma::mat& X);
-
- void CholeskyInsert(double sqNormNewX, const arma::vec& newGramCol);
-
- void GivensRotate(const arma::vec::fixed<2>& x,
- arma::vec::fixed<2>& rotatedX,
- arma::mat& G);
-
- void CholeskyDelete(const size_t colToKill);
-};
-
-}; // namespace regression
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lars/lars.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/lars/lars.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lars/lars.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lars/lars.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,243 @@
+/**
+ * @file lars.hpp
+ * @author Nishant Mehta (niche)
+ *
+ * Definition of the LARS class, which performs Least Angle Regression and the
+ * LASSO.
+ *
+ * Only minor modifications of LARS are necessary to handle the constrained
+ * version of the problem:
+ *
+ * \f[
+ * \min_{\beta} 0.5 || X \beta - y ||_2^2 + 0.5 \lambda_2 || \beta ||_2^2
+ * \f]
+ * subject to \f$ ||\beta||_1 <= \tau \f$
+ *
+ * Although this option currently is not implemented, it will be implemented
+ * very soon.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_LARS_LARS_HPP
+#define __MLPACK_METHODS_LARS_LARS_HPP
+
+#include <armadillo>
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace regression {
+
+// beta is the estimator
+// yHat is the prediction from the current estimator
+
+/**
+ * An implementation of LARS, a stage-wise homotopy-based algorithm for
+ * l1-regularized linear regression (LASSO) and l1+l2 regularized linear
+ * regression (Elastic Net).
+ *
+ * Let \f$ X \f$ be a matrix where each row is a point and each column is a
+ * dimension and let \f$ y \f$ be a vector of responses.
+ *
+ * The Elastic Net problem is to solve
+ *
+ * \f[ \min_{\beta} 0.5 || X \beta - y ||_2^2 + \lambda_1 || \beta ||_1 +
+ * 0.5 \lambda_2 || \beta ||_2^2 \f]
+ *
+ * where \f$ \beta \f$ is the vector of regression coefficients.
+ *
+ * If \f$ \lambda_1 > 0 \f$ and \f$ \lambda_2 = 0 \f$, the problem is the LASSO.
+ * If \f$ \lambda_1 > 0 \f$ and \f$ \lambda_2 > 0 \f$, the problem is the
+ * elastic net.
+ * If \f$ \lambda_1 = 0 \f$ and \f$ \lambda_2 > 0 \f$, the problem is ridge
+ * regression.
+ * If \f$ \lambda_1 = 0 \f$ and \f$ \lambda_2 = 0 \f$, the problem is
+ * unregularized linear regression.
+ *
+ * Note: This algorithm is not recommended for use (in terms of efficiency)
+ * when \f$ \lambda_1 \f$ = 0.
+ *
+ * For more details, see the following papers:
+ *
+ * @code
+ * @article{efron2004least,
+ * title={Least angle regression},
+ * author={Efron, B. and Hastie, T. and Johnstone, I. and Tibshirani, R.},
+ * journal={The Annals of statistics},
+ * volume={32},
+ * number={2},
+ * pages={407--499},
+ * year={2004},
+ * publisher={Institute of Mathematical Statistics}
+ * }
+ * @endcode
+ *
+ * @code
+ * @article{zou2005regularization,
+ * title={Regularization and variable selection via the elastic net},
+ * author={Zou, H. and Hastie, T.},
+ * journal={Journal of the Royal Statistical Society Series B},
+ * volume={67},
+ * number={2},
+ * pages={301--320},
+ * year={2005},
+ * publisher={Royal Statistical Society}
+ * }
+ * @endcode
+ */
+class LARS
+{
+ public:
+ /**
+ * Set the parameters to LARS. Both lambda1 and lambda2 default to 0.
+ *
+ * @param useCholesky Whether or not to use Cholesky decomposition when
+ * solving linear system (as opposed to using the full Gram matrix).
+ * @param lambda1 Regularization parameter for l1-norm penalty.
+ * @param lambda2 Regularization parameter for l2-norm penalty.
+ * @param tolerance Run until the maximum correlation of elements in (X^T y)
+ * is less than this.
+ */
+ LARS(const bool useCholesky,
+ const double lambda1 = 0.0,
+ const double lambda2 = 0.0,
+ const double tolerance = 1e-16);
+
+ /**
+ * Set the parameters to LARS, and pass in a precalculated Gram matrix. Both
+ * lambda1 and lambda2 default to 0.
+ *
+ * @param useCholesky Whether or not to use Cholesky decomposition when
+ * solving linear system (as opposed to using the full Gram matrix).
+ * @param gramMatrix Gram matrix.
+ * @param lambda1 Regularization parameter for l1-norm penalty.
+ * @param lambda2 Regularization parameter for l2-norm penalty.
+ * @param tolerance Run until the maximum correlation of elements in (X^T y)
+ * is less than this.
+ */
+ LARS(const bool useCholesky,
+ const arma::mat& gramMatrix,
+ const double lambda1 = 0.0,
+ const double lambda2 = 0.0,
+ const double tolerance = 1e-16);
+
+ /**
+ * Run LARS. The input matrix (like all MLPACK matrices) should be
+ * column-major -- each column is an observation and each row is a dimension.
+ * However, because LARS is more efficient on a row-major matrix, this method
+ * will (internally) transpose the matrix. If this transposition is not
+ * necessary (i.e., you want to pass in a row-major matrix), pass 'false' for
+ * the transposeData parameter.
+ *
+ * @param data Column-major input data (or row-major input data if rowMajor =
+ * true).
+ * @param responses A vector of targets.
+ * @param beta Vector to store the solution (the coefficients) in.
+ * @param rowMajor Set to false if the data is row-major.
+ */
+ void Regress(const arma::mat& data,
+ const arma::vec& responses,
+ arma::vec& beta,
+ const bool transposeData = true);
+
+ //! Access the set of active dimensions.
+ const std::vector<size_t>& ActiveSet() const { return activeSet; }
+
+ //! Access the set of coefficients after each iteration; the solution is the
+ //! last element.
+ const std::vector<arma::vec>& BetaPath() const { return betaPath; }
+
+ //! Access the set of values for lambda1 after each iteration; the solution is
+ //! the last element.
+ const std::vector<double>& LambdaPath() const { return lambdaPath; }
+
+ //! Access the upper triangular cholesky factor
+ const arma::mat& MatUtriCholFactor() const { return matUtriCholFactor; }
+
+private:
+ //! Gram matrix.
+ arma::mat matGramInternal;
+
+ //! Reference to the Gram matrix we will use.
+ const arma::mat& matGram;
+
+ //! Upper triangular cholesky factor; initially 0x0 matrix.
+ arma::mat matUtriCholFactor;
+
+ //! Whether or not to use Cholesky decomposition when solving linear system.
+ bool useCholesky;
+
+ //! True if this is the LASSO problem.
+ bool lasso;
+ //! Regularization parameter for l1 penalty.
+ double lambda1;
+
+ //! True if this is the elastic net problem.
+ bool elasticNet;
+ //! Regularization parameter for l2 penalty.
+ double lambda2;
+
+ //! Tolerance for main loop.
+ double tolerance;
+
+ //! Solution path.
+ std::vector<arma::vec> betaPath;
+
+ //! Value of lambda_1 for each solution in solution path.
+ std::vector<double> lambdaPath;
+
+ //! Active set of dimensions.
+ std::vector<size_t> activeSet;
+
+ //! Active set membership indicator (for each dimension).
+ std::vector<bool> isActive;
+
+ /**
+ * Remove activeVarInd'th element from active set.
+ *
+ * @param activeVarInd Index of element to remove from active set.
+ */
+ void Deactivate(const size_t activeVarInd);
+
+ /**
+ * Add dimension varInd to active set.
+ *
+ * @param varInd Dimension to add to active set.
+ */
+ void Activate(const size_t varInd);
+
+ // compute "equiangular" direction in output space
+ void ComputeYHatDirection(const arma::mat& matX,
+ const arma::vec& betaDirection,
+ arma::vec& yHatDirection);
+
+ // interpolate to compute last solution vector
+ void InterpolateBeta();
+
+ void CholeskyInsert(const arma::vec& newX, const arma::mat& X);
+
+ void CholeskyInsert(double sqNormNewX, const arma::vec& newGramCol);
+
+ void GivensRotate(const arma::vec::fixed<2>& x,
+ arma::vec::fixed<2>& rotatedX,
+ arma::mat& G);
+
+ void CholeskyDelete(const size_t colToKill);
+};
+
+}; // namespace regression
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lars/lars_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/lars/lars_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lars/lars_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,105 +0,0 @@
-/**
- * @file lars_main.cpp
- * @author Nishant Mehta
- *
- * Executable for LARS.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-
-#include "lars.hpp"
-
-PROGRAM_INFO("LARS", "An implementation of LARS: Least Angle Regression "
- "(Stagewise/laSso). This is a stage-wise homotopy-based algorithm for "
- "L1-regularized linear regression (LASSO) and L1+L2-regularized linear "
- "regression (Elastic Net).\n"
- "\n"
- "Let X be a matrix where each row is a point and each column is a "
- "dimension, and let y be a vector of targets.\n"
- "\n"
- "The Elastic Net problem is to solve\n\n"
- " min_beta 0.5 || X * beta - y ||_2^2 + lambda_1 ||beta||_1 +\n"
- " 0.5 lambda_2 ||beta||_2^2\n\n"
- "If lambda_1 > 0 and lambda_2 = 0, the problem is the LASSO.\n"
- "If lambda_1 > 0 and lambda_2 > 0, the problem is the Elastic Net.\n"
- "If lambda_1 = 0 and lambda_2 > 0, the problem is Ridge Regression.\n"
- "If lambda_1 = 0 and lambda_2 = 0, the problem is unregularized linear "
- "regression.\n"
- "\n"
- "For efficiency reasons, it is not recommended to use this algorithm with "
- "lambda_1 = 0.\n");
-
-PARAM_STRING_REQ("input_file", "File containing covariates (X).",
- "i");
-PARAM_STRING_REQ("responses_file", "File containing y "
- "(responses/observations).", "r");
-
-PARAM_STRING("output_file", "File to save beta (linear estimator) to.", "o",
- "output.csv");
-
-PARAM_DOUBLE("lambda1", "Regularization parameter for l1-norm penalty.", "l",
- 0);
-PARAM_DOUBLE("lambda2", "Regularization parameter for l2-norm penalty.", "L",
- 0);
-PARAM_FLAG("use_cholesky", "Use Cholesky decomposition during computation "
- "rather than explicitly computing the full Gram matrix.", "c");
-
-using namespace arma;
-using namespace std;
-using namespace mlpack;
-using namespace mlpack::regression;
-
-int main(int argc, char* argv[])
-{
- // Handle parameters,
- CLI::ParseCommandLine(argc, argv);
-
- double lambda1 = CLI::GetParam<double>("lambda1");
- double lambda2 = CLI::GetParam<double>("lambda2");
- bool useCholesky = CLI::HasParam("use_cholesky");
-
- // Load covariates. We can avoid LARS transposing our data by choosing to not
- // transpose this data.
- const string matXFilename = CLI::GetParam<string>("input_file");
- mat matX;
- data::Load(matXFilename.c_str(), matX, true, false);
-
- // Load responses. The responses should be a one-dimensional vector, and it
- // seems more likely that these will be stored with one response per line (one
- // per row). So we should not transpose upon loading.
- const string yFilename = CLI::GetParam<string>("responses_file");
- mat matY; // Will be a vector.
- data::Load(yFilename.c_str(), matY, true, false);
-
- // Make sure y is oriented the right way.
- if (matY.n_rows == 1)
- matY = trans(matY);
- if (matY.n_cols > 1)
- Log::Fatal << "Only one column or row allowed in responses file!" << endl;
-
- if (matY.n_elem != matX.n_rows)
- Log::Fatal << "Number of responses must be equal to number of rows of X!"
- << endl;
-
- // Do LARS.
- LARS lars(useCholesky, lambda1, lambda2);
- vec beta;
- lars.Regress(matX, matY.unsafe_col(0), beta, false /* do not transpose */);
-
- const string betaFilename = CLI::GetParam<string>("output_file");
- beta.save(betaFilename, raw_ascii);
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lars/lars_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/lars/lars_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lars/lars_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lars/lars_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,105 @@
+/**
+ * @file lars_main.cpp
+ * @author Nishant Mehta
+ *
+ * Executable for LARS.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+
+#include "lars.hpp"
+
+PROGRAM_INFO("LARS", "An implementation of LARS: Least Angle Regression "
+ "(Stagewise/laSso). This is a stage-wise homotopy-based algorithm for "
+ "L1-regularized linear regression (LASSO) and L1+L2-regularized linear "
+ "regression (Elastic Net).\n"
+ "\n"
+ "Let X be a matrix where each row is a point and each column is a "
+ "dimension, and let y be a vector of targets.\n"
+ "\n"
+ "The Elastic Net problem is to solve\n\n"
+ " min_beta 0.5 || X * beta - y ||_2^2 + lambda_1 ||beta||_1 +\n"
+ " 0.5 lambda_2 ||beta||_2^2\n\n"
+ "If lambda_1 > 0 and lambda_2 = 0, the problem is the LASSO.\n"
+ "If lambda_1 > 0 and lambda_2 > 0, the problem is the Elastic Net.\n"
+ "If lambda_1 = 0 and lambda_2 > 0, the problem is Ridge Regression.\n"
+ "If lambda_1 = 0 and lambda_2 = 0, the problem is unregularized linear "
+ "regression.\n"
+ "\n"
+ "For efficiency reasons, it is not recommended to use this algorithm with "
+ "lambda_1 = 0.\n");
+
+PARAM_STRING_REQ("input_file", "File containing covariates (X).",
+ "i");
+PARAM_STRING_REQ("responses_file", "File containing y "
+ "(responses/observations).", "r");
+
+PARAM_STRING("output_file", "File to save beta (linear estimator) to.", "o",
+ "output.csv");
+
+PARAM_DOUBLE("lambda1", "Regularization parameter for l1-norm penalty.", "l",
+ 0);
+PARAM_DOUBLE("lambda2", "Regularization parameter for l2-norm penalty.", "L",
+ 0);
+PARAM_FLAG("use_cholesky", "Use Cholesky decomposition during computation "
+ "rather than explicitly computing the full Gram matrix.", "c");
+
+using namespace arma;
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::regression;
+
+int main(int argc, char* argv[])
+{
+ // Handle parameters,
+ CLI::ParseCommandLine(argc, argv);
+
+ double lambda1 = CLI::GetParam<double>("lambda1");
+ double lambda2 = CLI::GetParam<double>("lambda2");
+ bool useCholesky = CLI::HasParam("use_cholesky");
+
+ // Load covariates. We can avoid LARS transposing our data by choosing to not
+ // transpose this data.
+ const string matXFilename = CLI::GetParam<string>("input_file");
+ mat matX;
+ data::Load(matXFilename.c_str(), matX, true, false);
+
+ // Load responses. The responses should be a one-dimensional vector, and it
+ // seems more likely that these will be stored with one response per line (one
+ // per row). So we should not transpose upon loading.
+ const string yFilename = CLI::GetParam<string>("responses_file");
+ mat matY; // Will be a vector.
+ data::Load(yFilename.c_str(), matY, true, false);
+
+ // Make sure y is oriented the right way.
+ if (matY.n_rows == 1)
+ matY = trans(matY);
+ if (matY.n_cols > 1)
+ Log::Fatal << "Only one column or row allowed in responses file!" << endl;
+
+ if (matY.n_elem != matX.n_rows)
+ Log::Fatal << "Number of responses must be equal to number of rows of X!"
+ << endl;
+
+ // Do LARS.
+ LARS lars(useCholesky, lambda1, lambda2);
+ vec beta;
+ lars.Regress(matX, matY.unsafe_col(0), beta, false /* do not transpose */);
+
+ const string betaFilename = CLI::GetParam<string>("output_file");
+ beta.save(betaFilename, raw_ascii);
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/linear_regression/linear_regression.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/linear_regression/linear_regression.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/linear_regression/linear_regression.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,105 +0,0 @@
-/**
- * @file linear_regression.cpp
- * @author James Cline
- *
- * Implementation of simple linear regression.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "linear_regression.hpp"
-
-using namespace mlpack;
-using namespace mlpack::regression;
-
-LinearRegression::LinearRegression(arma::mat& predictors,
- const arma::colvec& responses)
-{
- /*
- * We want to calculate the a_i coefficients of:
- * \sum_{i=0}^n (a_i * x_i^i)
- * In order to get the intercept value, we will add a row of ones.
- */
-
- // We store the number of rows of the predictors.
- // Reminder: Armadillo stores the data transposed from how we think of it,
- // that is, columns are actually rows (see: column major order).
- size_t nCols = predictors.n_cols;
-
- // Here we add the row of ones to the predictors.
- arma::rowvec ones;
- ones.ones(nCols);
- predictors.insert_rows(0, ones);
-
- // We set the parameters to the correct size and initialize them to zero.
- parameters.zeros(nCols);
-
- // We compute the QR decomposition of the predictors.
- // We transpose the predictors because they are in column major order.
- arma::mat Q, R;
- arma::qr(Q, R, arma::trans(predictors));
-
- // We compute the parameters, B, like so:
- // R * B = Q^T * responses
- // B = Q^T * responses * R^-1
- arma::solve(parameters, R, arma::trans(Q) * responses);
-
- // We now remove the row of ones we added so the user's data is unmodified.
- predictors.shed_row(0);
-}
-
-LinearRegression::LinearRegression(const std::string& filename)
-{
- data::Load(filename, parameters, true);
-}
-
-LinearRegression::LinearRegression(const LinearRegression& linearRegression)
-{
- parameters = linearRegression.parameters;
-}
-
-LinearRegression::~LinearRegression()
-{ }
-
-void LinearRegression::Predict(const arma::mat& points, arma::vec& predictions)
-{
- // We get the number of columns and rows of the dataset.
- const size_t nCols = points.n_cols;
- const size_t nRows = points.n_rows;
-
- // We want to be sure we have the correct number of dimensions in the dataset.
- Log::Assert(nRows == parameters.n_rows - 1);
- if (nRows != parameters.n_rows -1)
- {
- Log::Fatal << "The test data must have the same number of columns as the "
- "training file.\n";
- }
-
- predictions.zeros(nCols);
- // We set all the predictions to the intercept value initially.
- predictions += parameters(0);
-
- // Now we iterate through the dimensions of the data and parameters.
- for (size_t i = 1; i < nRows + 1; ++i)
- {
- // Now we iterate through each row, or point, of the data.
- for (size_t j = 0; j < nCols; ++j)
- {
- // Increment each prediction value by x_i * a_i, or the next dimensional
- // coefficient and x value.
- predictions(j) += parameters(i) * points(i - 1, j);
- }
- }
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/linear_regression/linear_regression.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/linear_regression/linear_regression.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/linear_regression/linear_regression.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/linear_regression/linear_regression.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,105 @@
+/**
+ * @file linear_regression.cpp
+ * @author James Cline
+ *
+ * Implementation of simple linear regression.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "linear_regression.hpp"
+
+using namespace mlpack;
+using namespace mlpack::regression;
+
+LinearRegression::LinearRegression(arma::mat& predictors,
+ const arma::colvec& responses)
+{
+ /*
+ * We want to calculate the a_i coefficients of:
+ * \sum_{i=0}^n (a_i * x_i^i)
+ * In order to get the intercept value, we will add a row of ones.
+ */
+
+ // We store the number of rows of the predictors.
+ // Reminder: Armadillo stores the data transposed from how we think of it,
+ // that is, columns are actually rows (see: column major order).
+ size_t nCols = predictors.n_cols;
+
+ // Here we add the row of ones to the predictors.
+ arma::rowvec ones;
+ ones.ones(nCols);
+ predictors.insert_rows(0, ones);
+
+ // We set the parameters to the correct size and initialize them to zero.
+ parameters.zeros(nCols);
+
+ // We compute the QR decomposition of the predictors.
+ // We transpose the predictors because they are in column major order.
+ arma::mat Q, R;
+ arma::qr(Q, R, arma::trans(predictors));
+
+ // We compute the parameters, B, like so:
+ // R * B = Q^T * responses
+ // B = Q^T * responses * R^-1
+ arma::solve(parameters, R, arma::trans(Q) * responses);
+
+ // We now remove the row of ones we added so the user's data is unmodified.
+ predictors.shed_row(0);
+}
+
+LinearRegression::LinearRegression(const std::string& filename)
+{
+ data::Load(filename, parameters, true);
+}
+
+LinearRegression::LinearRegression(const LinearRegression& linearRegression)
+{
+ parameters = linearRegression.parameters;
+}
+
+LinearRegression::~LinearRegression()
+{ }
+
+void LinearRegression::Predict(const arma::mat& points, arma::vec& predictions)
+{
+ // We get the number of columns and rows of the dataset.
+ const size_t nCols = points.n_cols;
+ const size_t nRows = points.n_rows;
+
+ // We want to be sure we have the correct number of dimensions in the dataset.
+ Log::Assert(nRows == parameters.n_rows - 1);
+ if (nRows != parameters.n_rows -1)
+ {
+ Log::Fatal << "The test data must have the same number of columns as the "
+ "training file.\n";
+ }
+
+ predictions.zeros(nCols);
+ // We set all the predictions to the intercept value initially.
+ predictions += parameters(0);
+
+ // Now we iterate through the dimensions of the data and parameters.
+ for (size_t i = 1; i < nRows + 1; ++i)
+ {
+ // Now we iterate through each row, or point, of the data.
+ for (size_t j = 0; j < nCols; ++j)
+ {
+ // Increment each prediction value by x_i * a_i, or the next dimensional
+ // coefficient and x value.
+ predictions(j) += parameters(i) * points(i - 1, j);
+ }
+ }
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/linear_regression/linear_regression.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/linear_regression/linear_regression.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/linear_regression/linear_regression.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,93 +0,0 @@
-/**
- * @file linear_regression.hpp
- * @author James Cline
- *
- * Simple least-squares linear regression.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_LINEAR_REGRESSION_LINEAR_REGRESSION_HPP
-#define __MLPACK_METHODS_LINEAR_REGRESSION_LINEAR_REGRESSION_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace regression /** Regression methods. */ {
-
-/**
- * A simple linear regression algorithm using ordinary least squares.
- */
-class LinearRegression
-{
- public:
- /**
- * Creates the model.
- *
- * @param predictors X, matrix of data points to create B with.
- * @param responses y, the measured data for each point in X
- */
- LinearRegression(arma::mat& predictors, const arma::vec& responses);
-
- /**
- * Initialize the model from a file.
- *
- * @param filename the name of the file to load the model from.
- */
- LinearRegression(const std::string& filename);
-
- /**
- * Copy constructor.
- *
- * @param linearRegression the other instance to copy parameters from.
- */
- LinearRegression(const LinearRegression& linearRegression);
-
- /**
- * Default constructor.
- */
- LinearRegression() {}
-
-
- /**
- * Destructor - no work done.
- */
- ~LinearRegression();
-
- /**
- * Calculate y_i for each data point in points.
- *
- * @param points the data points to calculate with.
- * @param predictions y, will contain calculated values on completion.
- */
- void Predict(const arma::mat& points, arma::vec& predictions);
-
- //! Return the parameters (the b vector).
- const arma::vec& Parameters() const { return parameters; }
- //! Modify the parameters (the b vector).
- arma::vec& Parameters() { return parameters; }
-
- private:
- /**
- * The calculated B.
- * Initialized and filled by constructor to hold the least squares solution.
- */
- arma::vec parameters;
-};
-
-}; // namespace linear_regression
-}; // namespace mlpack
-
-#endif // __MLPACK_METHODS_LINEAR_REGRESSCLIN_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/linear_regression/linear_regression.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/linear_regression/linear_regression.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/linear_regression/linear_regression.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/linear_regression/linear_regression.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,93 @@
+/**
+ * @file linear_regression.hpp
+ * @author James Cline
+ *
+ * Simple least-squares linear regression.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_LINEAR_REGRESSION_LINEAR_REGRESSION_HPP
+#define __MLPACK_METHODS_LINEAR_REGRESSION_LINEAR_REGRESSION_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace regression /** Regression methods. */ {
+
+/**
+ * A simple linear regression algorithm using ordinary least squares.
+ */
+class LinearRegression
+{
+ public:
+ /**
+ * Creates the model.
+ *
+ * @param predictors X, matrix of data points to create B with.
+ * @param responses y, the measured data for each point in X
+ */
+ LinearRegression(arma::mat& predictors, const arma::vec& responses);
+
+ /**
+ * Initialize the model from a file.
+ *
+ * @param filename the name of the file to load the model from.
+ */
+ LinearRegression(const std::string& filename);
+
+ /**
+ * Copy constructor.
+ *
+ * @param linearRegression the other instance to copy parameters from.
+ */
+ LinearRegression(const LinearRegression& linearRegression);
+
+ /**
+ * Default constructor.
+ */
+ LinearRegression() {}
+
+
+ /**
+ * Destructor - no work done.
+ */
+ ~LinearRegression();
+
+ /**
+ * Calculate y_i for each data point in points.
+ *
+ * @param points the data points to calculate with.
+ * @param predictions y, will contain calculated values on completion.
+ */
+ void Predict(const arma::mat& points, arma::vec& predictions);
+
+ //! Return the parameters (the b vector).
+ const arma::vec& Parameters() const { return parameters; }
+ //! Modify the parameters (the b vector).
+ arma::vec& Parameters() { return parameters; }
+
+ private:
+ /**
+ * The calculated B.
+ * Initialized and filled by constructor to hold the least squares solution.
+ */
+ arma::vec parameters;
+};
+
+}; // namespace linear_regression
+}; // namespace mlpack
+
+#endif // __MLPACK_METHODS_LINEAR_REGRESSCLIN_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/linear_regression/linear_regression_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/linear_regression/linear_regression_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/linear_regression/linear_regression_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,181 +0,0 @@
-/**
- * @file linear_regression_main.cpp
- * @author James Cline
- *
- * Main function for least-squares linear regression.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include "linear_regression.hpp"
-
-PROGRAM_INFO("Simple Linear Regression Prediction",
- "An implementation of simple linear regression using ordinary least "
- "squares. This solves the problem\n\n"
- " y = X * b + e\n\n"
- "where X (--input_file) and y (the last row of --input_file, or "
- "--input_responses) are known and b is the desired variable. The "
- "calculated b is saved to disk (--output_file).\n"
- "\n"
- "Optionally, the calculated value of b is used to predict the responses for"
- " another matrix X' (--test_file):\n\n"
- " y' = X' * b\n\n"
- "and these predicted responses, y', are saved to a file "
- "(--output_predictions).");
-
-PARAM_STRING("input_file", "File containing X (regressors).", "i", "");
-PARAM_STRING("input_responses", "Optional file containing y (responses). If "
- "not given, the responses are assumed to be the last row of the input "
- "file.", "r", "");
-
-PARAM_STRING("model_file", "File containing existing model (parameters).", "m",
- "");
-
-PARAM_STRING("output_file", "File where parameters (b) will be saved.",
- "o", "parameters.csv");
-
-PARAM_STRING("test_file", "File containing X' (test regressors).", "t", "");
-PARAM_STRING("output_predictions", "If --test_file is specified, this file is "
- "where the predicted responses will be saved.", "p", "predictions.csv");
-
-using namespace mlpack;
-using namespace mlpack::regression;
-using namespace arma;
-using namespace std;
-
-int main(int argc, char* argv[])
-{
- // Handle parameters
- CLI::ParseCommandLine(argc, argv);
-
- const string modelName = CLI::GetParam<string>("model_file");
- const string outputFile = CLI::GetParam<string>("output_file");
- const string outputPredictions = CLI::GetParam<string>("output_predictions");
- const string responseName = CLI::GetParam<string>("input_responses");
- const string testName = CLI::GetParam<string>("test_file");
- const string trainName = CLI::GetParam<string>("input_file");
-
- mat regressors;
- mat responses;
-
- LinearRegression lr;
-
- bool computeModel;
-
- // We want to determine if an input file XOR model file were given
- if (trainName.empty()) // The user specified no input file
- {
- if (modelName.empty()) // The user specified no model file, error and exit
- {
- Log::Fatal << "You must specify either --input_file or --model_file." << std::endl;
- exit(1);
- }
- else // The model file was specified, no problems
- {
- computeModel = false;
- }
- }
- // The user specified an input file but no model file, no problems
- else if (modelName.empty())
- {
- computeModel = true;
- }
- // The user specified both an input file and model file.
- // This is ambiguous -- which model should we use? A generated one or given one?
- // Report error and exit.
- else
- {
- Log::Fatal << "You must specify either --input_file or --model_file, not both." << std::endl;
- exit(1);
- }
-
- // If they specified a model file, we also need a test file or we
- // have nothing to do.
- if(!computeModel && testName.empty())
- {
- Log::Fatal << "When specifying --model_file, you must also specify --test_file." << std::endl;
- exit(1);
- }
-
- // An input file was given and we need to generate the model.
- if (computeModel)
- {
- Timer::Start("load_regressors");
- data::Load(trainName.c_str(), regressors, true);
- Timer::Stop("load_regressors");
-
- // Are the responses in a separate file?
- if (responseName.empty())
- {
- // The initial predictors for y, Nx1
- responses = trans(regressors.row(regressors.n_rows - 1));
- regressors.shed_row(regressors.n_rows - 1);
- }
- else
- {
- // The initial predictors for y, Nx1
- Timer::Start("load_responses");
- data::Load(responseName.c_str(), responses, true);
- Timer::Stop("load_responses");
-
- if (responses.n_rows == 1)
- responses = trans(responses); // Probably loaded backwards, but that's ok.
-
- if (responses.n_cols > 1)
- Log::Fatal << "The responses must have one column.\n";
-
- if (responses.n_rows != regressors.n_cols)
- Log::Fatal << "The responses must have the same number of rows as the "
- "training file.\n";
- }
-
- Timer::Start("regression");
- lr = LinearRegression(regressors, responses.unsafe_col(0));
- Timer::Stop("regression");
-
- // Save the parameters.
- data::Save(outputFile.c_str(), lr.Parameters(), true);
- }
-
- // Did we want to predict, too?
- if (!testName.empty() )
- {
-
- // A model file was passed in, so load it
- if (!computeModel)
- {
- Timer::Start("load_model");
- lr = LinearRegression(modelName);
- Timer::Stop("load_model");
- }
-
- // Load the test file data
- arma::mat points;
- Timer::Stop("load_test_points");
- data::Load(testName.c_str(), points, true);
- Timer::Stop("load_test_points");
-
- // Perform the predictions using our model
- arma::vec predictions;
- Timer::Start("prediction");
- lr.Predict(points, predictions);
- Timer::Stop("prediction");
-
- // Save predictions.
- predictions = arma::trans(predictions);
- data::Save(outputPredictions.c_str(), predictions, true);
- }
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/linear_regression/linear_regression_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/linear_regression/linear_regression_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/linear_regression/linear_regression_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/linear_regression/linear_regression_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,181 @@
+/**
+ * @file linear_regression_main.cpp
+ * @author James Cline
+ *
+ * Main function for least-squares linear regression.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include "linear_regression.hpp"
+
+PROGRAM_INFO("Simple Linear Regression Prediction",
+ "An implementation of simple linear regression using ordinary least "
+ "squares. This solves the problem\n\n"
+ " y = X * b + e\n\n"
+ "where X (--input_file) and y (the last row of --input_file, or "
+ "--input_responses) are known and b is the desired variable. The "
+ "calculated b is saved to disk (--output_file).\n"
+ "\n"
+ "Optionally, the calculated value of b is used to predict the responses for"
+ " another matrix X' (--test_file):\n\n"
+ " y' = X' * b\n\n"
+ "and these predicted responses, y', are saved to a file "
+ "(--output_predictions).");
+
+PARAM_STRING("input_file", "File containing X (regressors).", "i", "");
+PARAM_STRING("input_responses", "Optional file containing y (responses). If "
+ "not given, the responses are assumed to be the last row of the input "
+ "file.", "r", "");
+
+PARAM_STRING("model_file", "File containing existing model (parameters).", "m",
+ "");
+
+PARAM_STRING("output_file", "File where parameters (b) will be saved.",
+ "o", "parameters.csv");
+
+PARAM_STRING("test_file", "File containing X' (test regressors).", "t", "");
+PARAM_STRING("output_predictions", "If --test_file is specified, this file is "
+ "where the predicted responses will be saved.", "p", "predictions.csv");
+
+using namespace mlpack;
+using namespace mlpack::regression;
+using namespace arma;
+using namespace std;
+
+int main(int argc, char* argv[])
+{
+ // Handle parameters
+ CLI::ParseCommandLine(argc, argv);
+
+ const string modelName = CLI::GetParam<string>("model_file");
+ const string outputFile = CLI::GetParam<string>("output_file");
+ const string outputPredictions = CLI::GetParam<string>("output_predictions");
+ const string responseName = CLI::GetParam<string>("input_responses");
+ const string testName = CLI::GetParam<string>("test_file");
+ const string trainName = CLI::GetParam<string>("input_file");
+
+ mat regressors;
+ mat responses;
+
+ LinearRegression lr;
+
+ bool computeModel;
+
+ // We want to determine if an input file XOR model file were given
+ if (trainName.empty()) // The user specified no input file
+ {
+ if (modelName.empty()) // The user specified no model file, error and exit
+ {
+ Log::Fatal << "You must specify either --input_file or --model_file." << std::endl;
+ exit(1);
+ }
+ else // The model file was specified, no problems
+ {
+ computeModel = false;
+ }
+ }
+ // The user specified an input file but no model file, no problems
+ else if (modelName.empty())
+ {
+ computeModel = true;
+ }
+ // The user specified both an input file and model file.
+ // This is ambiguous -- which model should we use? A generated one or given one?
+ // Report error and exit.
+ else
+ {
+ Log::Fatal << "You must specify either --input_file or --model_file, not both." << std::endl;
+ exit(1);
+ }
+
+ // If they specified a model file, we also need a test file or we
+ // have nothing to do.
+ if(!computeModel && testName.empty())
+ {
+ Log::Fatal << "When specifying --model_file, you must also specify --test_file." << std::endl;
+ exit(1);
+ }
+
+ // An input file was given and we need to generate the model.
+ if (computeModel)
+ {
+ Timer::Start("load_regressors");
+ data::Load(trainName.c_str(), regressors, true);
+ Timer::Stop("load_regressors");
+
+ // Are the responses in a separate file?
+ if (responseName.empty())
+ {
+ // The initial predictors for y, Nx1
+ responses = trans(regressors.row(regressors.n_rows - 1));
+ regressors.shed_row(regressors.n_rows - 1);
+ }
+ else
+ {
+ // The initial predictors for y, Nx1
+ Timer::Start("load_responses");
+ data::Load(responseName.c_str(), responses, true);
+ Timer::Stop("load_responses");
+
+ if (responses.n_rows == 1)
+ responses = trans(responses); // Probably loaded backwards, but that's ok.
+
+ if (responses.n_cols > 1)
+ Log::Fatal << "The responses must have one column.\n";
+
+ if (responses.n_rows != regressors.n_cols)
+ Log::Fatal << "The responses must have the same number of rows as the "
+ "training file.\n";
+ }
+
+ Timer::Start("regression");
+ lr = LinearRegression(regressors, responses.unsafe_col(0));
+ Timer::Stop("regression");
+
+ // Save the parameters.
+ data::Save(outputFile.c_str(), lr.Parameters(), true);
+ }
+
+ // Did we want to predict, too?
+ if (!testName.empty() )
+ {
+
+ // A model file was passed in, so load it
+ if (!computeModel)
+ {
+ Timer::Start("load_model");
+ lr = LinearRegression(modelName);
+ Timer::Stop("load_model");
+ }
+
+ // Load the test file data
+ arma::mat points;
+ Timer::Stop("load_test_points");
+ data::Load(testName.c_str(), points, true);
+ Timer::Stop("load_test_points");
+
+ // Perform the predictions using our model
+ arma::vec predictions;
+ Timer::Start("prediction");
+ lr.Predict(points, predictions);
+ Timer::Stop("prediction");
+
+ // Save predictions.
+ predictions = arma::trans(predictions);
+ data::Save(outputPredictions.c_str(), predictions, true);
+ }
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding/lcc.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/local_coordinate_coding/lcc.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding/lcc.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,171 +0,0 @@
-/**
- * @file lcc.hpp
- * @author Nishant Mehta
- *
- * Definition of the LocalCoordinateCoding class, which performs the Local
- * Coordinate Coding algorithm.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_LOCAL_COORDINATE_CODING_LCC_HPP
-#define __MLPACK_METHODS_LOCAL_COORDINATE_CODING_LCC_HPP
-
-#include <mlpack/core.hpp>
-#include <mlpack/methods/lars/lars.hpp>
-
-// Include three simple dictionary initializers from sparse coding.
-#include "../sparse_coding/nothing_initializer.hpp"
-#include "../sparse_coding/data_dependent_random_initializer.hpp"
-#include "../sparse_coding/random_initializer.hpp"
-
-namespace mlpack {
-namespace lcc {
-
-/**
- * An implementation of Local Coordinate Coding (LCC) that codes data which
- * approximately lives on a manifold using a variation of l1-norm regularized
- * sparse coding; in LCC, the penalty on the absolute value of each point's
- * coefficient for each atom is weighted by the squared distance of that point
- * to that atom.
- *
- * Let d be the number of dimensions in the original space, m the number of
- * training points, and k the number of atoms in the dictionary (the dimension
- * of the learned feature space). The training data X is a d-by-m matrix where
- * each column is a point and each row is a dimension. The dictionary D is a
- * d-by-k matrix, and the sparse codes matrix Z is a k-by-m matrix.
- * This program seeks to minimize the objective:
- * min_{D,Z} ||X - D Z||_{Fro}^2
- * + lambda sum_{i=1}^m sum_{j=1}^k dist(X_i,D_j)^2 Z_i^j
- * where lambda > 0.
- *
- * This problem is solved by an algorithm that alternates between a dictionary
- * learning step and a sparse coding step. The dictionary learning step updates
- * the dictionary D by solving a linear system (note that the objective is a
- * positive definite quadratic program). The sparse coding step involves
- * solving a large number of weighted l1-norm regularized linear regression
- * problems problems; this can be done efficiently using LARS, an algorithm
- * that can solve the LASSO (paper below).
- *
- * The papers are listed below.
- *
- * @code
- * @incollection{NIPS2009_0719,
- * title = {Nonlinear Learning using Local Coordinate Coding},
- * author = {Kai Yu and Tong Zhang and Yihong Gong},
- * booktitle = {Advances in Neural Information Processing Systems 22},
- * editor = {Y. Bengio and D. Schuurmans and J. Lafferty and C. K. I. Williams
- * and A. Culotta},
- * pages = {2223--2231},
- * year = {2009}
- * }
- * @endcode
- *
- * @code
- * @article{efron2004least,
- * title={Least angle regression},
- * author={Efron, B. and Hastie, T. and Johnstone, I. and Tibshirani, R.},
- * journal={The Annals of statistics},
- * volume={32},
- * number={2},
- * pages={407--499},
- * year={2004},
- * publisher={Institute of Mathematical Statistics}
- * }
- * @endcode
- */
-template<typename DictionaryInitializer =
- sparse_coding::DataDependentRandomInitializer>
-class LocalCoordinateCoding
-{
- public:
- /**
- * Set the parameters to LocalCoordinateCoding.
- *
- * @param data Data matrix.
- * @param atoms Number of atoms in dictionary.
- * @param lambda Regularization parameter for weighted l1-norm penalty.
- */
- LocalCoordinateCoding(const arma::mat& data,
- const size_t atoms,
- const double lambda);
-
- /**
- * Run local coordinate coding.
- *
- * @param nIterations Maximum number of iterations to run algorithm.
- * @param objTolerance Tolerance of objective function. When the objective
- * function changes by a value lower than this tolerance, the optimization
- * terminates.
- */
- void Encode(const size_t maxIterations = 0,
- const double objTolerance = 0.01);
-
- /**
- * Code each point via distance-weighted LARS.
- */
- void OptimizeCode();
-
- /**
- * Learn dictionary by solving linear system.
- *
- * @param adjacencies Indices of entries (unrolled column by column) of
- * the coding matrix Z that are non-zero (the adjacency matrix for the
- * bipartite graph of points and atoms)
- */
- void OptimizeDictionary(arma::uvec adjacencies);
-
- /**
- * Compute objective function given the list of adjacencies.
- */
- double Objective(arma::uvec adjacencies) const;
-
- //! Access the data.
- const arma::mat& Data() const { return data; }
-
- //! Accessor for dictionary.
- const arma::mat& Dictionary() const { return dictionary; }
- //! Mutator for dictionary.
- arma::mat& Dictionary() { return dictionary; }
-
- //! Accessor the codes.
- const arma::mat& Codes() const { return codes; }
- //! Modify the codes.
- arma::mat& Codes() { return codes; }
-
- private:
- //! Number of atoms in dictionary.
- size_t atoms;
-
- //! Data matrix (columns are points).
- const arma::mat& data;
-
- //! Dictionary (columns are atoms).
- arma::mat dictionary;
-
- //! Codes (columns are points).
- arma::mat codes;
-
- //! l1 regularization term.
- double lambda;
-};
-
-}; // namespace lcc
-}; // namespace mlpack
-
-// Include implementation.
-#include "lcc_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding/lcc.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/local_coordinate_coding/lcc.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding/lcc.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding/lcc.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,171 @@
+/**
+ * @file lcc.hpp
+ * @author Nishant Mehta
+ *
+ * Definition of the LocalCoordinateCoding class, which performs the Local
+ * Coordinate Coding algorithm.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_LOCAL_COORDINATE_CODING_LCC_HPP
+#define __MLPACK_METHODS_LOCAL_COORDINATE_CODING_LCC_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/methods/lars/lars.hpp>
+
+// Include three simple dictionary initializers from sparse coding.
+#include "../sparse_coding/nothing_initializer.hpp"
+#include "../sparse_coding/data_dependent_random_initializer.hpp"
+#include "../sparse_coding/random_initializer.hpp"
+
+namespace mlpack {
+namespace lcc {
+
+/**
+ * An implementation of Local Coordinate Coding (LCC) that codes data which
+ * approximately lives on a manifold using a variation of l1-norm regularized
+ * sparse coding; in LCC, the penalty on the absolute value of each point's
+ * coefficient for each atom is weighted by the squared distance of that point
+ * to that atom.
+ *
+ * Let d be the number of dimensions in the original space, m the number of
+ * training points, and k the number of atoms in the dictionary (the dimension
+ * of the learned feature space). The training data X is a d-by-m matrix where
+ * each column is a point and each row is a dimension. The dictionary D is a
+ * d-by-k matrix, and the sparse codes matrix Z is a k-by-m matrix.
+ * This program seeks to minimize the objective:
+ * min_{D,Z} ||X - D Z||_{Fro}^2
+ * + lambda sum_{i=1}^m sum_{j=1}^k dist(X_i,D_j)^2 Z_i^j
+ * where lambda > 0.
+ *
+ * This problem is solved by an algorithm that alternates between a dictionary
+ * learning step and a sparse coding step. The dictionary learning step updates
+ * the dictionary D by solving a linear system (note that the objective is a
+ * positive definite quadratic program). The sparse coding step involves
+ * solving a large number of weighted l1-norm regularized linear regression
+ * problems problems; this can be done efficiently using LARS, an algorithm
+ * that can solve the LASSO (paper below).
+ *
+ * The papers are listed below.
+ *
+ * @code
+ * @incollection{NIPS2009_0719,
+ * title = {Nonlinear Learning using Local Coordinate Coding},
+ * author = {Kai Yu and Tong Zhang and Yihong Gong},
+ * booktitle = {Advances in Neural Information Processing Systems 22},
+ * editor = {Y. Bengio and D. Schuurmans and J. Lafferty and C. K. I. Williams
+ * and A. Culotta},
+ * pages = {2223--2231},
+ * year = {2009}
+ * }
+ * @endcode
+ *
+ * @code
+ * @article{efron2004least,
+ * title={Least angle regression},
+ * author={Efron, B. and Hastie, T. and Johnstone, I. and Tibshirani, R.},
+ * journal={The Annals of statistics},
+ * volume={32},
+ * number={2},
+ * pages={407--499},
+ * year={2004},
+ * publisher={Institute of Mathematical Statistics}
+ * }
+ * @endcode
+ */
+template<typename DictionaryInitializer =
+ sparse_coding::DataDependentRandomInitializer>
+class LocalCoordinateCoding
+{
+ public:
+ /**
+ * Set the parameters to LocalCoordinateCoding.
+ *
+ * @param data Data matrix.
+ * @param atoms Number of atoms in dictionary.
+ * @param lambda Regularization parameter for weighted l1-norm penalty.
+ */
+ LocalCoordinateCoding(const arma::mat& data,
+ const size_t atoms,
+ const double lambda);
+
+ /**
+ * Run local coordinate coding.
+ *
+ * @param nIterations Maximum number of iterations to run algorithm.
+ * @param objTolerance Tolerance of objective function. When the objective
+ * function changes by a value lower than this tolerance, the optimization
+ * terminates.
+ */
+ void Encode(const size_t maxIterations = 0,
+ const double objTolerance = 0.01);
+
+ /**
+ * Code each point via distance-weighted LARS.
+ */
+ void OptimizeCode();
+
+ /**
+ * Learn dictionary by solving linear system.
+ *
+ * @param adjacencies Indices of entries (unrolled column by column) of
+ * the coding matrix Z that are non-zero (the adjacency matrix for the
+ * bipartite graph of points and atoms)
+ */
+ void OptimizeDictionary(arma::uvec adjacencies);
+
+ /**
+ * Compute objective function given the list of adjacencies.
+ */
+ double Objective(arma::uvec adjacencies) const;
+
+ //! Access the data.
+ const arma::mat& Data() const { return data; }
+
+ //! Accessor for dictionary.
+ const arma::mat& Dictionary() const { return dictionary; }
+ //! Mutator for dictionary.
+ arma::mat& Dictionary() { return dictionary; }
+
+ //! Accessor the codes.
+ const arma::mat& Codes() const { return codes; }
+ //! Modify the codes.
+ arma::mat& Codes() { return codes; }
+
+ private:
+ //! Number of atoms in dictionary.
+ size_t atoms;
+
+ //! Data matrix (columns are points).
+ const arma::mat& data;
+
+ //! Dictionary (columns are atoms).
+ arma::mat dictionary;
+
+ //! Codes (columns are points).
+ arma::mat codes;
+
+ //! l1 regularization term.
+ double lambda;
+};
+
+}; // namespace lcc
+}; // namespace mlpack
+
+// Include implementation.
+#include "lcc_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding/lcc_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/local_coordinate_coding/lcc_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding/lcc_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,329 +0,0 @@
-/**
- * @file lcc_impl.hpp
- * @author Nishant Mehta
- *
- * Implementation of Local Coordinate Coding
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_LOCAL_COORDINATE_CODING_LCC_IMPL_HPP
-#define __MLPACK_METHODS_LOCAL_COORDINATE_CODING_LCC_IMPL_HPP
-
-// In case it hasn't been included yet.
-#include "lcc.hpp"
-
-namespace mlpack {
-namespace lcc {
-
-template<typename DictionaryInitializer>
-LocalCoordinateCoding<DictionaryInitializer>::LocalCoordinateCoding(
- const arma::mat& data,
- const size_t atoms,
- const double lambda) :
- atoms(atoms),
- data(data),
- codes(atoms, data.n_cols),
- lambda(lambda)
-{
- // Initialize the dictionary.
- DictionaryInitializer::Initialize(data, atoms, dictionary);
-}
-
-template<typename DictionaryInitializer>
-void LocalCoordinateCoding<DictionaryInitializer>::Encode(
- const size_t maxIterations,
- const double objTolerance)
-{
- Timer::Start("local_coordinate_coding");
-
- double lastObjVal = DBL_MAX;
-
- // Take the initial coding step, which has to happen before entering the main
- // loop.
- Log::Info << "Initial Coding Step." << std::endl;
-
- OptimizeCode();
- arma::uvec adjacencies = find(codes);
-
- Log::Info << " Sparsity level: " << 100.0 * ((double)(adjacencies.n_elem)) /
- ((double)(atoms * data.n_cols)) << "%.\n";
- Log::Info << " Objective value: " << Objective(adjacencies) << "."
- << std::endl;
-
- for (size_t t = 1; t != maxIterations; t++)
- {
- Log::Info << "Iteration " << t << " of " << maxIterations << "."
- << std::endl;
-
- // First step: optimize the dictionary.
- Log::Info << "Performing dictionary step..." << std::endl;
- OptimizeDictionary(adjacencies);
- double dsObjVal = Objective(adjacencies);
- Log::Info << " Objective value: " << Objective(adjacencies) << "."
- << std::endl;
-
- // Second step: perform the coding.
- Log::Info << "Performing coding step..." << std::endl;
- OptimizeCode();
- adjacencies = find(codes);
- Log::Info << " Sparsity level: " << 100.0 * ((double) (adjacencies.n_elem))
- / ((double)(atoms * data.n_cols)) << "%.\n";
-
- // Terminate if the objective increased in the coding step.
- double curObjVal = Objective(adjacencies);
- if (curObjVal > dsObjVal)
- {
- Log::Warn << "Objective increased in coding step! Terminating."
- << std::endl;
- break;
- }
-
- // Find the new objective value and improvement so we can check for
- // convergence.
- double improvement = lastObjVal - curObjVal;
- Log::Info << "Objective value: " << curObjVal << " (improvement "
- << std::scientific << improvement << ")." << std::endl;
-
- if (improvement < objTolerance)
- {
- Log::Info << "Converged within tolerance " << objTolerance << ".\n";
- break;
- }
-
- lastObjVal = curObjVal;
- }
-
- Timer::Stop("local_coordinate_coding");
-}
-
-template<typename DictionaryInitializer>
-void LocalCoordinateCoding<DictionaryInitializer>::OptimizeCode()
-{
- arma::mat invSqDists = 1.0 / (repmat(trans(sum(square(dictionary))), 1,
- data.n_cols) + repmat(sum(square(data)), atoms, 1) - 2 * trans(dictionary)
- * data);
-
- arma::mat dictGram = trans(dictionary) * dictionary;
- arma::mat dictGramTD(dictGram.n_rows, dictGram.n_cols);
-
- for (size_t i = 0; i < data.n_cols; i++)
- {
- // report progress
- if ((i % 100) == 0)
- {
- Log::Debug << "Optimization at point " << i << "." << std::endl;
- }
-
- arma::vec invW = invSqDists.unsafe_col(i);
- arma::mat dictPrime = dictionary * diagmat(invW);
-
- arma::mat dictGramTD = diagmat(invW) * dictGram * diagmat(invW);
-
- bool useCholesky = false;
- regression::LARS lars(useCholesky, dictGramTD, 0.5 * lambda);
-
- // Run LARS for this point, by making an alias of the point and passing
- // that.
- arma::vec beta = codes.unsafe_col(i);
- lars.Regress(dictPrime, data.unsafe_col(i), beta, false);
- beta %= invW; // Remember, beta is an alias of codes.col(i).
- }
-}
-
-template<typename DictionaryInitializer>
-void LocalCoordinateCoding<DictionaryInitializer>::OptimizeDictionary(
- arma::uvec adjacencies)
-{
- // Count number of atomic neighbors for each point x^i.
- arma::uvec neighborCounts = arma::zeros<arma::uvec>(data.n_cols, 1);
- if (adjacencies.n_elem > 0)
- {
- // This gets the column index. Intentional integer division.
- size_t curPointInd = (size_t) (adjacencies(0) / atoms);
- ++neighborCounts(curPointInd);
-
- size_t nextColIndex = (curPointInd + 1) * atoms;
- for (size_t l = 1; l < adjacencies.n_elem; l++)
- {
- // If l no longer refers to an element in this column, advance the column
- // number accordingly.
- if (adjacencies(l) >= nextColIndex)
- {
- curPointInd = (size_t) (adjacencies(l) / atoms);
- nextColIndex = (curPointInd + 1) * atoms;
- }
-
- ++neighborCounts(curPointInd);
- }
- }
-
- // Build dataPrime := [X x^1 ... x^1 ... x^n ... x^n]
- // where each x^i is repeated for the number of neighbors x^i has.
- arma::mat dataPrime = arma::zeros(data.n_rows,
- data.n_cols + adjacencies.n_elem);
-
- dataPrime(arma::span::all, arma::span(0, data.n_cols - 1)) = data;
-
- size_t curCol = data.n_cols;
- for (size_t i = 0; i < data.n_cols; i++)
- {
- if (neighborCounts(i) > 0)
- {
- dataPrime(arma::span::all, arma::span(curCol, curCol + neighborCounts(i)
- - 1)) = repmat(data.col(i), 1, neighborCounts(i));
- }
- curCol += neighborCounts(i);
- }
-
- // Handle the case of inactive atoms (atoms not used in the given coding).
- std::vector<size_t> inactiveAtoms;
- for (size_t j = 0; j < atoms; ++j)
- if (accu(codes.row(j) != 0) == 0)
- inactiveAtoms.push_back(j);
-
- const size_t nInactiveAtoms = inactiveAtoms.size();
- const size_t nActiveAtoms = atoms - nInactiveAtoms;
-
- // Efficient construction of codes restricted to active atoms.
- arma::mat codesPrime = arma::zeros(nActiveAtoms, data.n_cols +
- adjacencies.n_elem);
- arma::vec wSquared = arma::ones(data.n_cols + adjacencies.n_elem, 1);
-
- if (nInactiveAtoms > 0)
- {
- Log::Warn << "There are " << nInactiveAtoms
- << " inactive atoms. They will be re-initialized randomly.\n";
-
- // Create matrix holding only active codes.
- arma::mat activeCodes;
- math::RemoveRows(codes, inactiveAtoms, activeCodes);
-
- // Create reverse atom lookup for active atoms.
- arma::uvec atomReverseLookup(atoms);
- size_t inactiveOffset = 0;
- for (size_t i = 0; i < atoms; ++i)
- {
- if (inactiveAtoms[inactiveOffset] == i)
- ++inactiveOffset;
- else
- atomReverseLookup(i - inactiveOffset) = i;
- }
-
- codesPrime(arma::span::all, arma::span(0, data.n_cols - 1)) = activeCodes;
-
- // Fill the rest of codesPrime.
- for (size_t l = 0; l < adjacencies.n_elem; ++l)
- {
- // Recover the location in the codes matrix that this adjacency refers to.
- size_t atomInd = adjacencies(l) % atoms;
- size_t pointInd = (size_t) (adjacencies(l) / atoms);
-
- // Fill matrix.
- codesPrime(atomReverseLookup(atomInd), data.n_cols + l) = 1.0;
- wSquared(data.n_cols + l) = codes(atomInd, pointInd);
- }
- }
- else
- {
- // All atoms are active.
- codesPrime(arma::span::all, arma::span(0, data.n_cols - 1)) = codes;
-
- for (size_t l = 0; l < adjacencies.n_elem; ++l)
- {
- // Recover the location in the codes matrix that this adjacency refers to.
- size_t atomInd = adjacencies(l) % atoms;
- size_t pointInd = (size_t) (adjacencies(l) / atoms);
-
- // Fill matrix.
- codesPrime(atomInd, data.n_cols + l) = 1.0;
- wSquared(data.n_cols + l) = codes(atomInd, pointInd);
- }
- }
-
- wSquared.subvec(data.n_cols, wSquared.n_elem - 1) = lambda *
- abs(wSquared.subvec(data.n_cols, wSquared.n_elem - 1));
-
- // Solve system.
- if (nInactiveAtoms == 0)
- {
- // No inactive atoms. We can solve directly.
- arma::mat A = codesPrime * diagmat(wSquared) * trans(codesPrime);
- arma::mat B = codesPrime * diagmat(wSquared) * trans(dataPrime);
-
- dictionary = trans(solve(A, B));
- /*
- dictionary = trans(solve(codesPrime * diagmat(wSquared) * trans(codesPrime),
- codesPrime * diagmat(wSquared) * trans(dataPrime)));
- */
- }
- else
- {
- // Inactive atoms must be reinitialized randomly, so we cannot solve
- // directly for the entire dictionary estimate.
- arma::mat dictionaryActive =
- trans(solve(codesPrime * diagmat(wSquared) * trans(codesPrime),
- codesPrime * diagmat(wSquared) * trans(dataPrime)));
-
- // Update all atoms.
- size_t currentInactiveIndex = 0;
- for (size_t i = 0; i < atoms; ++i)
- {
- if (inactiveAtoms[currentInactiveIndex] == i)
- {
- // This atom is inactive. Reinitialize it randomly.
- dictionary.col(i) = (data.col(math::RandInt(data.n_cols)) +
- data.col(math::RandInt(data.n_cols)) +
- data.col(math::RandInt(data.n_cols)));
-
- // Now normalize the atom.
- dictionary.col(i) /= norm(dictionary.col(i), 2);
-
- // Increment inactive atom counter.
- ++currentInactiveIndex;
- }
- else
- {
- // Update estimate.
- dictionary.col(i) = dictionaryActive.col(i - currentInactiveIndex);
- }
- }
- }
-}
-
-template<typename DictionaryInitializer>
-double LocalCoordinateCoding<DictionaryInitializer>::Objective(
- arma::uvec adjacencies) const
-{
- double weightedL1NormZ = 0;
-
- for (size_t l = 0; l < adjacencies.n_elem; l++)
- {
- // Map adjacency back to its location in the codes matrix.
- const size_t atomInd = adjacencies(l) % atoms;
- const size_t pointInd = (size_t) (adjacencies(l) / atoms);
-
- weightedL1NormZ += fabs(codes(atomInd, pointInd)) * arma::as_scalar(
- arma::sum(arma::square(dictionary.col(atomInd) - data.col(pointInd))));
- }
-
- double froNormResidual = norm(data - dictionary * codes, "fro");
- return std::pow(froNormResidual, 2.0) + lambda * weightedL1NormZ;
-}
-
-}; // namespace lcc
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding/lcc_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/local_coordinate_coding/lcc_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding/lcc_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding/lcc_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,329 @@
+/**
+ * @file lcc_impl.hpp
+ * @author Nishant Mehta
+ *
+ * Implementation of Local Coordinate Coding
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_LOCAL_COORDINATE_CODING_LCC_IMPL_HPP
+#define __MLPACK_METHODS_LOCAL_COORDINATE_CODING_LCC_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "lcc.hpp"
+
+namespace mlpack {
+namespace lcc {
+
+template<typename DictionaryInitializer>
+LocalCoordinateCoding<DictionaryInitializer>::LocalCoordinateCoding(
+ const arma::mat& data,
+ const size_t atoms,
+ const double lambda) :
+ atoms(atoms),
+ data(data),
+ codes(atoms, data.n_cols),
+ lambda(lambda)
+{
+ // Initialize the dictionary.
+ DictionaryInitializer::Initialize(data, atoms, dictionary);
+}
+
+template<typename DictionaryInitializer>
+void LocalCoordinateCoding<DictionaryInitializer>::Encode(
+ const size_t maxIterations,
+ const double objTolerance)
+{
+ Timer::Start("local_coordinate_coding");
+
+ double lastObjVal = DBL_MAX;
+
+ // Take the initial coding step, which has to happen before entering the main
+ // loop.
+ Log::Info << "Initial Coding Step." << std::endl;
+
+ OptimizeCode();
+ arma::uvec adjacencies = find(codes);
+
+ Log::Info << " Sparsity level: " << 100.0 * ((double)(adjacencies.n_elem)) /
+ ((double)(atoms * data.n_cols)) << "%.\n";
+ Log::Info << " Objective value: " << Objective(adjacencies) << "."
+ << std::endl;
+
+ for (size_t t = 1; t != maxIterations; t++)
+ {
+ Log::Info << "Iteration " << t << " of " << maxIterations << "."
+ << std::endl;
+
+ // First step: optimize the dictionary.
+ Log::Info << "Performing dictionary step..." << std::endl;
+ OptimizeDictionary(adjacencies);
+ double dsObjVal = Objective(adjacencies);
+ Log::Info << " Objective value: " << Objective(adjacencies) << "."
+ << std::endl;
+
+ // Second step: perform the coding.
+ Log::Info << "Performing coding step..." << std::endl;
+ OptimizeCode();
+ adjacencies = find(codes);
+ Log::Info << " Sparsity level: " << 100.0 * ((double) (adjacencies.n_elem))
+ / ((double)(atoms * data.n_cols)) << "%.\n";
+
+ // Terminate if the objective increased in the coding step.
+ double curObjVal = Objective(adjacencies);
+ if (curObjVal > dsObjVal)
+ {
+ Log::Warn << "Objective increased in coding step! Terminating."
+ << std::endl;
+ break;
+ }
+
+ // Find the new objective value and improvement so we can check for
+ // convergence.
+ double improvement = lastObjVal - curObjVal;
+ Log::Info << "Objective value: " << curObjVal << " (improvement "
+ << std::scientific << improvement << ")." << std::endl;
+
+ if (improvement < objTolerance)
+ {
+ Log::Info << "Converged within tolerance " << objTolerance << ".\n";
+ break;
+ }
+
+ lastObjVal = curObjVal;
+ }
+
+ Timer::Stop("local_coordinate_coding");
+}
+
+template<typename DictionaryInitializer>
+void LocalCoordinateCoding<DictionaryInitializer>::OptimizeCode()
+{
+ arma::mat invSqDists = 1.0 / (repmat(trans(sum(square(dictionary))), 1,
+ data.n_cols) + repmat(sum(square(data)), atoms, 1) - 2 * trans(dictionary)
+ * data);
+
+ arma::mat dictGram = trans(dictionary) * dictionary;
+ arma::mat dictGramTD(dictGram.n_rows, dictGram.n_cols);
+
+ for (size_t i = 0; i < data.n_cols; i++)
+ {
+ // report progress
+ if ((i % 100) == 0)
+ {
+ Log::Debug << "Optimization at point " << i << "." << std::endl;
+ }
+
+ arma::vec invW = invSqDists.unsafe_col(i);
+ arma::mat dictPrime = dictionary * diagmat(invW);
+
+ arma::mat dictGramTD = diagmat(invW) * dictGram * diagmat(invW);
+
+ bool useCholesky = false;
+ regression::LARS lars(useCholesky, dictGramTD, 0.5 * lambda);
+
+ // Run LARS for this point, by making an alias of the point and passing
+ // that.
+ arma::vec beta = codes.unsafe_col(i);
+ lars.Regress(dictPrime, data.unsafe_col(i), beta, false);
+ beta %= invW; // Remember, beta is an alias of codes.col(i).
+ }
+}
+
+template<typename DictionaryInitializer>
+void LocalCoordinateCoding<DictionaryInitializer>::OptimizeDictionary(
+ arma::uvec adjacencies)
+{
+ // Count number of atomic neighbors for each point x^i.
+ arma::uvec neighborCounts = arma::zeros<arma::uvec>(data.n_cols, 1);
+ if (adjacencies.n_elem > 0)
+ {
+ // This gets the column index. Intentional integer division.
+ size_t curPointInd = (size_t) (adjacencies(0) / atoms);
+ ++neighborCounts(curPointInd);
+
+ size_t nextColIndex = (curPointInd + 1) * atoms;
+ for (size_t l = 1; l < adjacencies.n_elem; l++)
+ {
+ // If l no longer refers to an element in this column, advance the column
+ // number accordingly.
+ if (adjacencies(l) >= nextColIndex)
+ {
+ curPointInd = (size_t) (adjacencies(l) / atoms);
+ nextColIndex = (curPointInd + 1) * atoms;
+ }
+
+ ++neighborCounts(curPointInd);
+ }
+ }
+
+ // Build dataPrime := [X x^1 ... x^1 ... x^n ... x^n]
+ // where each x^i is repeated for the number of neighbors x^i has.
+ arma::mat dataPrime = arma::zeros(data.n_rows,
+ data.n_cols + adjacencies.n_elem);
+
+ dataPrime(arma::span::all, arma::span(0, data.n_cols - 1)) = data;
+
+ size_t curCol = data.n_cols;
+ for (size_t i = 0; i < data.n_cols; i++)
+ {
+ if (neighborCounts(i) > 0)
+ {
+ dataPrime(arma::span::all, arma::span(curCol, curCol + neighborCounts(i)
+ - 1)) = repmat(data.col(i), 1, neighborCounts(i));
+ }
+ curCol += neighborCounts(i);
+ }
+
+ // Handle the case of inactive atoms (atoms not used in the given coding).
+ std::vector<size_t> inactiveAtoms;
+ for (size_t j = 0; j < atoms; ++j)
+ if (accu(codes.row(j) != 0) == 0)
+ inactiveAtoms.push_back(j);
+
+ const size_t nInactiveAtoms = inactiveAtoms.size();
+ const size_t nActiveAtoms = atoms - nInactiveAtoms;
+
+ // Efficient construction of codes restricted to active atoms.
+ arma::mat codesPrime = arma::zeros(nActiveAtoms, data.n_cols +
+ adjacencies.n_elem);
+ arma::vec wSquared = arma::ones(data.n_cols + adjacencies.n_elem, 1);
+
+ if (nInactiveAtoms > 0)
+ {
+ Log::Warn << "There are " << nInactiveAtoms
+ << " inactive atoms. They will be re-initialized randomly.\n";
+
+ // Create matrix holding only active codes.
+ arma::mat activeCodes;
+ math::RemoveRows(codes, inactiveAtoms, activeCodes);
+
+ // Create reverse atom lookup for active atoms.
+ arma::uvec atomReverseLookup(atoms);
+ size_t inactiveOffset = 0;
+ for (size_t i = 0; i < atoms; ++i)
+ {
+ if (inactiveAtoms[inactiveOffset] == i)
+ ++inactiveOffset;
+ else
+ atomReverseLookup(i - inactiveOffset) = i;
+ }
+
+ codesPrime(arma::span::all, arma::span(0, data.n_cols - 1)) = activeCodes;
+
+ // Fill the rest of codesPrime.
+ for (size_t l = 0; l < adjacencies.n_elem; ++l)
+ {
+ // Recover the location in the codes matrix that this adjacency refers to.
+ size_t atomInd = adjacencies(l) % atoms;
+ size_t pointInd = (size_t) (adjacencies(l) / atoms);
+
+ // Fill matrix.
+ codesPrime(atomReverseLookup(atomInd), data.n_cols + l) = 1.0;
+ wSquared(data.n_cols + l) = codes(atomInd, pointInd);
+ }
+ }
+ else
+ {
+ // All atoms are active.
+ codesPrime(arma::span::all, arma::span(0, data.n_cols - 1)) = codes;
+
+ for (size_t l = 0; l < adjacencies.n_elem; ++l)
+ {
+ // Recover the location in the codes matrix that this adjacency refers to.
+ size_t atomInd = adjacencies(l) % atoms;
+ size_t pointInd = (size_t) (adjacencies(l) / atoms);
+
+ // Fill matrix.
+ codesPrime(atomInd, data.n_cols + l) = 1.0;
+ wSquared(data.n_cols + l) = codes(atomInd, pointInd);
+ }
+ }
+
+ wSquared.subvec(data.n_cols, wSquared.n_elem - 1) = lambda *
+ abs(wSquared.subvec(data.n_cols, wSquared.n_elem - 1));
+
+ // Solve system.
+ if (nInactiveAtoms == 0)
+ {
+ // No inactive atoms. We can solve directly.
+ arma::mat A = codesPrime * diagmat(wSquared) * trans(codesPrime);
+ arma::mat B = codesPrime * diagmat(wSquared) * trans(dataPrime);
+
+ dictionary = trans(solve(A, B));
+ /*
+ dictionary = trans(solve(codesPrime * diagmat(wSquared) * trans(codesPrime),
+ codesPrime * diagmat(wSquared) * trans(dataPrime)));
+ */
+ }
+ else
+ {
+ // Inactive atoms must be reinitialized randomly, so we cannot solve
+ // directly for the entire dictionary estimate.
+ arma::mat dictionaryActive =
+ trans(solve(codesPrime * diagmat(wSquared) * trans(codesPrime),
+ codesPrime * diagmat(wSquared) * trans(dataPrime)));
+
+ // Update all atoms.
+ size_t currentInactiveIndex = 0;
+ for (size_t i = 0; i < atoms; ++i)
+ {
+ if (inactiveAtoms[currentInactiveIndex] == i)
+ {
+ // This atom is inactive. Reinitialize it randomly.
+ dictionary.col(i) = (data.col(math::RandInt(data.n_cols)) +
+ data.col(math::RandInt(data.n_cols)) +
+ data.col(math::RandInt(data.n_cols)));
+
+ // Now normalize the atom.
+ dictionary.col(i) /= norm(dictionary.col(i), 2);
+
+ // Increment inactive atom counter.
+ ++currentInactiveIndex;
+ }
+ else
+ {
+ // Update estimate.
+ dictionary.col(i) = dictionaryActive.col(i - currentInactiveIndex);
+ }
+ }
+ }
+}
+
+template<typename DictionaryInitializer>
+double LocalCoordinateCoding<DictionaryInitializer>::Objective(
+ arma::uvec adjacencies) const
+{
+ double weightedL1NormZ = 0;
+
+ for (size_t l = 0; l < adjacencies.n_elem; l++)
+ {
+ // Map adjacency back to its location in the codes matrix.
+ const size_t atomInd = adjacencies(l) % atoms;
+ const size_t pointInd = (size_t) (adjacencies(l) / atoms);
+
+ weightedL1NormZ += fabs(codes(atomInd, pointInd)) * arma::as_scalar(
+ arma::sum(arma::square(dictionary.col(atomInd) - data.col(pointInd))));
+ }
+
+ double froNormResidual = norm(data - dictionary * codes, "fro");
+ return std::pow(froNormResidual, 2.0) + lambda * weightedL1NormZ;
+}
+
+}; // namespace lcc
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,172 +0,0 @@
-/**
- * @file lcc_main.cpp
- * @author Nishant Mehta
- *
- * Executable for Local Coordinate Coding.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include "lcc.hpp"
-
-PROGRAM_INFO("Local Coordinate Coding",
- "An implementation of Local Coordinate Coding (LCC), which "
- "codes data that approximately lives on a manifold using a variation of l1-"
- "norm regularized sparse coding. Given a dense data matrix X with n points"
- " and d dimensions, LCC seeks to find a dense dictionary matrix D with k "
- "atoms in d dimensions, and a coding matrix Z with n points in k "
- "dimensions. Because of the regularization method used, the atoms in D "
- "should lie close to the manifold on which the data points lie."
- "\n\n"
- "The original data matrix X can then be reconstructed as D * Z. Therefore,"
- " this program finds a representation of each point in X as a sparse linear"
- " combination of atoms in the dictionary D."
- "\n\n"
- "The coding is found with an algorithm which alternates between a "
- "dictionary step, which updates the dictionary D, and a coding step, which "
- "updates the coding matrix Z."
- "\n\n"
- "To run this program, the input matrix X must be specified (with -i), along"
- " with the number of atoms in the dictionary (-k). An initial dictionary "
- "may also be specified with the --initial_dictionary option. The l1-norm "
- "regularization parameter is specified with -l. For example, to run LCC on"
- " the dataset in data.csv using 200 atoms and an l1-regularization "
- "parameter of 0.1, saving the dictionary into dict.csv and the codes into "
- "codes.csv, use "
- "\n\n"
- "$ local_coordinate_coding -i data.csv -k 200 -l 0.1 -d dict.csv -c "
- "codes.csv"
- "\n\n"
- "The maximum number of iterations may be specified with the -n option. "
- "Optionally, the input data matrix X can be normalized before coding with "
- "the -N option.");
-
-PARAM_STRING_REQ("input_file", "Filename of the input data.", "i");
-PARAM_INT_REQ("atoms", "Number of atoms in the dictionary.", "k");
-
-PARAM_DOUBLE("lambda", "Weighted l1-norm regularization parameter.", "l", 0.0);
-
-PARAM_INT("max_iterations", "Maximum number of iterations for LCC (0 indicates "
- "no limit).", "n", 0);
-
-PARAM_STRING("initial_dictionary", "Filename for optional initial dictionary.",
- "D", "");
-
-PARAM_STRING("dictionary_file", "Filename to save the output dictionary to.",
- "d", "dictionary.csv");
-PARAM_STRING("codes_file", "Filename to save the output codes to.", "c",
- "codes.csv");
-
-PARAM_FLAG("normalize", "If set, the input data matrix will be normalized "
- "before coding.", "N");
-
-PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
-
-PARAM_DOUBLE("objective_tolerance", "Tolerance for objective function.", "o",
- 0.01);
-
-using namespace arma;
-using namespace std;
-using namespace mlpack;
-using namespace mlpack::math;
-using namespace mlpack::lcc;
-using namespace mlpack::sparse_coding; // For NothingInitializer.
-
-int main(int argc, char* argv[])
-{
- CLI::ParseCommandLine(argc, argv);
-
- if (CLI::GetParam<int>("seed") != 0)
- RandomSeed((size_t) CLI::GetParam<int>("seed"));
- else
- RandomSeed((size_t) std::time(NULL));
-
- const double lambda = CLI::GetParam<double>("lambda");
-
- const string inputFile = CLI::GetParam<string>("input_file");
- const string dictionaryFile = CLI::GetParam<string>("dictionary_file");
- const string codesFile = CLI::GetParam<string>("codes_file");
- const string initialDictionaryFile =
- CLI::GetParam<string>("initial_dictionary");
-
- const size_t maxIterations = CLI::GetParam<int>("max_iterations");
- const size_t atoms = CLI::GetParam<int>("atoms");
-
- const bool normalize = CLI::HasParam("normalize");
-
- const double objTolerance = CLI::GetParam<double>("objective_tolerance");
-
- mat input;
- data::Load(inputFile, input, true);
-
- Log::Info << "Loaded " << input.n_cols << " point in " << input.n_rows
- << " dimensions." << endl;
-
- // Normalize each point if the user asked for it.
- if (normalize)
- {
- Log::Info << "Normalizing data before coding..." << endl;
- for (size_t i = 0; i < input.n_cols; ++i)
- input.col(i) /= norm(input.col(i), 2);
- }
-
- // If there is an initial dictionary, be sure we do not initialize one.
- if (initialDictionaryFile != "")
- {
- LocalCoordinateCoding<NothingInitializer> lcc(input, atoms, lambda);
-
- // Load initial dictionary directly into LCC object.
- data::Load(initialDictionaryFile, lcc.Dictionary(), true);
-
- // Validate size of initial dictionary.
- if (lcc.Dictionary().n_cols != atoms)
- {
- Log::Fatal << "The initial dictionary has " << lcc.Dictionary().n_cols
- << " atoms, but the number of atoms was specified to be " << atoms
- << "!" << endl;
- }
-
- if (lcc.Dictionary().n_rows != input.n_rows)
- {
- Log::Fatal << "The initial dictionary has " << lcc.Dictionary().n_rows
- << " dimensions, but the data has " << input.n_rows << " dimensions!"
- << endl;
- }
-
- // Run LCC.
- lcc.Encode(maxIterations, objTolerance);
-
- // Save the results.
- Log::Info << "Saving dictionary matrix to '" << dictionaryFile << "'.\n";
- data::Save(dictionaryFile, lcc.Dictionary());
- Log::Info << "Saving sparse codes to '" << codesFile << "'.\n";
- data::Save(codesFile, lcc.Codes());
- }
- else
- {
- // No initial dictionary.
- LocalCoordinateCoding<> lcc(input, atoms, lambda);
-
- // Run LCC.
- lcc.Encode(maxIterations, objTolerance);
-
- // Save the results.
- Log::Info << "Saving dictionary matrix to '" << dictionaryFile << "'.\n";
- data::Save(dictionaryFile, lcc.Dictionary());
- Log::Info << "Saving sparse codes to '" << codesFile << "'.\n";
- data::Save(codesFile, lcc.Codes());
- }
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,172 @@
+/**
+ * @file lcc_main.cpp
+ * @author Nishant Mehta
+ *
+ * Executable for Local Coordinate Coding.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include "lcc.hpp"
+
+PROGRAM_INFO("Local Coordinate Coding",
+ "An implementation of Local Coordinate Coding (LCC), which "
+ "codes data that approximately lives on a manifold using a variation of l1-"
+ "norm regularized sparse coding. Given a dense data matrix X with n points"
+ " and d dimensions, LCC seeks to find a dense dictionary matrix D with k "
+ "atoms in d dimensions, and a coding matrix Z with n points in k "
+ "dimensions. Because of the regularization method used, the atoms in D "
+ "should lie close to the manifold on which the data points lie."
+ "\n\n"
+ "The original data matrix X can then be reconstructed as D * Z. Therefore,"
+ " this program finds a representation of each point in X as a sparse linear"
+ " combination of atoms in the dictionary D."
+ "\n\n"
+ "The coding is found with an algorithm which alternates between a "
+ "dictionary step, which updates the dictionary D, and a coding step, which "
+ "updates the coding matrix Z."
+ "\n\n"
+ "To run this program, the input matrix X must be specified (with -i), along"
+ " with the number of atoms in the dictionary (-k). An initial dictionary "
+ "may also be specified with the --initial_dictionary option. The l1-norm "
+ "regularization parameter is specified with -l. For example, to run LCC on"
+ " the dataset in data.csv using 200 atoms and an l1-regularization "
+ "parameter of 0.1, saving the dictionary into dict.csv and the codes into "
+ "codes.csv, use "
+ "\n\n"
+ "$ local_coordinate_coding -i data.csv -k 200 -l 0.1 -d dict.csv -c "
+ "codes.csv"
+ "\n\n"
+ "The maximum number of iterations may be specified with the -n option. "
+ "Optionally, the input data matrix X can be normalized before coding with "
+ "the -N option.");
+
+PARAM_STRING_REQ("input_file", "Filename of the input data.", "i");
+PARAM_INT_REQ("atoms", "Number of atoms in the dictionary.", "k");
+
+PARAM_DOUBLE("lambda", "Weighted l1-norm regularization parameter.", "l", 0.0);
+
+PARAM_INT("max_iterations", "Maximum number of iterations for LCC (0 indicates "
+ "no limit).", "n", 0);
+
+PARAM_STRING("initial_dictionary", "Filename for optional initial dictionary.",
+ "D", "");
+
+PARAM_STRING("dictionary_file", "Filename to save the output dictionary to.",
+ "d", "dictionary.csv");
+PARAM_STRING("codes_file", "Filename to save the output codes to.", "c",
+ "codes.csv");
+
+PARAM_FLAG("normalize", "If set, the input data matrix will be normalized "
+ "before coding.", "N");
+
+PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
+
+PARAM_DOUBLE("objective_tolerance", "Tolerance for objective function.", "o",
+ 0.01);
+
+using namespace arma;
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::math;
+using namespace mlpack::lcc;
+using namespace mlpack::sparse_coding; // For NothingInitializer.
+
+int main(int argc, char* argv[])
+{
+ CLI::ParseCommandLine(argc, argv);
+
+ if (CLI::GetParam<int>("seed") != 0)
+ RandomSeed((size_t) CLI::GetParam<int>("seed"));
+ else
+ RandomSeed((size_t) std::time(NULL));
+
+ const double lambda = CLI::GetParam<double>("lambda");
+
+ const string inputFile = CLI::GetParam<string>("input_file");
+ const string dictionaryFile = CLI::GetParam<string>("dictionary_file");
+ const string codesFile = CLI::GetParam<string>("codes_file");
+ const string initialDictionaryFile =
+ CLI::GetParam<string>("initial_dictionary");
+
+ const size_t maxIterations = CLI::GetParam<int>("max_iterations");
+ const size_t atoms = CLI::GetParam<int>("atoms");
+
+ const bool normalize = CLI::HasParam("normalize");
+
+ const double objTolerance = CLI::GetParam<double>("objective_tolerance");
+
+ mat input;
+ data::Load(inputFile, input, true);
+
+ Log::Info << "Loaded " << input.n_cols << " point in " << input.n_rows
+ << " dimensions." << endl;
+
+ // Normalize each point if the user asked for it.
+ if (normalize)
+ {
+ Log::Info << "Normalizing data before coding..." << endl;
+ for (size_t i = 0; i < input.n_cols; ++i)
+ input.col(i) /= norm(input.col(i), 2);
+ }
+
+ // If there is an initial dictionary, be sure we do not initialize one.
+ if (initialDictionaryFile != "")
+ {
+ LocalCoordinateCoding<NothingInitializer> lcc(input, atoms, lambda);
+
+ // Load initial dictionary directly into LCC object.
+ data::Load(initialDictionaryFile, lcc.Dictionary(), true);
+
+ // Validate size of initial dictionary.
+ if (lcc.Dictionary().n_cols != atoms)
+ {
+ Log::Fatal << "The initial dictionary has " << lcc.Dictionary().n_cols
+ << " atoms, but the number of atoms was specified to be " << atoms
+ << "!" << endl;
+ }
+
+ if (lcc.Dictionary().n_rows != input.n_rows)
+ {
+ Log::Fatal << "The initial dictionary has " << lcc.Dictionary().n_rows
+ << " dimensions, but the data has " << input.n_rows << " dimensions!"
+ << endl;
+ }
+
+ // Run LCC.
+ lcc.Encode(maxIterations, objTolerance);
+
+ // Save the results.
+ Log::Info << "Saving dictionary matrix to '" << dictionaryFile << "'.\n";
+ data::Save(dictionaryFile, lcc.Dictionary());
+ Log::Info << "Saving sparse codes to '" << codesFile << "'.\n";
+ data::Save(codesFile, lcc.Codes());
+ }
+ else
+ {
+ // No initial dictionary.
+ LocalCoordinateCoding<> lcc(input, atoms, lambda);
+
+ // Run LCC.
+ lcc.Encode(maxIterations, objTolerance);
+
+ // Save the results.
+ Log::Info << "Saving dictionary matrix to '" << dictionaryFile << "'.\n";
+ data::Save(dictionaryFile, lcc.Dictionary());
+ Log::Info << "Saving sparse codes to '" << codesFile << "'.\n";
+ data::Save(codesFile, lcc.Codes());
+ }
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lsh/lsh_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/lsh/lsh_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lsh/lsh_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,169 +0,0 @@
-/**
- * @file lsh_main.cpp
- * @author Parikshit Ram
- *
- * This file computes the approximate nearest-neighbors using 2-stable
- * Locality-sensitive Hashing.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <time.h>
-
-#include <mlpack/core.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-
-#include <string>
-#include <fstream>
-#include <iostream>
-
-#include "lsh_search.hpp"
-
-using namespace std;
-using namespace mlpack;
-using namespace mlpack::neighbor;
-
-// Information about the program itself.
-PROGRAM_INFO("All K-Approximate-Nearest-Neighbor Search with LSH",
- "This program will calculate the k approximate-nearest-neighbors of a set "
- "of points using locality-sensitive hashing. You may specify a separate set"
- " of reference points and query points, or just a reference set which will "
- "be used as both the reference and query set. "
- "\n\n"
- "For example, the following will return 5 neighbors from the data for each "
- "point in 'input.csv' and store the distances in 'distances.csv' and the "
- "neighbors in the file 'neighbors.csv':"
- "\n\n"
- "$ lsh -k 5 -r input.csv -d distances.csv -n neighbors.csv "
- "\n\n"
- "The output files are organized such that row i and column j in the "
- "neighbors output file corresponds to the index of the point in the "
- "reference set which is the i'th nearest neighbor from the point in the "
- "query set with index j. Row i and column j in the distances output file "
- "corresponds to the distance between those two points."
- "\n\n"
- "Because this is approximate-nearest-neighbors search, results may be "
- "different from run to run. Thus, the --seed option can be specified to "
- "set the random seed.");
-
-// Define our input parameters that this program will take.
-PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
- "r");
-PARAM_STRING("distances_file", "File to output distances into.", "d", "");
-PARAM_STRING("neighbors_file", "File to output neighbors into.", "n", "");
-
-PARAM_INT_REQ("k", "Number of nearest neighbors to find.", "k");
-
-PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
-
-PARAM_INT("projections", "The number of hash functions for each table", "K",
- 10);
-PARAM_INT("tables", "The number of hash tables to be used.", "L", 30);
-PARAM_DOUBLE("hash_width", "The hash width for the first-level hashing in the "
- "LSH preprocessing. By default, the LSH class automatically estimates a "
- "hash width for its use.", "H", 0.0);
-PARAM_INT("second_hash_size", "The size of the second level hash table.", "M",
- 99901);
-PARAM_INT("bucket_size", "The size of a bucket in the second level hash.", "B",
- 500);
-PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
-
-int main(int argc, char *argv[])
-{
- // Give CLI the command line parameters the user passed in.
- CLI::ParseCommandLine(argc, argv);
-
- if (CLI::GetParam<int>("seed") != 0)
- math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
- else
- math::RandomSeed((size_t) time(NULL));
-
- // Get all the parameters.
- string referenceFile = CLI::GetParam<string>("reference_file");
- string distancesFile = CLI::GetParam<string>("distances_file");
- string neighborsFile = CLI::GetParam<string>("neighbors_file");
-
- size_t k = CLI::GetParam<int>("k");
- size_t secondHashSize = CLI::GetParam<int>("second_hash_size");
- size_t bucketSize = CLI::GetParam<int>("bucket_size");
-
- arma::mat referenceData;
- arma::mat queryData; // So it doesn't go out of scope.
- data::Load(referenceFile.c_str(), referenceData, true);
-
- Log::Info << "Loaded reference data from '" << referenceFile << "' ("
- << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
-
- // Sanity check on k value: must be greater than 0, must be less than the
- // number of reference points.
- if (k > referenceData.n_cols)
- {
- Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
- Log::Fatal << "than or equal to the number of reference points (";
- Log::Fatal << referenceData.n_cols << ")." << endl;
- }
-
- // Pick up the LSH-specific parameters.
- const size_t numProj = CLI::GetParam<int>("projections");
- const size_t numTables = CLI::GetParam<int>("tables");
- const double hashWidth = CLI::GetParam<double>("hash_width");
-
- arma::Mat<size_t> neighbors;
- arma::mat distances;
-
- if (CLI::GetParam<string>("query_file") != "")
- {
- string queryFile = CLI::GetParam<string>("query_file");
-
- data::Load(queryFile.c_str(), queryData, true);
- Log::Info << "Loaded query data from '" << queryFile << "' ("
- << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
- }
-
- if (hashWidth == 0.0)
- Log::Info << "Using LSH with " << numProj << " projections (K) and " <<
- numTables << " tables (L) with default hash width." << endl;
- else
- Log::Info << "Using LSH with " << numProj << " projections (K) and " <<
- numTables << " tables (L) with hash width(r): " << hashWidth << endl;
-
- Timer::Start("hash_building");
-
- LSHSearch<>* allkann;
-
- if (CLI::GetParam<string>("query_file") != "")
- allkann = new LSHSearch<>(referenceData, queryData, numProj, numTables,
- hashWidth, secondHashSize, bucketSize);
- else
- allkann = new LSHSearch<>(referenceData, numProj, numTables, hashWidth,
- secondHashSize, bucketSize);
-
- Timer::Stop("hash_building");
-
- Log::Info << "Computing " << k << " distance approximate nearest neighbors "
- << endl;
- allkann->Search(k, neighbors, distances);
-
- Log::Info << "Neighbors computed." << endl;
-
- // Save output.
- if (distancesFile != "")
- data::Save(distancesFile, distances);
-
- if (neighborsFile != "")
- data::Save(neighborsFile, neighbors);
-
- delete allkann;
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lsh/lsh_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/lsh/lsh_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lsh/lsh_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lsh/lsh_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,169 @@
+/**
+ * @file lsh_main.cpp
+ * @author Parikshit Ram
+ *
+ * This file computes the approximate nearest-neighbors using 2-stable
+ * Locality-sensitive Hashing.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <time.h>
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+
+#include <string>
+#include <fstream>
+#include <iostream>
+
+#include "lsh_search.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::neighbor;
+
+// Information about the program itself.
+PROGRAM_INFO("All K-Approximate-Nearest-Neighbor Search with LSH",
+ "This program will calculate the k approximate-nearest-neighbors of a set "
+ "of points using locality-sensitive hashing. You may specify a separate set"
+ " of reference points and query points, or just a reference set which will "
+ "be used as both the reference and query set. "
+ "\n\n"
+ "For example, the following will return 5 neighbors from the data for each "
+ "point in 'input.csv' and store the distances in 'distances.csv' and the "
+ "neighbors in the file 'neighbors.csv':"
+ "\n\n"
+ "$ lsh -k 5 -r input.csv -d distances.csv -n neighbors.csv "
+ "\n\n"
+ "The output files are organized such that row i and column j in the "
+ "neighbors output file corresponds to the index of the point in the "
+ "reference set which is the i'th nearest neighbor from the point in the "
+ "query set with index j. Row i and column j in the distances output file "
+ "corresponds to the distance between those two points."
+ "\n\n"
+ "Because this is approximate-nearest-neighbors search, results may be "
+ "different from run to run. Thus, the --seed option can be specified to "
+ "set the random seed.");
+
+// Define our input parameters that this program will take.
+PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
+ "r");
+PARAM_STRING("distances_file", "File to output distances into.", "d", "");
+PARAM_STRING("neighbors_file", "File to output neighbors into.", "n", "");
+
+PARAM_INT_REQ("k", "Number of nearest neighbors to find.", "k");
+
+PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
+
+PARAM_INT("projections", "The number of hash functions for each table", "K",
+ 10);
+PARAM_INT("tables", "The number of hash tables to be used.", "L", 30);
+PARAM_DOUBLE("hash_width", "The hash width for the first-level hashing in the "
+ "LSH preprocessing. By default, the LSH class automatically estimates a "
+ "hash width for its use.", "H", 0.0);
+PARAM_INT("second_hash_size", "The size of the second level hash table.", "M",
+ 99901);
+PARAM_INT("bucket_size", "The size of a bucket in the second level hash.", "B",
+ 500);
+PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
+
+int main(int argc, char *argv[])
+{
+ // Give CLI the command line parameters the user passed in.
+ CLI::ParseCommandLine(argc, argv);
+
+ if (CLI::GetParam<int>("seed") != 0)
+ math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
+ else
+ math::RandomSeed((size_t) time(NULL));
+
+ // Get all the parameters.
+ string referenceFile = CLI::GetParam<string>("reference_file");
+ string distancesFile = CLI::GetParam<string>("distances_file");
+ string neighborsFile = CLI::GetParam<string>("neighbors_file");
+
+ size_t k = CLI::GetParam<int>("k");
+ size_t secondHashSize = CLI::GetParam<int>("second_hash_size");
+ size_t bucketSize = CLI::GetParam<int>("bucket_size");
+
+ arma::mat referenceData;
+ arma::mat queryData; // So it doesn't go out of scope.
+ data::Load(referenceFile.c_str(), referenceData, true);
+
+ Log::Info << "Loaded reference data from '" << referenceFile << "' ("
+ << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
+
+ // Sanity check on k value: must be greater than 0, must be less than the
+ // number of reference points.
+ if (k > referenceData.n_cols)
+ {
+ Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
+ Log::Fatal << "than or equal to the number of reference points (";
+ Log::Fatal << referenceData.n_cols << ")." << endl;
+ }
+
+ // Pick up the LSH-specific parameters.
+ const size_t numProj = CLI::GetParam<int>("projections");
+ const size_t numTables = CLI::GetParam<int>("tables");
+ const double hashWidth = CLI::GetParam<double>("hash_width");
+
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ string queryFile = CLI::GetParam<string>("query_file");
+
+ data::Load(queryFile.c_str(), queryData, true);
+ Log::Info << "Loaded query data from '" << queryFile << "' ("
+ << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+ }
+
+ if (hashWidth == 0.0)
+ Log::Info << "Using LSH with " << numProj << " projections (K) and " <<
+ numTables << " tables (L) with default hash width." << endl;
+ else
+ Log::Info << "Using LSH with " << numProj << " projections (K) and " <<
+ numTables << " tables (L) with hash width(r): " << hashWidth << endl;
+
+ Timer::Start("hash_building");
+
+ LSHSearch<>* allkann;
+
+ if (CLI::GetParam<string>("query_file") != "")
+ allkann = new LSHSearch<>(referenceData, queryData, numProj, numTables,
+ hashWidth, secondHashSize, bucketSize);
+ else
+ allkann = new LSHSearch<>(referenceData, numProj, numTables, hashWidth,
+ secondHashSize, bucketSize);
+
+ Timer::Stop("hash_building");
+
+ Log::Info << "Computing " << k << " distance approximate nearest neighbors "
+ << endl;
+ allkann->Search(k, neighbors, distances);
+
+ Log::Info << "Neighbors computed." << endl;
+
+ // Save output.
+ if (distancesFile != "")
+ data::Save(distancesFile, distances);
+
+ if (neighborsFile != "")
+ data::Save(neighborsFile, neighbors);
+
+ delete allkann;
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lsh/lsh_search.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/lsh/lsh_search.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lsh/lsh_search.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,253 +0,0 @@
-/**
- * @file lsh_search.hpp
- * @author Parikshit Ram
- *
- * Defines the LSHSearch class, which performs an approximate
- * nearest neighbor search for a queries in a query set
- * over a given dataset using Locality-sensitive hashing
- * with 2-stable distributions.
- *
- * The details of this method can be found in the following paper:
- *
- * @inproceedings{datar2004locality,
- * title={Locality-sensitive hashing scheme based on p-stable distributions},
- * author={Datar, M. and Immorlica, N. and Indyk, P. and Mirrokni, V.S.},
- * booktitle=
- * {Proceedings of the 12th Annual Symposium on Computational Geometry},
- * pages={253--262},
- * year={2004},
- * organization={ACM}
- * }
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
-#define __MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
-
-#include <mlpack/core.hpp>
-#include <vector>
-#include <string>
-
-#include <mlpack/core/metrics/lmetric.hpp>
-#include <mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp>
-
-namespace mlpack {
-namespace neighbor {
-
-/**
- * The LSHSearch class -- This class builds a hash on the reference set
- * and uses this hash to compute the distance-approximate nearest-neighbors
- * of the given queries.
- *
- * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
- */
-template<typename SortPolicy = NearestNeighborSort>
-class LSHSearch
-{
- public:
- /**
- * This function initializes the LSH class. It builds the hash on the
- * reference set with 2-stable distributions. See the individual functions
- * performing the hashing for details on how the hashing is done.
- *
- * @param referenceSet Set of reference points.
- * @param querySet Set of query points.
- * @param numProj Number of projections in each hash table (anything between
- * 10-50 might be a decent choice).
- * @param numTables Total number of hash tables (anything between 10-20
- * should suffice).
- * @param hashWidth The width of hash for every table. If 0 (the default) is
- * provided, then the hash width is automatically obtained by computing
- * the average pairwise distance of 25 pairs. This should be a reasonable
- * upper bound on the nearest-neighbor distance in general.
- * @param secondHashSize The size of the second hash table. This should be a
- * large prime number.
- * @param bucketSize The size of the bucket in the second hash table. This is
- * the maximum number of points that can be hashed into single bucket.
- * Default values are already provided here.
- */
- LSHSearch(const arma::mat& referenceSet,
- const arma::mat& querySet,
- const size_t numProj,
- const size_t numTables,
- const double hashWidth = 0.0,
- const size_t secondHashSize = 99901,
- const size_t bucketSize = 500);
-
- /**
- * This function initializes the LSH class. It builds the hash on the
- * reference set with 2-stable distributions. See the individual functions
- * performing the hashing for details on how the hashing is done.
- *
- * @param referenceSet Set of reference points and the set of queries.
- * @param numProj Number of projections in each hash table (anything between
- * 10-50 might be a decent choice).
- * @param numTables Total number of hash tables (anything between 10-20
- * should suffice).
- * @param hashWidth The width of hash for every table. If 0 (the default) is
- * provided, then the hash width is automatically obtained by computing
- * the average pairwise distance of 25 pairs. This should be a reasonable
- * upper bound on the nearest-neighbor distance in general.
- * @param secondHashSize The size of the second hash table. This should be a
- * large prime number.
- * @param bucketSize The size of the bucket in the second hash table. This is
- * the maximum number of points that can be hashed into single bucket.
- * Default values are already provided here.
- */
- LSHSearch(const arma::mat& referenceSet,
- const size_t numProj,
- const size_t numTables,
- const double hashWidth = 0.0,
- const size_t secondHashSize = 99901,
- const size_t bucketSize = 500);
-
- /**
- * Compute the nearest neighbors and store the output in the given matrices.
- * The matrices will be set to the size of n columns by k rows, where n is
- * the number of points in the query dataset and k is the number of neighbors
- * being searched for.
- *
- * @param k Number of neighbors to search for.
- * @param resultingNeighbors Matrix storing lists of neighbors for each query
- * point.
- * @param distances Matrix storing distances of neighbors for each query
- * point.
- * @param numTablesToSearch This parameter allows the user to have control
- * over the number of hash tables to be searched. This allows
- * the user to pick the number of tables it can afford for the time
- * available without having to build hashing for every table size.
- * By default, this is set to zero in which case all tables are
- * considered.
- */
- void Search(const size_t k,
- arma::Mat<size_t>& resultingNeighbors,
- arma::mat& distances,
- const size_t numTablesToSearch = 0);
-
- private:
- /**
- * This function builds a hash table with two levels of hashing as presented
- * in the paper. This function first hashes the points with 'numProj' random
- * projections to a single hash table creating (key, point ID) pairs where the
- * key is a 'numProj'-dimensional integer vector.
- *
- * Then each key in this hash table is hashed into a second hash table using a
- * standard hash.
- *
- * This function does not have any parameters and relies on parameters which
- * are private members of this class, intialized during the class
- * intialization.
- */
- void BuildHash();
-
- /**
- * This function takes a query and hashes it into each of the hash tables to
- * get keys for the query and then the key is hashed to a bucket of the second
- * hash table and all the points (if any) in those buckets are collected as
- * the potential neighbor candidates.
- *
- * @param queryIndex The index of the query currently being processed.
- * @param referenceIndices The list of neighbor candidates obtained from
- * hashing the query into all the hash tables and eventually into
- * multiple buckets of the second hash table.
- */
- void ReturnIndicesFromTable(const size_t queryIndex,
- arma::uvec& referenceIndices,
- size_t numTablesToSearch);
-
- /**
- * This is a helper function that computes the distance of the query to the
- * neighbor candidates and appropriately stores the best 'k' candidates
- *
- * @param queryIndex The index of the query in question
- * @param referenceIndex The index of the neighbor candidate in question
- */
- double BaseCase(const size_t queryIndex, const size_t referenceIndex);
-
- /**
- * This is a helper function that efficiently inserts better neighbor
- * candidates into an existing set of neighbor candidates. This function is
- * only called by the 'BaseCase' function.
- *
- * @param queryIndex This is the index of the query being processed currently
- * @param pos The position of the neighbor candidate in the current list of
- * neighbor candidates.
- * @param neighbor The neighbor candidate that is being inserted into the list
- * of the best 'k' candidates for the query in question.
- * @param distance The distance of the query to the neighbor candidate.
- */
- void InsertNeighbor(const size_t queryIndex, const size_t pos,
- const size_t neighbor, const double distance);
-
- private:
- //! Reference dataset.
- const arma::mat& referenceSet;
-
- //! Query dataset (may not be given).
- const arma::mat& querySet;
-
- //! The number of projections
- const size_t numProj;
-
- //! The number of hash tables
- const size_t numTables;
-
- //! The std::vector containing the projection matrix of each table
- std::vector<arma::mat> projections; // should be [numProj x dims] x numTables
-
- //! The list of the offset 'b' for each of the projection for each table
- arma::mat offsets; // should be numProj x numTables
-
- //! The hash width
- double hashWidth;
-
- //! The big prime representing the size of the second hash
- const size_t secondHashSize;
-
- //! The weights of the second hash
- arma::vec secondHashWeights;
-
- //! The bucket size of the second hash
- const size_t bucketSize;
-
- //! Instantiation of the metric.
- metric::SquaredEuclideanDistance metric;
-
- //! The final hash table; should be (< secondHashSize) x bucketSize.
- arma::Mat<size_t> secondHashTable;
-
- //! The number of elements present in each hash bucket; should be
- //! secondHashSize.
- arma::Col<size_t> bucketContentSize;
-
- //! For a particular hash value, points to the row in secondHashTable
- //! corresponding to this value. Should be secondHashSize.
- arma::Col<size_t> bucketRowInHashTable;
-
- //! The pointer to the nearest neighbor distances.
- arma::mat* distancePtr;
-
- //! The pointer to the nearest neighbor indices.
- arma::Mat<size_t>* neighborPtr;
-}; // class LSHSearch
-
-}; // namespace neighbor
-}; // namespace mlpack
-
-// Include implementation.
-#include "lsh_search_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lsh/lsh_search.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/lsh/lsh_search.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lsh/lsh_search.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lsh/lsh_search.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,253 @@
+/**
+ * @file lsh_search.hpp
+ * @author Parikshit Ram
+ *
+ * Defines the LSHSearch class, which performs an approximate
+ * nearest neighbor search for a queries in a query set
+ * over a given dataset using Locality-sensitive hashing
+ * with 2-stable distributions.
+ *
+ * The details of this method can be found in the following paper:
+ *
+ * @inproceedings{datar2004locality,
+ * title={Locality-sensitive hashing scheme based on p-stable distributions},
+ * author={Datar, M. and Immorlica, N. and Indyk, P. and Mirrokni, V.S.},
+ * booktitle=
+ * {Proceedings of the 12th Annual Symposium on Computational Geometry},
+ * pages={253--262},
+ * year={2004},
+ * organization={ACM}
+ * }
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
+
+#include <mlpack/core.hpp>
+#include <vector>
+#include <string>
+
+#include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp>
+
+namespace mlpack {
+namespace neighbor {
+
+/**
+ * The LSHSearch class -- This class builds a hash on the reference set
+ * and uses this hash to compute the distance-approximate nearest-neighbors
+ * of the given queries.
+ *
+ * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
+ */
+template<typename SortPolicy = NearestNeighborSort>
+class LSHSearch
+{
+ public:
+ /**
+ * This function initializes the LSH class. It builds the hash on the
+ * reference set with 2-stable distributions. See the individual functions
+ * performing the hashing for details on how the hashing is done.
+ *
+ * @param referenceSet Set of reference points.
+ * @param querySet Set of query points.
+ * @param numProj Number of projections in each hash table (anything between
+ * 10-50 might be a decent choice).
+ * @param numTables Total number of hash tables (anything between 10-20
+ * should suffice).
+ * @param hashWidth The width of hash for every table. If 0 (the default) is
+ * provided, then the hash width is automatically obtained by computing
+ * the average pairwise distance of 25 pairs. This should be a reasonable
+ * upper bound on the nearest-neighbor distance in general.
+ * @param secondHashSize The size of the second hash table. This should be a
+ * large prime number.
+ * @param bucketSize The size of the bucket in the second hash table. This is
+ * the maximum number of points that can be hashed into single bucket.
+ * Default values are already provided here.
+ */
+ LSHSearch(const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ const size_t numProj,
+ const size_t numTables,
+ const double hashWidth = 0.0,
+ const size_t secondHashSize = 99901,
+ const size_t bucketSize = 500);
+
+ /**
+ * This function initializes the LSH class. It builds the hash on the
+ * reference set with 2-stable distributions. See the individual functions
+ * performing the hashing for details on how the hashing is done.
+ *
+ * @param referenceSet Set of reference points and the set of queries.
+ * @param numProj Number of projections in each hash table (anything between
+ * 10-50 might be a decent choice).
+ * @param numTables Total number of hash tables (anything between 10-20
+ * should suffice).
+ * @param hashWidth The width of hash for every table. If 0 (the default) is
+ * provided, then the hash width is automatically obtained by computing
+ * the average pairwise distance of 25 pairs. This should be a reasonable
+ * upper bound on the nearest-neighbor distance in general.
+ * @param secondHashSize The size of the second hash table. This should be a
+ * large prime number.
+ * @param bucketSize The size of the bucket in the second hash table. This is
+ * the maximum number of points that can be hashed into single bucket.
+ * Default values are already provided here.
+ */
+ LSHSearch(const arma::mat& referenceSet,
+ const size_t numProj,
+ const size_t numTables,
+ const double hashWidth = 0.0,
+ const size_t secondHashSize = 99901,
+ const size_t bucketSize = 500);
+
+ /**
+ * Compute the nearest neighbors and store the output in the given matrices.
+ * The matrices will be set to the size of n columns by k rows, where n is
+ * the number of points in the query dataset and k is the number of neighbors
+ * being searched for.
+ *
+ * @param k Number of neighbors to search for.
+ * @param resultingNeighbors Matrix storing lists of neighbors for each query
+ * point.
+ * @param distances Matrix storing distances of neighbors for each query
+ * point.
+ * @param numTablesToSearch This parameter allows the user to have control
+ * over the number of hash tables to be searched. This allows
+ * the user to pick the number of tables it can afford for the time
+ * available without having to build hashing for every table size.
+ * By default, this is set to zero in which case all tables are
+ * considered.
+ */
+ void Search(const size_t k,
+ arma::Mat<size_t>& resultingNeighbors,
+ arma::mat& distances,
+ const size_t numTablesToSearch = 0);
+
+ private:
+ /**
+ * This function builds a hash table with two levels of hashing as presented
+ * in the paper. This function first hashes the points with 'numProj' random
+ * projections to a single hash table creating (key, point ID) pairs where the
+ * key is a 'numProj'-dimensional integer vector.
+ *
+ * Then each key in this hash table is hashed into a second hash table using a
+ * standard hash.
+ *
+ * This function does not have any parameters and relies on parameters which
+ * are private members of this class, intialized during the class
+ * intialization.
+ */
+ void BuildHash();
+
+ /**
+ * This function takes a query and hashes it into each of the hash tables to
+ * get keys for the query and then the key is hashed to a bucket of the second
+ * hash table and all the points (if any) in those buckets are collected as
+ * the potential neighbor candidates.
+ *
+ * @param queryIndex The index of the query currently being processed.
+ * @param referenceIndices The list of neighbor candidates obtained from
+ * hashing the query into all the hash tables and eventually into
+ * multiple buckets of the second hash table.
+ */
+ void ReturnIndicesFromTable(const size_t queryIndex,
+ arma::uvec& referenceIndices,
+ size_t numTablesToSearch);
+
+ /**
+ * This is a helper function that computes the distance of the query to the
+ * neighbor candidates and appropriately stores the best 'k' candidates
+ *
+ * @param queryIndex The index of the query in question
+ * @param referenceIndex The index of the neighbor candidate in question
+ */
+ double BaseCase(const size_t queryIndex, const size_t referenceIndex);
+
+ /**
+ * This is a helper function that efficiently inserts better neighbor
+ * candidates into an existing set of neighbor candidates. This function is
+ * only called by the 'BaseCase' function.
+ *
+ * @param queryIndex This is the index of the query being processed currently
+ * @param pos The position of the neighbor candidate in the current list of
+ * neighbor candidates.
+ * @param neighbor The neighbor candidate that is being inserted into the list
+ * of the best 'k' candidates for the query in question.
+ * @param distance The distance of the query to the neighbor candidate.
+ */
+ void InsertNeighbor(const size_t queryIndex, const size_t pos,
+ const size_t neighbor, const double distance);
+
+ private:
+ //! Reference dataset.
+ const arma::mat& referenceSet;
+
+ //! Query dataset (may not be given).
+ const arma::mat& querySet;
+
+ //! The number of projections
+ const size_t numProj;
+
+ //! The number of hash tables
+ const size_t numTables;
+
+ //! The std::vector containing the projection matrix of each table
+ std::vector<arma::mat> projections; // should be [numProj x dims] x numTables
+
+ //! The list of the offset 'b' for each of the projection for each table
+ arma::mat offsets; // should be numProj x numTables
+
+ //! The hash width
+ double hashWidth;
+
+ //! The big prime representing the size of the second hash
+ const size_t secondHashSize;
+
+ //! The weights of the second hash
+ arma::vec secondHashWeights;
+
+ //! The bucket size of the second hash
+ const size_t bucketSize;
+
+ //! Instantiation of the metric.
+ metric::SquaredEuclideanDistance metric;
+
+ //! The final hash table; should be (< secondHashSize) x bucketSize.
+ arma::Mat<size_t> secondHashTable;
+
+ //! The number of elements present in each hash bucket; should be
+ //! secondHashSize.
+ arma::Col<size_t> bucketContentSize;
+
+ //! For a particular hash value, points to the row in secondHashTable
+ //! corresponding to this value. Should be secondHashSize.
+ arma::Col<size_t> bucketRowInHashTable;
+
+ //! The pointer to the nearest neighbor distances.
+ arma::mat* distancePtr;
+
+ //! The pointer to the nearest neighbor indices.
+ arma::Mat<size_t>* neighborPtr;
+}; // class LSHSearch
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+// Include implementation.
+#include "lsh_search_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lsh/lsh_search_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/lsh/lsh_search_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lsh/lsh_search_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,409 +0,0 @@
-/**
- * @file lsh_search_impl.hpp
- * @author Parikshit Ram
- *
- * Implementation of the LSHSearch class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_IMPL_HPP
-#define __MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_IMPL_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace neighbor {
-
-// Construct the object.
-template<typename SortPolicy>
-LSHSearch<SortPolicy>::
-LSHSearch(const arma::mat& referenceSet,
- const arma::mat& querySet,
- const size_t numProj,
- const size_t numTables,
- const double hashWidthIn,
- const size_t secondHashSize,
- const size_t bucketSize) :
- referenceSet(referenceSet),
- querySet(querySet),
- numProj(numProj),
- numTables(numTables),
- hashWidth(hashWidthIn),
- secondHashSize(secondHashSize),
- bucketSize(bucketSize)
-{
- if (hashWidth == 0.0) // The user has not provided any value.
- {
- // Compute a heuristic hash width from the data.
- for (size_t i = 0; i < 25; i++)
- {
- size_t p1 = (size_t) math::RandInt(referenceSet.n_cols);
- size_t p2 = (size_t) math::RandInt(referenceSet.n_cols);
-
- hashWidth += std::sqrt(metric.Evaluate(referenceSet.unsafe_col(p1),
- referenceSet.unsafe_col(p2)));
- }
-
- hashWidth /= 25;
- }
-
- Log::Info << "Hash width chosen as: " << hashWidth << std::endl;
-
- BuildHash();
-}
-
-template<typename SortPolicy>
-LSHSearch<SortPolicy>::
-LSHSearch(const arma::mat& referenceSet,
- const size_t numProj,
- const size_t numTables,
- const double hashWidthIn,
- const size_t secondHashSize,
- const size_t bucketSize) :
- referenceSet(referenceSet),
- querySet(referenceSet),
- numProj(numProj),
- numTables(numTables),
- hashWidth(hashWidthIn),
- secondHashSize(secondHashSize),
- bucketSize(bucketSize)
-{
- if (hashWidth == 0.0) // The user has not provided any value.
- {
- // Compute a heuristic hash width from the data.
- for (size_t i = 0; i < 25; i++)
- {
- size_t p1 = (size_t) math::RandInt(referenceSet.n_cols);
- size_t p2 = (size_t) math::RandInt(referenceSet.n_cols);
-
- hashWidth += std::sqrt(metric.Evaluate(referenceSet.unsafe_col(p1),
- referenceSet.unsafe_col(p2)));
- }
-
- hashWidth /= 25;
- }
-
- Log::Info << "Hash width chosen as: " << hashWidth << std::endl;
-
- BuildHash();
-}
-
-template<typename SortPolicy>
-void LSHSearch<SortPolicy>::
-InsertNeighbor(const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance)
-{
- // We only memmove() if there is actually a need to shift something.
- if (pos < (distancePtr->n_rows - 1))
- {
- int len = (distancePtr->n_rows - 1) - pos;
- memmove(distancePtr->colptr(queryIndex) + (pos + 1),
- distancePtr->colptr(queryIndex) + pos,
- sizeof(double) * len);
- memmove(neighborPtr->colptr(queryIndex) + (pos + 1),
- neighborPtr->colptr(queryIndex) + pos,
- sizeof(size_t) * len);
- }
-
- // Now put the new information in the right index.
- (*distancePtr)(pos, queryIndex) = distance;
- (*neighborPtr)(pos, queryIndex) = neighbor;
-}
-
-template<typename SortPolicy>
-inline force_inline
-double LSHSearch<SortPolicy>::
-BaseCase(const size_t queryIndex, const size_t referenceIndex)
-{
- // If the datasets are the same, then this search is only using one dataset
- // and we should not return identical points.
- if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
- return 0.0;
-
- double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
- referenceSet.unsafe_col(referenceIndex));
-
- // If this distance is better than any of the current candidates, the
- // SortDistance() function will give us the position to insert it into.
- arma::vec queryDist = distancePtr->unsafe_col(queryIndex);
- size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
-
- // SortDistance() returns (size_t() - 1) if we shouldn't add it.
- if (insertPosition != (size_t() - 1))
- InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance);
-
- return distance;
-}
-
-template<typename SortPolicy>
-void LSHSearch<SortPolicy>::
-ReturnIndicesFromTable(const size_t queryIndex,
- arma::uvec& referenceIndices,
- size_t numTablesToSearch)
-{
- // Decide on the number of tables to look into.
- if (numTablesToSearch == 0) // If no user input is given, search all.
- numTablesToSearch = numTables;
-
- // Sanity check to make sure that the existing number of tables is not
- // exceeded.
- if (numTablesToSearch > numTables)
- numTablesToSearch = numTables;
-
- // Hash the query in each of the 'numTablesToSearch' hash tables using the
- // 'numProj' projections for each table. This gives us 'numTablesToSearch'
- // keys for the query where each key is a 'numProj' dimensional integer
- // vector.
-
- // Compute the projection of the query in each table.
- arma::mat allProjInTables(numProj, numTablesToSearch);
- for (size_t i = 0; i < numTablesToSearch; i++)
- {
- allProjInTables.unsafe_col(i) = projections[i].t() *
- querySet.unsafe_col(queryIndex);
- }
- allProjInTables += offsets.cols(0, numTablesToSearch - 1);
- allProjInTables /= hashWidth;
-
- // Compute the hash value of each key of the query into a bucket of the
- // 'secondHashTable' using the 'secondHashWeights'.
- arma::rowvec hashVec = secondHashWeights.t() * arma::floor(allProjInTables);
-
- for (size_t i = 0; i < hashVec.n_elem; i++)
- hashVec[i] = (double) ((size_t) hashVec[i] % secondHashSize);
-
- Log::Assert(hashVec.n_elem == numTablesToSearch);
-
- // For all the buckets that the query is hashed into, sequentially
- // collect the indices in those buckets.
- arma::Col<size_t> refPointsConsidered;
- refPointsConsidered.zeros(referenceSet.n_cols);
-
- for (size_t i = 0; i < hashVec.n_elem; i++) // For all tables.
- {
- size_t hashInd = (size_t) hashVec[i];
-
- if (bucketContentSize[hashInd] > 0)
- {
- // Pick the indices in the bucket corresponding to 'hashInd'.
- size_t tableRow = bucketRowInHashTable[hashInd];
- assert(tableRow < secondHashSize);
- assert(tableRow < secondHashTable.n_rows);
-
- for (size_t j = 0; j < bucketContentSize[hashInd]; j++)
- refPointsConsidered[secondHashTable(tableRow, j)]++;
- }
- }
-
- referenceIndices = arma::find(refPointsConsidered > 0);
-}
-
-
-template<typename SortPolicy>
-void LSHSearch<SortPolicy>::
-Search(const size_t k,
- arma::Mat<size_t>& resultingNeighbors,
- arma::mat& distances,
- const size_t numTablesToSearch)
-{
- neighborPtr = &resultingNeighbors;
- distancePtr = &distances;
-
- // Set the size of the neighbor and distance matrices.
- neighborPtr->set_size(k, querySet.n_cols);
- distancePtr->set_size(k, querySet.n_cols);
- distancePtr->fill(SortPolicy::WorstDistance());
- neighborPtr->fill(referenceSet.n_cols);
-
- size_t avgIndicesReturned = 0;
-
- Timer::Start("computing_neighbors");
-
- // Go through every query point sequentially.
- for (size_t i = 0; i < querySet.n_cols; i++)
- {
- // Hash every query into every hash table and eventually into the
- // 'secondHashTable' to obtain the neighbor candidates.
- arma::uvec refIndices;
- ReturnIndicesFromTable(i, refIndices, numTablesToSearch);
-
- // An informative book-keeping for the number of neighbor candidates
- // returned on average.
- avgIndicesReturned += refIndices.n_elem;
-
- // Sequentially go through all the candidates and save the best 'k'
- // candidates.
- for (size_t j = 0; j < refIndices.n_elem; j++)
- BaseCase(i, (size_t) refIndices[j]);
- }
-
- Timer::Stop("computing_neighbors");
-
- avgIndicesReturned /= querySet.n_cols;
- Log::Info << avgIndicesReturned << " distinct indices returned on average." <<
- std::endl;
-}
-
-template<typename SortPolicy>
-void LSHSearch<SortPolicy>::
-BuildHash()
-{
- // The first level hash for a single table outputs a 'numProj'-dimensional
- // integer key for each point in the set -- (key, pointID)
- // The key creation details are presented below
- //
- // The second level hash is performed by hashing the key to
- // an integer in the range [0, 'secondHashSize').
- //
- // This is done by creating a weight vector 'secondHashWeights' of
- // length 'numProj' with each entry an integer randomly chosen
- // between [0, 'secondHashSize').
- //
- // Then the bucket for any key and its corresponding point is
- // given by <key, 'secondHashWeights'> % 'secondHashSize'
- // and the corresponding point ID is put into that bucket.
-
- // Step I: Prepare the second level hash.
-
- // Obtain the weights for the second hash.
- secondHashWeights = arma::floor(arma::randu(numProj) *
- (double) secondHashSize);
-
- // The 'secondHashTable' is initially an empty matrix of size
- // ('secondHashSize' x 'bucketSize'). But by only filling the buckets
- // as points land in them allows us to shrink the size of the
- // 'secondHashTable' at the end of the hashing.
-
- // Fill the second hash table n = referenceSet.n_cols. This is because no
- // point has index 'n' so the presence of this in the bucket denotes that
- // there are no more points in this bucket.
- secondHashTable.set_size(secondHashSize, bucketSize);
- secondHashTable.fill(referenceSet.n_cols);
-
- // Keep track of the size of each bucket in the hash. At the end of hashing
- // most buckets will be empty.
- bucketContentSize.zeros(secondHashSize);
-
- // Instead of putting the points in the row corresponding to the bucket, we
- // chose the next empty row and keep track of the row in which the bucket
- // lies. This allows us to stack together and slice out the empty buckets at
- // the end of the hashing.
- bucketRowInHashTable.set_size(secondHashSize);
- bucketRowInHashTable.fill(secondHashSize);
-
- // Keep track of number of non-empty rows in the 'secondHashTable'.
- size_t numRowsInTable = 0;
-
- // Step II: The offsets for all projections in all tables.
- // Since the 'offsets' are in [0, hashWidth], we obtain the 'offsets'
- // as randu(numProj, numTables) * hashWidth.
- offsets.randu(numProj, numTables);
- offsets *= hashWidth;
-
- // Step III: Create each hash table in the first level hash one by one and
- // putting them directly into the 'secondHashTable' for memory efficiency.
- for (size_t i = 0; i < numTables; i++)
- {
- // Step IV: Obtain the 'numProj' projections for each table.
-
- // For L2 metric, 2-stable distributions are used, and
- // the normal Z ~ N(0, 1) is a 2-stable distribution.
- arma::mat projMat;
- projMat.randn(referenceSet.n_rows, numProj);
-
- // Save the projection matrix for querying.
- projections.push_back(projMat);
-
- // Step V: create the 'numProj'-dimensional key for each point in each
- // table.
-
- // The following code performs the task of hashing each point to a
- // 'numProj'-dimensional integer key. Hence you get a ('numProj' x
- // 'referenceSet.n_cols') key matrix.
- //
- // For a single table, let the 'numProj' projections be denoted by 'proj_i'
- // and the corresponding offset be 'offset_i'. Then the key of a single
- // point is obtained as:
- // key = { floor( (<proj_i, point> + offset_i) / 'hashWidth' ) forall i }
- arma::mat offsetMat = arma::repmat(offsets.unsafe_col(i), 1,
- referenceSet.n_cols);
- arma::mat hashMat = projMat.t() * referenceSet;
- hashMat += offsetMat;
- hashMat /= hashWidth;
-
- // Step VI: Putting the points in the 'secondHashTable' by hashing the key.
- // Now we hash every key, point ID to its corresponding bucket.
- arma::rowvec secondHashVec = secondHashWeights.t()
- * arma::floor(hashMat);
-
- // This gives us the bucket for the corresponding point ID.
- for (size_t j = 0; j < secondHashVec.n_elem; j++)
- secondHashVec[j] = (double)((size_t) secondHashVec[j] % secondHashSize);
-
- Log::Assert(secondHashVec.n_elem == referenceSet.n_cols);
-
- // Insert the point in the corresponding row to its bucket in the
- // 'secondHashTable'.
- for (size_t j = 0; j < secondHashVec.n_elem; j++)
- {
- // This is the bucket number.
- size_t hashInd = (size_t) secondHashVec[j];
- // The point ID is 'j'.
-
- // If this is currently an empty bucket, start a new row keep track of
- // which row corresponds to the bucket.
- if (bucketContentSize[hashInd] == 0)
- {
- // Start a new row for hash.
- bucketRowInHashTable[hashInd] = numRowsInTable;
- secondHashTable(numRowsInTable, 0) = j;
-
- numRowsInTable++;
- }
-
- else
- {
- // If bucket is already present in the 'secondHashTable', find the
- // corresponding row and insert the point ID in this row unless the
- // bucket is full, in which case, do nothing.
- if (bucketContentSize[hashInd] < bucketSize)
- secondHashTable(bucketRowInHashTable[hashInd],
- bucketContentSize[hashInd]) = j;
- }
-
- // Increment the count of the points in this bucket.
- if (bucketContentSize[hashInd] < bucketSize)
- bucketContentSize[hashInd]++;
- } // Loop over all points in the reference set.
- } // Loop over tables.
-
- // Step VII: Condensing the 'secondHashTable'.
- size_t maxBucketSize = 0;
- for (size_t i = 0; i < bucketContentSize.n_elem; i++)
- if (bucketContentSize[i] > maxBucketSize)
- maxBucketSize = bucketContentSize[i];
-
- Log::Info << "Final hash table size: (" << numRowsInTable << " x "
- << maxBucketSize << ")" << std::endl;
- secondHashTable.resize(numRowsInTable, maxBucketSize);
-}
-
-}; // namespace neighbor
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lsh/lsh_search_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/lsh/lsh_search_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lsh/lsh_search_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/lsh/lsh_search_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,409 @@
+/**
+ * @file lsh_search_impl.hpp
+ * @author Parikshit Ram
+ *
+ * Implementation of the LSHSearch class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_IMPL_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_IMPL_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace neighbor {
+
+// Construct the object.
+template<typename SortPolicy>
+LSHSearch<SortPolicy>::
+LSHSearch(const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ const size_t numProj,
+ const size_t numTables,
+ const double hashWidthIn,
+ const size_t secondHashSize,
+ const size_t bucketSize) :
+ referenceSet(referenceSet),
+ querySet(querySet),
+ numProj(numProj),
+ numTables(numTables),
+ hashWidth(hashWidthIn),
+ secondHashSize(secondHashSize),
+ bucketSize(bucketSize)
+{
+ if (hashWidth == 0.0) // The user has not provided any value.
+ {
+ // Compute a heuristic hash width from the data.
+ for (size_t i = 0; i < 25; i++)
+ {
+ size_t p1 = (size_t) math::RandInt(referenceSet.n_cols);
+ size_t p2 = (size_t) math::RandInt(referenceSet.n_cols);
+
+ hashWidth += std::sqrt(metric.Evaluate(referenceSet.unsafe_col(p1),
+ referenceSet.unsafe_col(p2)));
+ }
+
+ hashWidth /= 25;
+ }
+
+ Log::Info << "Hash width chosen as: " << hashWidth << std::endl;
+
+ BuildHash();
+}
+
+template<typename SortPolicy>
+LSHSearch<SortPolicy>::
+LSHSearch(const arma::mat& referenceSet,
+ const size_t numProj,
+ const size_t numTables,
+ const double hashWidthIn,
+ const size_t secondHashSize,
+ const size_t bucketSize) :
+ referenceSet(referenceSet),
+ querySet(referenceSet),
+ numProj(numProj),
+ numTables(numTables),
+ hashWidth(hashWidthIn),
+ secondHashSize(secondHashSize),
+ bucketSize(bucketSize)
+{
+ if (hashWidth == 0.0) // The user has not provided any value.
+ {
+ // Compute a heuristic hash width from the data.
+ for (size_t i = 0; i < 25; i++)
+ {
+ size_t p1 = (size_t) math::RandInt(referenceSet.n_cols);
+ size_t p2 = (size_t) math::RandInt(referenceSet.n_cols);
+
+ hashWidth += std::sqrt(metric.Evaluate(referenceSet.unsafe_col(p1),
+ referenceSet.unsafe_col(p2)));
+ }
+
+ hashWidth /= 25;
+ }
+
+ Log::Info << "Hash width chosen as: " << hashWidth << std::endl;
+
+ BuildHash();
+}
+
+template<typename SortPolicy>
+void LSHSearch<SortPolicy>::
+InsertNeighbor(const size_t queryIndex,
+ const size_t pos,
+ const size_t neighbor,
+ const double distance)
+{
+ // We only memmove() if there is actually a need to shift something.
+ if (pos < (distancePtr->n_rows - 1))
+ {
+ int len = (distancePtr->n_rows - 1) - pos;
+ memmove(distancePtr->colptr(queryIndex) + (pos + 1),
+ distancePtr->colptr(queryIndex) + pos,
+ sizeof(double) * len);
+ memmove(neighborPtr->colptr(queryIndex) + (pos + 1),
+ neighborPtr->colptr(queryIndex) + pos,
+ sizeof(size_t) * len);
+ }
+
+ // Now put the new information in the right index.
+ (*distancePtr)(pos, queryIndex) = distance;
+ (*neighborPtr)(pos, queryIndex) = neighbor;
+}
+
+template<typename SortPolicy>
+inline force_inline
+double LSHSearch<SortPolicy>::
+BaseCase(const size_t queryIndex, const size_t referenceIndex)
+{
+ // If the datasets are the same, then this search is only using one dataset
+ // and we should not return identical points.
+ if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
+ return 0.0;
+
+ double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
+ referenceSet.unsafe_col(referenceIndex));
+
+ // If this distance is better than any of the current candidates, the
+ // SortDistance() function will give us the position to insert it into.
+ arma::vec queryDist = distancePtr->unsafe_col(queryIndex);
+ size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
+
+ // SortDistance() returns (size_t() - 1) if we shouldn't add it.
+ if (insertPosition != (size_t() - 1))
+ InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance);
+
+ return distance;
+}
+
+template<typename SortPolicy>
+void LSHSearch<SortPolicy>::
+ReturnIndicesFromTable(const size_t queryIndex,
+ arma::uvec& referenceIndices,
+ size_t numTablesToSearch)
+{
+ // Decide on the number of tables to look into.
+ if (numTablesToSearch == 0) // If no user input is given, search all.
+ numTablesToSearch = numTables;
+
+ // Sanity check to make sure that the existing number of tables is not
+ // exceeded.
+ if (numTablesToSearch > numTables)
+ numTablesToSearch = numTables;
+
+ // Hash the query in each of the 'numTablesToSearch' hash tables using the
+ // 'numProj' projections for each table. This gives us 'numTablesToSearch'
+ // keys for the query where each key is a 'numProj' dimensional integer
+ // vector.
+
+ // Compute the projection of the query in each table.
+ arma::mat allProjInTables(numProj, numTablesToSearch);
+ for (size_t i = 0; i < numTablesToSearch; i++)
+ {
+ allProjInTables.unsafe_col(i) = projections[i].t() *
+ querySet.unsafe_col(queryIndex);
+ }
+ allProjInTables += offsets.cols(0, numTablesToSearch - 1);
+ allProjInTables /= hashWidth;
+
+ // Compute the hash value of each key of the query into a bucket of the
+ // 'secondHashTable' using the 'secondHashWeights'.
+ arma::rowvec hashVec = secondHashWeights.t() * arma::floor(allProjInTables);
+
+ for (size_t i = 0; i < hashVec.n_elem; i++)
+ hashVec[i] = (double) ((size_t) hashVec[i] % secondHashSize);
+
+ Log::Assert(hashVec.n_elem == numTablesToSearch);
+
+ // For all the buckets that the query is hashed into, sequentially
+ // collect the indices in those buckets.
+ arma::Col<size_t> refPointsConsidered;
+ refPointsConsidered.zeros(referenceSet.n_cols);
+
+ for (size_t i = 0; i < hashVec.n_elem; i++) // For all tables.
+ {
+ size_t hashInd = (size_t) hashVec[i];
+
+ if (bucketContentSize[hashInd] > 0)
+ {
+ // Pick the indices in the bucket corresponding to 'hashInd'.
+ size_t tableRow = bucketRowInHashTable[hashInd];
+ assert(tableRow < secondHashSize);
+ assert(tableRow < secondHashTable.n_rows);
+
+ for (size_t j = 0; j < bucketContentSize[hashInd]; j++)
+ refPointsConsidered[secondHashTable(tableRow, j)]++;
+ }
+ }
+
+ referenceIndices = arma::find(refPointsConsidered > 0);
+}
+
+
+template<typename SortPolicy>
+void LSHSearch<SortPolicy>::
+Search(const size_t k,
+ arma::Mat<size_t>& resultingNeighbors,
+ arma::mat& distances,
+ const size_t numTablesToSearch)
+{
+ neighborPtr = &resultingNeighbors;
+ distancePtr = &distances;
+
+ // Set the size of the neighbor and distance matrices.
+ neighborPtr->set_size(k, querySet.n_cols);
+ distancePtr->set_size(k, querySet.n_cols);
+ distancePtr->fill(SortPolicy::WorstDistance());
+ neighborPtr->fill(referenceSet.n_cols);
+
+ size_t avgIndicesReturned = 0;
+
+ Timer::Start("computing_neighbors");
+
+ // Go through every query point sequentially.
+ for (size_t i = 0; i < querySet.n_cols; i++)
+ {
+ // Hash every query into every hash table and eventually into the
+ // 'secondHashTable' to obtain the neighbor candidates.
+ arma::uvec refIndices;
+ ReturnIndicesFromTable(i, refIndices, numTablesToSearch);
+
+ // An informative book-keeping for the number of neighbor candidates
+ // returned on average.
+ avgIndicesReturned += refIndices.n_elem;
+
+ // Sequentially go through all the candidates and save the best 'k'
+ // candidates.
+ for (size_t j = 0; j < refIndices.n_elem; j++)
+ BaseCase(i, (size_t) refIndices[j]);
+ }
+
+ Timer::Stop("computing_neighbors");
+
+ avgIndicesReturned /= querySet.n_cols;
+ Log::Info << avgIndicesReturned << " distinct indices returned on average." <<
+ std::endl;
+}
+
+template<typename SortPolicy>
+void LSHSearch<SortPolicy>::
+BuildHash()
+{
+ // The first level hash for a single table outputs a 'numProj'-dimensional
+ // integer key for each point in the set -- (key, pointID)
+ // The key creation details are presented below
+ //
+ // The second level hash is performed by hashing the key to
+ // an integer in the range [0, 'secondHashSize').
+ //
+ // This is done by creating a weight vector 'secondHashWeights' of
+ // length 'numProj' with each entry an integer randomly chosen
+ // between [0, 'secondHashSize').
+ //
+ // Then the bucket for any key and its corresponding point is
+ // given by <key, 'secondHashWeights'> % 'secondHashSize'
+ // and the corresponding point ID is put into that bucket.
+
+ // Step I: Prepare the second level hash.
+
+ // Obtain the weights for the second hash.
+ secondHashWeights = arma::floor(arma::randu(numProj) *
+ (double) secondHashSize);
+
+ // The 'secondHashTable' is initially an empty matrix of size
+ // ('secondHashSize' x 'bucketSize'). But by only filling the buckets
+ // as points land in them allows us to shrink the size of the
+ // 'secondHashTable' at the end of the hashing.
+
+ // Fill the second hash table n = referenceSet.n_cols. This is because no
+ // point has index 'n' so the presence of this in the bucket denotes that
+ // there are no more points in this bucket.
+ secondHashTable.set_size(secondHashSize, bucketSize);
+ secondHashTable.fill(referenceSet.n_cols);
+
+ // Keep track of the size of each bucket in the hash. At the end of hashing
+ // most buckets will be empty.
+ bucketContentSize.zeros(secondHashSize);
+
+ // Instead of putting the points in the row corresponding to the bucket, we
+ // chose the next empty row and keep track of the row in which the bucket
+ // lies. This allows us to stack together and slice out the empty buckets at
+ // the end of the hashing.
+ bucketRowInHashTable.set_size(secondHashSize);
+ bucketRowInHashTable.fill(secondHashSize);
+
+ // Keep track of number of non-empty rows in the 'secondHashTable'.
+ size_t numRowsInTable = 0;
+
+ // Step II: The offsets for all projections in all tables.
+ // Since the 'offsets' are in [0, hashWidth], we obtain the 'offsets'
+ // as randu(numProj, numTables) * hashWidth.
+ offsets.randu(numProj, numTables);
+ offsets *= hashWidth;
+
+ // Step III: Create each hash table in the first level hash one by one and
+ // putting them directly into the 'secondHashTable' for memory efficiency.
+ for (size_t i = 0; i < numTables; i++)
+ {
+ // Step IV: Obtain the 'numProj' projections for each table.
+
+ // For L2 metric, 2-stable distributions are used, and
+ // the normal Z ~ N(0, 1) is a 2-stable distribution.
+ arma::mat projMat;
+ projMat.randn(referenceSet.n_rows, numProj);
+
+ // Save the projection matrix for querying.
+ projections.push_back(projMat);
+
+ // Step V: create the 'numProj'-dimensional key for each point in each
+ // table.
+
+ // The following code performs the task of hashing each point to a
+ // 'numProj'-dimensional integer key. Hence you get a ('numProj' x
+ // 'referenceSet.n_cols') key matrix.
+ //
+ // For a single table, let the 'numProj' projections be denoted by 'proj_i'
+ // and the corresponding offset be 'offset_i'. Then the key of a single
+ // point is obtained as:
+ // key = { floor( (<proj_i, point> + offset_i) / 'hashWidth' ) forall i }
+ arma::mat offsetMat = arma::repmat(offsets.unsafe_col(i), 1,
+ referenceSet.n_cols);
+ arma::mat hashMat = projMat.t() * referenceSet;
+ hashMat += offsetMat;
+ hashMat /= hashWidth;
+
+ // Step VI: Putting the points in the 'secondHashTable' by hashing the key.
+ // Now we hash every key, point ID to its corresponding bucket.
+ arma::rowvec secondHashVec = secondHashWeights.t()
+ * arma::floor(hashMat);
+
+ // This gives us the bucket for the corresponding point ID.
+ for (size_t j = 0; j < secondHashVec.n_elem; j++)
+ secondHashVec[j] = (double)((size_t) secondHashVec[j] % secondHashSize);
+
+ Log::Assert(secondHashVec.n_elem == referenceSet.n_cols);
+
+ // Insert the point in the corresponding row to its bucket in the
+ // 'secondHashTable'.
+ for (size_t j = 0; j < secondHashVec.n_elem; j++)
+ {
+ // This is the bucket number.
+ size_t hashInd = (size_t) secondHashVec[j];
+ // The point ID is 'j'.
+
+ // If this is currently an empty bucket, start a new row keep track of
+ // which row corresponds to the bucket.
+ if (bucketContentSize[hashInd] == 0)
+ {
+ // Start a new row for hash.
+ bucketRowInHashTable[hashInd] = numRowsInTable;
+ secondHashTable(numRowsInTable, 0) = j;
+
+ numRowsInTable++;
+ }
+
+ else
+ {
+ // If bucket is already present in the 'secondHashTable', find the
+ // corresponding row and insert the point ID in this row unless the
+ // bucket is full, in which case, do nothing.
+ if (bucketContentSize[hashInd] < bucketSize)
+ secondHashTable(bucketRowInHashTable[hashInd],
+ bucketContentSize[hashInd]) = j;
+ }
+
+ // Increment the count of the points in this bucket.
+ if (bucketContentSize[hashInd] < bucketSize)
+ bucketContentSize[hashInd]++;
+ } // Loop over all points in the reference set.
+ } // Loop over tables.
+
+ // Step VII: Condensing the 'secondHashTable'.
+ size_t maxBucketSize = 0;
+ for (size_t i = 0; i < bucketContentSize.n_elem; i++)
+ if (bucketContentSize[i] > maxBucketSize)
+ maxBucketSize = bucketContentSize[i];
+
+ Log::Info << "Final hash table size: (" << numRowsInTable << " x "
+ << maxBucketSize << ")" << std::endl;
+ secondHashTable.resize(numRowsInTable, maxBucketSize);
+}
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/mvu/mvu.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/mvu/mvu.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/mvu/mvu.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,120 +0,0 @@
-/**
- * @file mvu.cpp
- * @author Ryan Curtin
- *
- * Implementation of the MVU class and its auxiliary objective function class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "mvu.hpp"
-
-//#include <mlpack/core/optimizers/aug_lagrangian/aug_lagrangian.hpp>
-#include <mlpack/core/optimizers/lrsdp/lrsdp.hpp>
-
-#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
-
-using namespace mlpack;
-using namespace mlpack::mvu;
-using namespace mlpack::optimization;
-
-MVU::MVU(const arma::mat& data) : data(data)
-{
- // Nothing to do.
-}
-
-void MVU::Unfold(const size_t newDim,
- const size_t numNeighbors,
- arma::mat& outputData)
-{
- // First we have to choose the output point. We'll take a linear projection
- // of the data for now (this is probably not a good final solution).
-// outputData = trans(data.rows(0, newDim - 1));
- // Following Nick's idea.
- outputData.randu(data.n_cols, newDim);
-
- // The number of constraints is the number of nearest neighbors plus one.
- LRSDP mvuSolver(numNeighbors * data.n_cols + 1, outputData);
-
- // Set up the objective. Because we are maximizing the trace of (R R^T),
- // we'll instead state it as min(-I_n * (R R^T)), meaning C() is -I_n.
- mvuSolver.C().eye(data.n_cols, data.n_cols);
- mvuSolver.C() *= -1;
-
- // Now set up each of the constraints.
- // The first constraint is trace(ones * R * R^T) = 0.
- mvuSolver.B()[0] = 0;
- mvuSolver.A()[0].ones(data.n_cols, data.n_cols);
-
- // All of our other constraints will be sparse except the first. So set that
- // vector of modes accordingly.
- mvuSolver.AModes().ones();
- mvuSolver.AModes()[0] = 0;
-
- // Now all of the other constraints. We first have to run AllkNN to get the
- // list of nearest neighbors.
- arma::Mat<size_t> neighbors;
- arma::mat distances;
-
- AllkNN allknn(data);
- allknn.Search(numNeighbors, neighbors, distances);
-
- // Add each of the other constraints. They are sparse constraints:
- // Tr(A_ij K) = d_ij;
- // A_ij = zeros except for 1 at (i, i), (j, j); -1 at (i, j), (j, i).
- for (size_t i = 0; i < neighbors.n_cols; ++i)
- {
- for (size_t j = 0; j < numNeighbors; ++j)
- {
- // This is the index of the constraint.
- const size_t index = (i * numNeighbors) + j + 1;
-
- arma::mat& aRef = mvuSolver.A()[index];
-
- aRef.set_size(3, 4);
-
- // A_ij(i, i) = 1.
- aRef(0, 0) = i;
- aRef(1, 0) = i;
- aRef(2, 0) = 1;
-
- // A_ij(i, j) = -1.
- aRef(0, 1) = i;
- aRef(1, 1) = neighbors(j, i);
- aRef(2, 1) = -1;
-
- // A_ij(j, i) = -1.
- aRef(0, 2) = neighbors(j, i);
- aRef(1, 2) = i;
- aRef(2, 2) = -1;
-
- // A_ij(j, j) = 1.
- aRef(0, 3) = neighbors(j, i);
- aRef(1, 3) = neighbors(j, i);
- aRef(2, 3) = 1;
-
- // The constraint b_ij is the distance between these two points.
- mvuSolver.B()[index] = distances(j, i);
- }
- }
-
- // Now on with the solving.
- double objective = mvuSolver.Optimize(outputData);
-
- Log::Info << "Final objective is " << objective << "." << std::endl;
-
- // Revert to original data format.
- outputData = trans(outputData);
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/mvu/mvu.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/mvu/mvu.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/mvu/mvu.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/mvu/mvu.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,120 @@
+/**
+ * @file mvu.cpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the MVU class and its auxiliary objective function class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "mvu.hpp"
+
+//#include <mlpack/core/optimizers/aug_lagrangian/aug_lagrangian.hpp>
+#include <mlpack/core/optimizers/lrsdp/lrsdp.hpp>
+
+#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+
+using namespace mlpack;
+using namespace mlpack::mvu;
+using namespace mlpack::optimization;
+
+MVU::MVU(const arma::mat& data) : data(data)
+{
+ // Nothing to do.
+}
+
+void MVU::Unfold(const size_t newDim,
+ const size_t numNeighbors,
+ arma::mat& outputData)
+{
+ // First we have to choose the output point. We'll take a linear projection
+ // of the data for now (this is probably not a good final solution).
+// outputData = trans(data.rows(0, newDim - 1));
+ // Following Nick's idea.
+ outputData.randu(data.n_cols, newDim);
+
+ // The number of constraints is the number of nearest neighbors plus one.
+ LRSDP mvuSolver(numNeighbors * data.n_cols + 1, outputData);
+
+ // Set up the objective. Because we are maximizing the trace of (R R^T),
+ // we'll instead state it as min(-I_n * (R R^T)), meaning C() is -I_n.
+ mvuSolver.C().eye(data.n_cols, data.n_cols);
+ mvuSolver.C() *= -1;
+
+ // Now set up each of the constraints.
+ // The first constraint is trace(ones * R * R^T) = 0.
+ mvuSolver.B()[0] = 0;
+ mvuSolver.A()[0].ones(data.n_cols, data.n_cols);
+
+ // All of our other constraints will be sparse except the first. So set that
+ // vector of modes accordingly.
+ mvuSolver.AModes().ones();
+ mvuSolver.AModes()[0] = 0;
+
+ // Now all of the other constraints. We first have to run AllkNN to get the
+ // list of nearest neighbors.
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ AllkNN allknn(data);
+ allknn.Search(numNeighbors, neighbors, distances);
+
+ // Add each of the other constraints. They are sparse constraints:
+ // Tr(A_ij K) = d_ij;
+ // A_ij = zeros except for 1 at (i, i), (j, j); -1 at (i, j), (j, i).
+ for (size_t i = 0; i < neighbors.n_cols; ++i)
+ {
+ for (size_t j = 0; j < numNeighbors; ++j)
+ {
+ // This is the index of the constraint.
+ const size_t index = (i * numNeighbors) + j + 1;
+
+ arma::mat& aRef = mvuSolver.A()[index];
+
+ aRef.set_size(3, 4);
+
+ // A_ij(i, i) = 1.
+ aRef(0, 0) = i;
+ aRef(1, 0) = i;
+ aRef(2, 0) = 1;
+
+ // A_ij(i, j) = -1.
+ aRef(0, 1) = i;
+ aRef(1, 1) = neighbors(j, i);
+ aRef(2, 1) = -1;
+
+ // A_ij(j, i) = -1.
+ aRef(0, 2) = neighbors(j, i);
+ aRef(1, 2) = i;
+ aRef(2, 2) = -1;
+
+ // A_ij(j, j) = 1.
+ aRef(0, 3) = neighbors(j, i);
+ aRef(1, 3) = neighbors(j, i);
+ aRef(2, 3) = 1;
+
+ // The constraint b_ij is the distance between these two points.
+ mvuSolver.B()[index] = distances(j, i);
+ }
+ }
+
+ // Now on with the solving.
+ double objective = mvuSolver.Optimize(outputData);
+
+ Log::Info << "Final objective is " << objective << "." << std::endl;
+
+ // Revert to original data format.
+ outputData = trans(outputData);
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/mvu/mvu.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/mvu/mvu.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/mvu/mvu.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,56 +0,0 @@
-/**
- * @file mvu.hpp
- * @author Ryan Curtin
- *
- * An implementation of Maximum Variance Unfolding. This file defines an MVU
- * class as well as a class representing the objective function (a semidefinite
- * program) which MVU seeks to minimize. Minimization is performed by the
- * Augmented Lagrangian optimizer (which in turn uses the L-BFGS optimizer).
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_MVU_MVU_HPP
-#define __MLPACK_METHODS_MVU_MVU_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace mvu {
-
-/**
- * The MVU class is meant to provide a good abstraction for users. The dataset
- * needs to be provided, as well as several parameters.
- *
- * - dataset
- * - new dimensionality
- */
-class MVU
-{
- public:
- MVU(const arma::mat& dataIn);
-
- void Unfold(const size_t newDim,
- const size_t numNeighbors,
- arma::mat& outputCoordinates);
-
- private:
- const arma::mat& data;
-};
-
-}; // namespace mvu
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/mvu/mvu.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/mvu/mvu.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/mvu/mvu.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/mvu/mvu.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,56 @@
+/**
+ * @file mvu.hpp
+ * @author Ryan Curtin
+ *
+ * An implementation of Maximum Variance Unfolding. This file defines an MVU
+ * class as well as a class representing the objective function (a semidefinite
+ * program) which MVU seeks to minimize. Minimization is performed by the
+ * Augmented Lagrangian optimizer (which in turn uses the L-BFGS optimizer).
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_MVU_MVU_HPP
+#define __MLPACK_METHODS_MVU_MVU_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace mvu {
+
+/**
+ * The MVU class is meant to provide a good abstraction for users. The dataset
+ * needs to be provided, as well as several parameters.
+ *
+ * - dataset
+ * - new dimensionality
+ */
+class MVU
+{
+ public:
+ MVU(const arma::mat& dataIn);
+
+ void Unfold(const size_t newDim,
+ const size_t numNeighbors,
+ arma::mat& outputCoordinates);
+
+ private:
+ const arma::mat& data;
+};
+
+}; // namespace mvu
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/mvu/mvu_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/mvu/mvu_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/mvu/mvu_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,84 +0,0 @@
-/**
- * @file mvu_main.cpp
- * @author Ryan Curtin
- *
- * Executable for MVU.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include "mvu.hpp"
-
-PROGRAM_INFO("Maximum Variance Unfolding (MVU)", "This program implements "
- "Maximum Variance Unfolding, a nonlinear dimensionality reduction "
- "technique. The method minimizes dimensionality by unfolding a manifold "
- "such that the distances to the nearest neighbors of each point are held "
- "constant.");
-
-PARAM_STRING_REQ("input_file", "Filename of input dataset.", "i");
-PARAM_INT_REQ("new_dim", "New dimensionality of dataset.", "d");
-
-PARAM_STRING("output_file", "Filename to save unfolded dataset to.", "o",
- "output.csv");
-PARAM_INT("num_neighbors", "Number of nearest neighbors to consider while "
- "unfolding.", "k", 5);
-
-using namespace mlpack;
-using namespace mlpack::mvu;
-using namespace mlpack::math;
-using namespace arma;
-using namespace std;
-
-int main(int argc, char **argv)
-{
- // Read from command line.
- CLI::ParseCommandLine(argc, argv);
-
- RandomSeed(time(NULL));
-
- // Load input dataset.
- const string inputFile = CLI::GetParam<string>("input_file");
- mat data;
- data::Load(inputFile, data, true);
-
- // Verify that the requested dimensionality is valid.
- const int newDim = CLI::GetParam<int>("new_dim");
- if (newDim <= 0 || newDim > (int) data.n_rows)
- {
- Log::Fatal << "Invalid new dimensionality (" << newDim << "). Must be "
- << "between 1 and the input dataset dimensionality (" << data.n_rows
- << ")." << std::endl;
- }
-
- // Verify that the number of neighbors is valid.
- const int numNeighbors = CLI::GetParam<int>("num_neighbors");
- if (numNeighbors <= 0 || numNeighbors > (int) data.n_cols)
- {
- Log::Fatal << "Invalid number of neighbors (" << numNeighbors << "). Must "
- << "be between 1 and the number of points in the input dataset ("
- << data.n_cols << ")." << std::endl;
- }
-
- // Now run MVU.
- MVU mvu(data);
-
- mat output;
- mvu.Unfold(newDim, numNeighbors, output);
-
- // Save results to file.
- const string outputFile = CLI::GetParam<string>("output_file");
- data::Save(outputFile, output, true);
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/mvu/mvu_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/mvu/mvu_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/mvu/mvu_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/mvu/mvu_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,84 @@
+/**
+ * @file mvu_main.cpp
+ * @author Ryan Curtin
+ *
+ * Executable for MVU.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include "mvu.hpp"
+
+PROGRAM_INFO("Maximum Variance Unfolding (MVU)", "This program implements "
+ "Maximum Variance Unfolding, a nonlinear dimensionality reduction "
+ "technique. The method minimizes dimensionality by unfolding a manifold "
+ "such that the distances to the nearest neighbors of each point are held "
+ "constant.");
+
+PARAM_STRING_REQ("input_file", "Filename of input dataset.", "i");
+PARAM_INT_REQ("new_dim", "New dimensionality of dataset.", "d");
+
+PARAM_STRING("output_file", "Filename to save unfolded dataset to.", "o",
+ "output.csv");
+PARAM_INT("num_neighbors", "Number of nearest neighbors to consider while "
+ "unfolding.", "k", 5);
+
+using namespace mlpack;
+using namespace mlpack::mvu;
+using namespace mlpack::math;
+using namespace arma;
+using namespace std;
+
+int main(int argc, char **argv)
+{
+ // Read from command line.
+ CLI::ParseCommandLine(argc, argv);
+
+ RandomSeed(time(NULL));
+
+ // Load input dataset.
+ const string inputFile = CLI::GetParam<string>("input_file");
+ mat data;
+ data::Load(inputFile, data, true);
+
+ // Verify that the requested dimensionality is valid.
+ const int newDim = CLI::GetParam<int>("new_dim");
+ if (newDim <= 0 || newDim > (int) data.n_rows)
+ {
+ Log::Fatal << "Invalid new dimensionality (" << newDim << "). Must be "
+ << "between 1 and the input dataset dimensionality (" << data.n_rows
+ << ")." << std::endl;
+ }
+
+ // Verify that the number of neighbors is valid.
+ const int numNeighbors = CLI::GetParam<int>("num_neighbors");
+ if (numNeighbors <= 0 || numNeighbors > (int) data.n_cols)
+ {
+ Log::Fatal << "Invalid number of neighbors (" << numNeighbors << "). Must "
+ << "be between 1 and the number of points in the input dataset ("
+ << data.n_cols << ")." << std::endl;
+ }
+
+ // Now run MVU.
+ MVU mvu(data);
+
+ mat output;
+ mvu.Unfold(newDim, numNeighbors, output);
+
+ // Save results to file.
+ const string outputFile = CLI::GetParam<string>("output_file");
+ data::Save(outputFile, output, true);
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,126 +0,0 @@
-/**
- * @file naive_bayes_classifier.hpp
- * @author Parikshit Ram (pram at cc.gatech.edu)
- *
- * A Naive Bayes Classifier which parametrically estimates the distribution of
- * the features. It is assumed that the features have been sampled from a
- * Gaussian PDF.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_HPP
-#define __MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_HPP
-
-#include <mlpack/core.hpp>
-#include <mlpack/methods/gmm/phi.hpp>
-
-namespace mlpack {
-namespace naive_bayes /** The Naive Bayes Classifier. */ {
-
-/**
- * The simple Naive Bayes classifier. This class trains on the data by
- * calculating the sample mean and variance of the features with respect to each
- * of the labels, and also the class probabilities. The class labels are
- * assumed to be positive integers (starting with 0), and are expected to be the
- * last row of the data input to the constructor.
- *
- * Mathematically, it computes P(X_i = x_i | Y = y_j) for each feature X_i for
- * each of the labels y_j. Alongwith this, it also computes the classs
- * probabilities P(Y = y_j).
- *
- * For classifying a data point (x_1, x_2, ..., x_n), it computes the following:
- * arg max_y(P(Y = y)*P(X_1 = x_1 | Y = y) * ... * P(X_n = x_n | Y = y))
- *
- * Example use:
- *
- * @code
- * extern arma::mat training_data, testing_data;
- * NaiveBayesClassifier<> nbc(training_data, 5);
- * arma::vec results;
- *
- * nbc.Classify(testing_data, results);
- * @endcode
- */
-template<typename MatType = arma::mat>
-class NaiveBayesClassifier
-{
- private:
- //! Sample mean for each class.
- MatType means;
-
- //! Sample variances for each class.
- MatType variances;
-
- //! Class probabilities.
- arma::vec probabilities;
-
- public:
- /**
- * Initializes the classifier as per the input and then trains it by
- * calculating the sample mean and variances. The input data is expected to
- * have integer labels as the last row (starting with 0 and not greater than
- * the number of classes).
- *
- * Example use:
- * @code
- * extern arma::mat training_data, testing_data;
- * NaiveBayesClassifier nbc(training_data, 5);
- * @endcode
- *
- * @param data Sample data points; the last row should be labels.
- * @param classes Number of classes in this classifier.
- */
- NaiveBayesClassifier(const MatType& data, const size_t classes);
-
- /**
- * Given a bunch of data points, this function evaluates the class of each of
- * those data points, and puts it in the vector 'results'.
- *
- * @code
- * arma::mat test_data; // each column is a test point
- * arma::Col<size_t> results;
- * ...
- * nbc.Classify(test_data, &results);
- * @endcode
- *
- * @param data List of data points.
- * @param results Vector that class predictions will be placed into.
- */
- void Classify(const MatType& data, arma::Col<size_t>& results);
-
- //! Get the sample means for each class.
- const MatType& Means() const { return means; }
- //! Modify the sample means for each class.
- MatType& Means() { return means; }
-
- //! Get the sample variances for each class.
- const MatType& Variances() const { return variances; }
- //! Modify the sample variances for each class.
- MatType& Variances() { return variances; }
-
- //! Get the prior probabilities for each class.
- const arma::vec& Probabilities() const { return probabilities; }
- //! Modify the prior probabilities for each class.
- arma::vec& Probabilities() { return probabilities; }
-};
-
-}; // namespace naive_bayes
-}; // namespace mlpack
-
-// Include implementation.
-#include "naive_bayes_classifier_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/naive_bayes/naive_bayes_classifier.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,126 @@
+/**
+ * @file naive_bayes_classifier.hpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * A Naive Bayes Classifier which parametrically estimates the distribution of
+ * the features. It is assumed that the features have been sampled from a
+ * Gaussian PDF.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_HPP
+#define __MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/methods/gmm/phi.hpp>
+
+namespace mlpack {
+namespace naive_bayes /** The Naive Bayes Classifier. */ {
+
+/**
+ * The simple Naive Bayes classifier. This class trains on the data by
+ * calculating the sample mean and variance of the features with respect to each
+ * of the labels, and also the class probabilities. The class labels are
+ * assumed to be positive integers (starting with 0), and are expected to be the
+ * last row of the data input to the constructor.
+ *
+ * Mathematically, it computes P(X_i = x_i | Y = y_j) for each feature X_i for
+ * each of the labels y_j. Alongwith this, it also computes the classs
+ * probabilities P(Y = y_j).
+ *
+ * For classifying a data point (x_1, x_2, ..., x_n), it computes the following:
+ * arg max_y(P(Y = y)*P(X_1 = x_1 | Y = y) * ... * P(X_n = x_n | Y = y))
+ *
+ * Example use:
+ *
+ * @code
+ * extern arma::mat training_data, testing_data;
+ * NaiveBayesClassifier<> nbc(training_data, 5);
+ * arma::vec results;
+ *
+ * nbc.Classify(testing_data, results);
+ * @endcode
+ */
+template<typename MatType = arma::mat>
+class NaiveBayesClassifier
+{
+ private:
+ //! Sample mean for each class.
+ MatType means;
+
+ //! Sample variances for each class.
+ MatType variances;
+
+ //! Class probabilities.
+ arma::vec probabilities;
+
+ public:
+ /**
+ * Initializes the classifier as per the input and then trains it by
+ * calculating the sample mean and variances. The input data is expected to
+ * have integer labels as the last row (starting with 0 and not greater than
+ * the number of classes).
+ *
+ * Example use:
+ * @code
+ * extern arma::mat training_data, testing_data;
+ * NaiveBayesClassifier nbc(training_data, 5);
+ * @endcode
+ *
+ * @param data Sample data points; the last row should be labels.
+ * @param classes Number of classes in this classifier.
+ */
+ NaiveBayesClassifier(const MatType& data, const size_t classes);
+
+ /**
+ * Given a bunch of data points, this function evaluates the class of each of
+ * those data points, and puts it in the vector 'results'.
+ *
+ * @code
+ * arma::mat test_data; // each column is a test point
+ * arma::Col<size_t> results;
+ * ...
+ * nbc.Classify(test_data, &results);
+ * @endcode
+ *
+ * @param data List of data points.
+ * @param results Vector that class predictions will be placed into.
+ */
+ void Classify(const MatType& data, arma::Col<size_t>& results);
+
+ //! Get the sample means for each class.
+ const MatType& Means() const { return means; }
+ //! Modify the sample means for each class.
+ MatType& Means() { return means; }
+
+ //! Get the sample variances for each class.
+ const MatType& Variances() const { return variances; }
+ //! Modify the sample variances for each class.
+ MatType& Variances() { return variances; }
+
+ //! Get the prior probabilities for each class.
+ const arma::vec& Probabilities() const { return probabilities; }
+ //! Modify the prior probabilities for each class.
+ arma::vec& Probabilities() { return probabilities; }
+};
+
+}; // namespace naive_bayes
+}; // namespace mlpack
+
+// Include implementation.
+#include "naive_bayes_classifier_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,116 +0,0 @@
-/**
- * @file simple_nbc_impl.hpp
- * @author Parikshit Ram (pram at cc.gatech.edu)
- *
- * A Naive Bayes Classifier which parametrically estimates the distribution of
- * the features. It is assumed that the features have been sampled from a
- * Gaussian PDF.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_IMPL_HPP
-#define __MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_IMPL_HPP
-
-#include <mlpack/core.hpp>
-
-// In case it hasn't been included already.
-#include "naive_bayes_classifier.hpp"
-
-namespace mlpack {
-namespace naive_bayes {
-
-template<typename MatType>
-NaiveBayesClassifier<MatType>::NaiveBayesClassifier(const MatType& data,
- const size_t classes)
-{
- size_t dimensionality = data.n_rows - 1;
-
- // Update the variables according to the number of features and classes
- // present in the data.
- probabilities.set_size(classes);
- means.zeros(dimensionality, classes);
- variances.zeros(dimensionality, classes);
-
- Log::Info << "Training Naive Bayes classifier on " << data.n_cols
- << " examples with " << dimensionality << " features each." << std::endl;
-
- // Calculate the class probabilities as well as the sample mean and variance
- // for each of the features with respect to each of the labels.
- for (size_t j = 0; j < data.n_cols; ++j)
- {
- size_t label = (size_t) data(dimensionality, j);
- ++probabilities[label];
-
- means.col(label) += data(arma::span(0, dimensionality - 1), j);
- variances.col(label) += square(data(arma::span(0, dimensionality - 1), j));
- }
-
- for (size_t i = 0; i < classes; ++i)
- {
- variances.col(i) -= (square(means.col(i)) / probabilities[i]);
- means.col(i) /= probabilities[i];
- variances.col(i) /= (probabilities[i] - 1);
- }
-
- probabilities /= data.n_cols;
-}
-
-template<typename MatType>
-void NaiveBayesClassifier<MatType>::Classify(const MatType& data,
- arma::Col<size_t>& results)
-{
- // Check that the number of features in the test data is same as in the
- // training data.
- Log::Assert(data.n_rows == means.n_rows);
-
- arma::vec probs(means.n_cols);
-
- results.zeros(data.n_cols);
-
- Log::Info << "Running Naive Bayes classifier on " << data.n_cols
- << " data points with " << data.n_rows << " features each." << std::endl;
-
- // Calculate the joint probability for each of the data points for each of the
- // means.n_cols.
-
- // Loop over every test case.
- for (size_t n = 0; n < data.n_cols; n++)
- {
- // Loop over every class.
- for (size_t i = 0; i < means.n_cols; i++)
- {
- // Use the log values to prevent floating point underflow.
- probs(i) = log(probabilities(i));
-
- // Loop over every feature.
- probs(i) += log(gmm::phi(data.unsafe_col(n), means.unsafe_col(i),
- diagmat(variances.unsafe_col(i))));
- }
-
- // Find the index of the maximum value in tmp_vals.
- arma::uword maxIndex = 0;
- probs.max(maxIndex);
-
- results[n] = maxIndex;
- }
-
- return;
-}
-
-}; // namespace naive_bayes
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,116 @@
+/**
+ * @file simple_nbc_impl.hpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * A Naive Bayes Classifier which parametrically estimates the distribution of
+ * the features. It is assumed that the features have been sampled from a
+ * Gaussian PDF.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_IMPL_HPP
+#define __MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_IMPL_HPP
+
+#include <mlpack/core.hpp>
+
+// In case it hasn't been included already.
+#include "naive_bayes_classifier.hpp"
+
+namespace mlpack {
+namespace naive_bayes {
+
+template<typename MatType>
+NaiveBayesClassifier<MatType>::NaiveBayesClassifier(const MatType& data,
+ const size_t classes)
+{
+ size_t dimensionality = data.n_rows - 1;
+
+ // Update the variables according to the number of features and classes
+ // present in the data.
+ probabilities.set_size(classes);
+ means.zeros(dimensionality, classes);
+ variances.zeros(dimensionality, classes);
+
+ Log::Info << "Training Naive Bayes classifier on " << data.n_cols
+ << " examples with " << dimensionality << " features each." << std::endl;
+
+ // Calculate the class probabilities as well as the sample mean and variance
+ // for each of the features with respect to each of the labels.
+ for (size_t j = 0; j < data.n_cols; ++j)
+ {
+ size_t label = (size_t) data(dimensionality, j);
+ ++probabilities[label];
+
+ means.col(label) += data(arma::span(0, dimensionality - 1), j);
+ variances.col(label) += square(data(arma::span(0, dimensionality - 1), j));
+ }
+
+ for (size_t i = 0; i < classes; ++i)
+ {
+ variances.col(i) -= (square(means.col(i)) / probabilities[i]);
+ means.col(i) /= probabilities[i];
+ variances.col(i) /= (probabilities[i] - 1);
+ }
+
+ probabilities /= data.n_cols;
+}
+
+template<typename MatType>
+void NaiveBayesClassifier<MatType>::Classify(const MatType& data,
+ arma::Col<size_t>& results)
+{
+ // Check that the number of features in the test data is same as in the
+ // training data.
+ Log::Assert(data.n_rows == means.n_rows);
+
+ arma::vec probs(means.n_cols);
+
+ results.zeros(data.n_cols);
+
+ Log::Info << "Running Naive Bayes classifier on " << data.n_cols
+ << " data points with " << data.n_rows << " features each." << std::endl;
+
+ // Calculate the joint probability for each of the data points for each of the
+ // means.n_cols.
+
+ // Loop over every test case.
+ for (size_t n = 0; n < data.n_cols; n++)
+ {
+ // Loop over every class.
+ for (size_t i = 0; i < means.n_cols; i++)
+ {
+ // Use the log values to prevent floating point underflow.
+ probs(i) = log(probabilities(i));
+
+ // Loop over every feature.
+ probs(i) += log(gmm::phi(data.unsafe_col(n), means.unsafe_col(i),
+ diagmat(variances.unsafe_col(i))));
+ }
+
+ // Find the index of the maximum value in tmp_vals.
+ arma::uword maxIndex = 0;
+ probs.max(maxIndex);
+
+ results[n] = maxIndex;
+ }
+
+ return;
+}
+
+}; // namespace naive_bayes
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/naive_bayes/nbc_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/naive_bayes/nbc_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/naive_bayes/nbc_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,105 +0,0 @@
-/**
- * @author Parikshit Ram (pram at cc.gatech.edu)
- * @file nbc_main.cpp
- *
- * This program runs the Simple Naive Bayes Classifier.
- *
- * This classifier does parametric naive bayes classification assuming that the
- * features are sampled from a Gaussian distribution.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-
-#include "naive_bayes_classifier.hpp"
-
-PROGRAM_INFO("Parametric Naive Bayes Classifier",
- "This program trains the Naive Bayes classifier on the given labeled "
- "training set and then uses the trained classifier to classify the points "
- "in the given test set.\n"
- "\n"
- "Labels are expected to be the last row of the training set (--train_file),"
- " but labels can also be passed in separately as their own file "
- "(--labels_file).");
-
-PARAM_STRING_REQ("train_file", "A file containing the training set.", "t");
-PARAM_STRING_REQ("test_file", "A file containing the test set.", "T");
-
-PARAM_STRING("labels_file", "A file containing labels for the training set.",
- "l", "");
-PARAM_STRING("output", "The file in which the output of the test would "
- "be written, defaults to 'output.csv')", "o", "output.csv");
-
-using namespace mlpack;
-using namespace mlpack::naive_bayes;
-using namespace std;
-using namespace arma;
-
-int main(int argc, char* argv[])
-{
- CLI::ParseCommandLine(argc, argv);
-
- // Check input parameters.
- const string trainingDataFilename = CLI::GetParam<string>("train_file");
- mat trainingData;
- data::Load(trainingDataFilename.c_str(), trainingData, true);
-
- // Did the user pass in labels?
- const string labelsFilename = CLI::GetParam<string>("labels_file");
- if (labelsFilename != "")
- {
- // Load labels.
- arma::mat labels;
- data::Load(labelsFilename.c_str(), labels, true);
-
- // Not incredibly efficient...
- if (labels.n_rows == 1)
- trainingData.insert_rows(trainingData.n_rows, trans(labels));
- else if (labels.n_cols == 1)
- trainingData.insert_rows(trainingData.n_rows, labels);
- else
- Log::Fatal << "Labels must have only one column or row!" << endl;
- }
-
- const string testingDataFilename = CLI::GetParam<std::string>("test_file");
- mat testingData;
- data::Load(testingDataFilename.c_str(), testingData, true);
-
- if (testingData.n_rows != trainingData.n_rows - 1)
- Log::Fatal << "Test data dimensionality (" << testingData.n_rows << ") "
- << "must be the same as training data (" << trainingData.n_rows - 1
- << ")!" << std::endl;
-
- // Calculate number of classes.
- size_t classes = (size_t) max(trainingData.row(trainingData.n_rows - 1)) + 1;
-
- // Create and train the classifier.
- Timer::Start("training");
- NaiveBayesClassifier<> nbc(trainingData, classes);
- Timer::Stop("training");
-
- // Timing the running of the Naive Bayes Classifier.
- arma::Col<size_t> results;
- Timer::Start("testing");
- nbc.Classify(testingData, results);
- Timer::Stop("testing");
-
- // Output results.
- const string outputFilename = CLI::GetParam<string>("output");
- data::Save(outputFilename.c_str(), results, true);
-
- return 0;
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/naive_bayes/nbc_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/naive_bayes/nbc_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/naive_bayes/nbc_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/naive_bayes/nbc_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,105 @@
+/**
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ * @file nbc_main.cpp
+ *
+ * This program runs the Simple Naive Bayes Classifier.
+ *
+ * This classifier does parametric naive bayes classification assuming that the
+ * features are sampled from a Gaussian distribution.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+
+#include "naive_bayes_classifier.hpp"
+
+PROGRAM_INFO("Parametric Naive Bayes Classifier",
+ "This program trains the Naive Bayes classifier on the given labeled "
+ "training set and then uses the trained classifier to classify the points "
+ "in the given test set.\n"
+ "\n"
+ "Labels are expected to be the last row of the training set (--train_file),"
+ " but labels can also be passed in separately as their own file "
+ "(--labels_file).");
+
+PARAM_STRING_REQ("train_file", "A file containing the training set.", "t");
+PARAM_STRING_REQ("test_file", "A file containing the test set.", "T");
+
+PARAM_STRING("labels_file", "A file containing labels for the training set.",
+ "l", "");
+PARAM_STRING("output", "The file in which the output of the test would "
+ "be written, defaults to 'output.csv')", "o", "output.csv");
+
+using namespace mlpack;
+using namespace mlpack::naive_bayes;
+using namespace std;
+using namespace arma;
+
+int main(int argc, char* argv[])
+{
+ CLI::ParseCommandLine(argc, argv);
+
+ // Check input parameters.
+ const string trainingDataFilename = CLI::GetParam<string>("train_file");
+ mat trainingData;
+ data::Load(trainingDataFilename.c_str(), trainingData, true);
+
+ // Did the user pass in labels?
+ const string labelsFilename = CLI::GetParam<string>("labels_file");
+ if (labelsFilename != "")
+ {
+ // Load labels.
+ arma::mat labels;
+ data::Load(labelsFilename.c_str(), labels, true);
+
+ // Not incredibly efficient...
+ if (labels.n_rows == 1)
+ trainingData.insert_rows(trainingData.n_rows, trans(labels));
+ else if (labels.n_cols == 1)
+ trainingData.insert_rows(trainingData.n_rows, labels);
+ else
+ Log::Fatal << "Labels must have only one column or row!" << endl;
+ }
+
+ const string testingDataFilename = CLI::GetParam<std::string>("test_file");
+ mat testingData;
+ data::Load(testingDataFilename.c_str(), testingData, true);
+
+ if (testingData.n_rows != trainingData.n_rows - 1)
+ Log::Fatal << "Test data dimensionality (" << testingData.n_rows << ") "
+ << "must be the same as training data (" << trainingData.n_rows - 1
+ << ")!" << std::endl;
+
+ // Calculate number of classes.
+ size_t classes = (size_t) max(trainingData.row(trainingData.n_rows - 1)) + 1;
+
+ // Create and train the classifier.
+ Timer::Start("training");
+ NaiveBayesClassifier<> nbc(trainingData, classes);
+ Timer::Stop("training");
+
+ // Timing the running of the Naive Bayes Classifier.
+ arma::Col<size_t> results;
+ Timer::Start("testing");
+ nbc.Classify(testingData, results);
+ Timer::Stop("testing");
+
+ // Output results.
+ const string outputFilename = CLI::GetParam<string>("output");
+ data::Save(outputFilename.c_str(), results, true);
+
+ return 0;
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/nca/nca.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,123 +0,0 @@
-/**
- * @file nca.hpp
- * @author Ryan Curtin
- *
- * Declaration of NCA class (Neighborhood Components Analysis).
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NCA_NCA_HPP
-#define __MLPACK_METHODS_NCA_NCA_HPP
-
-#include <mlpack/core.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-#include <mlpack/core/optimizers/sgd/sgd.hpp>
-
-#include "nca_softmax_error_function.hpp"
-
-namespace mlpack {
-namespace nca /** Neighborhood Components Analysis. */ {
-
-/**
- * An implementation of Neighborhood Components Analysis, both a linear
- * dimensionality reduction technique and a distance learning technique. The
- * method seeks to improve k-nearest-neighbor classification on a dataset by
- * scaling the dimensions. The method is nonparametric, and does not require a
- * value of k. It works by using stochastic ("soft") neighbor assignments and
- * using optimization techniques over the gradient of the accuracy of the
- * neighbor assignments.
- *
- * For more details, see the following published paper:
- *
- * @code
- * @inproceedings{Goldberger2004,
- * author = {Goldberger, Jacob and Roweis, Sam and Hinton, Geoff and
- * Salakhutdinov, Ruslan},
- * booktitle = {Advances in Neural Information Processing Systems 17},
- * pages = {513--520},
- * publisher = {MIT Press},
- * title = {{Neighbourhood Components Analysis}},
- * year = {2004}
- * }
- * @endcode
- */
-template<typename MetricType = metric::SquaredEuclideanDistance,
- template<typename> class OptimizerType = optimization::SGD>
-class NCA
-{
- public:
- /**
- * Construct the Neighborhood Components Analysis object. This simply stores
- * the reference to the dataset and labels as well as the parameters for
- * optimization before the actual optimization is performed.
- *
- * @param dataset Input dataset.
- * @param labels Input dataset labels.
- * @param stepSize Step size for stochastic gradient descent.
- * @param maxIterations Maximum iterations for stochastic gradient descent.
- * @param tolerance Tolerance for termination of stochastic gradient descent.
- * @param shuffle Whether or not to shuffle the dataset during SGD.
- * @param metric Instantiated metric to use.
- */
- NCA(const arma::mat& dataset,
- const arma::uvec& labels,
- MetricType metric = MetricType());
-
- /**
- * Perform Neighborhood Components Analysis. The output distance learning
- * matrix is written into the passed reference. If LearnDistance() is called
- * with an outputMatrix which has the correct size (dataset.n_rows x
- * dataset.n_rows), that matrix will be used as the starting point for
- * optimization.
- *
- * @param output_matrix Covariance matrix of Mahalanobis distance.
- */
- void LearnDistance(arma::mat& outputMatrix);
-
- //! Get the dataset reference.
- const arma::mat& Dataset() const { return dataset; }
- //! Get the labels reference.
- const arma::uvec& Labels() const { return labels; }
-
- //! Get the optimizer.
- const OptimizerType<SoftmaxErrorFunction<MetricType> >& Optimizer() const
- { return optimizer; }
- OptimizerType<SoftmaxErrorFunction<MetricType> >& Optimizer()
- { return optimizer; }
-
- private:
- //! Dataset reference.
- const arma::mat& dataset;
- //! Labels reference.
- const arma::uvec& labels;
-
- //! Metric to be used.
- MetricType metric;
-
- //! The function to optimize.
- SoftmaxErrorFunction<MetricType> errorFunction;
-
- //! The optimizer to use.
- OptimizerType<SoftmaxErrorFunction<MetricType> > optimizer;
-};
-
-}; // namespace nca
-}; // namespace mlpack
-
-// Include the implementation.
-#include "nca_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/nca/nca.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,123 @@
+/**
+ * @file nca.hpp
+ * @author Ryan Curtin
+ *
+ * Declaration of NCA class (Neighborhood Components Analysis).
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NCA_NCA_HPP
+#define __MLPACK_METHODS_NCA_NCA_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/core/optimizers/sgd/sgd.hpp>
+
+#include "nca_softmax_error_function.hpp"
+
+namespace mlpack {
+namespace nca /** Neighborhood Components Analysis. */ {
+
+/**
+ * An implementation of Neighborhood Components Analysis, both a linear
+ * dimensionality reduction technique and a distance learning technique. The
+ * method seeks to improve k-nearest-neighbor classification on a dataset by
+ * scaling the dimensions. The method is nonparametric, and does not require a
+ * value of k. It works by using stochastic ("soft") neighbor assignments and
+ * using optimization techniques over the gradient of the accuracy of the
+ * neighbor assignments.
+ *
+ * For more details, see the following published paper:
+ *
+ * @code
+ * @inproceedings{Goldberger2004,
+ * author = {Goldberger, Jacob and Roweis, Sam and Hinton, Geoff and
+ * Salakhutdinov, Ruslan},
+ * booktitle = {Advances in Neural Information Processing Systems 17},
+ * pages = {513--520},
+ * publisher = {MIT Press},
+ * title = {{Neighbourhood Components Analysis}},
+ * year = {2004}
+ * }
+ * @endcode
+ */
+template<typename MetricType = metric::SquaredEuclideanDistance,
+ template<typename> class OptimizerType = optimization::SGD>
+class NCA
+{
+ public:
+ /**
+ * Construct the Neighborhood Components Analysis object. This simply stores
+ * the reference to the dataset and labels as well as the parameters for
+ * optimization before the actual optimization is performed.
+ *
+ * @param dataset Input dataset.
+ * @param labels Input dataset labels.
+ * @param stepSize Step size for stochastic gradient descent.
+ * @param maxIterations Maximum iterations for stochastic gradient descent.
+ * @param tolerance Tolerance for termination of stochastic gradient descent.
+ * @param shuffle Whether or not to shuffle the dataset during SGD.
+ * @param metric Instantiated metric to use.
+ */
+ NCA(const arma::mat& dataset,
+ const arma::uvec& labels,
+ MetricType metric = MetricType());
+
+ /**
+ * Perform Neighborhood Components Analysis. The output distance learning
+ * matrix is written into the passed reference. If LearnDistance() is called
+ * with an outputMatrix which has the correct size (dataset.n_rows x
+ * dataset.n_rows), that matrix will be used as the starting point for
+ * optimization.
+ *
+ * @param output_matrix Covariance matrix of Mahalanobis distance.
+ */
+ void LearnDistance(arma::mat& outputMatrix);
+
+ //! Get the dataset reference.
+ const arma::mat& Dataset() const { return dataset; }
+ //! Get the labels reference.
+ const arma::uvec& Labels() const { return labels; }
+
+ //! Get the optimizer.
+ const OptimizerType<SoftmaxErrorFunction<MetricType> >& Optimizer() const
+ { return optimizer; }
+ OptimizerType<SoftmaxErrorFunction<MetricType> >& Optimizer()
+ { return optimizer; }
+
+ private:
+ //! Dataset reference.
+ const arma::mat& dataset;
+ //! Labels reference.
+ const arma::uvec& labels;
+
+ //! Metric to be used.
+ MetricType metric;
+
+ //! The function to optimize.
+ SoftmaxErrorFunction<MetricType> errorFunction;
+
+ //! The optimizer to use.
+ OptimizerType<SoftmaxErrorFunction<MetricType> > optimizer;
+};
+
+}; // namespace nca
+}; // namespace mlpack
+
+// Include the implementation.
+#include "nca_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/nca/nca_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,61 +0,0 @@
-/**
- * @file nca_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of templated NCA class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NCA_NCA_IMPL_HPP
-#define __MLPACK_METHODS_NCA_NCA_IMPL_HPP
-
-// In case it was not already included.
-#include "nca.hpp"
-
-namespace mlpack {
-namespace nca {
-
-// Just set the internal matrix reference.
-template<typename MetricType, template<typename> class OptimizerType>
-NCA<MetricType, OptimizerType>::NCA(const arma::mat& dataset,
- const arma::uvec& labels,
- MetricType metric) :
- dataset(dataset),
- labels(labels),
- metric(metric),
- errorFunction(dataset, labels, metric),
- optimizer(OptimizerType<SoftmaxErrorFunction<MetricType> >(errorFunction))
-{ /* Nothing to do. */ }
-
-template<typename MetricType, template<typename> class OptimizerType>
-void NCA<MetricType, OptimizerType>::LearnDistance(arma::mat& outputMatrix)
-{
- // See if we were passed an initialized matrix.
- if ((outputMatrix.n_rows != dataset.n_rows) ||
- (outputMatrix.n_cols != dataset.n_rows))
- outputMatrix.eye(dataset.n_rows, dataset.n_rows);
-
- Timer::Start("nca_sgd_optimization");
-
- optimizer.Optimize(outputMatrix);
-
- Timer::Stop("nca_sgd_optimization");
-}
-
-}; // namespace nca
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/nca/nca_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,61 @@
+/**
+ * @file nca_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of templated NCA class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NCA_NCA_IMPL_HPP
+#define __MLPACK_METHODS_NCA_NCA_IMPL_HPP
+
+// In case it was not already included.
+#include "nca.hpp"
+
+namespace mlpack {
+namespace nca {
+
+// Just set the internal matrix reference.
+template<typename MetricType, template<typename> class OptimizerType>
+NCA<MetricType, OptimizerType>::NCA(const arma::mat& dataset,
+ const arma::uvec& labels,
+ MetricType metric) :
+ dataset(dataset),
+ labels(labels),
+ metric(metric),
+ errorFunction(dataset, labels, metric),
+ optimizer(OptimizerType<SoftmaxErrorFunction<MetricType> >(errorFunction))
+{ /* Nothing to do. */ }
+
+template<typename MetricType, template<typename> class OptimizerType>
+void NCA<MetricType, OptimizerType>::LearnDistance(arma::mat& outputMatrix)
+{
+ // See if we were passed an initialized matrix.
+ if ((outputMatrix.n_rows != dataset.n_rows) ||
+ (outputMatrix.n_cols != dataset.n_rows))
+ outputMatrix.eye(dataset.n_rows, dataset.n_rows);
+
+ Timer::Start("nca_sgd_optimization");
+
+ optimizer.Optimize(outputMatrix);
+
+ Timer::Stop("nca_sgd_optimization");
+}
+
+}; // namespace nca
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/nca/nca_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,260 +0,0 @@
-/**
- * @file nca_main.cpp
- * @author Ryan Curtin
- *
- * Executable for Neighborhood Components Analysis.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-
-#include "nca.hpp"
-
-#include <mlpack/core/optimizers/lbfgs/lbfgs.hpp>
-
-// Define parameters.
-PROGRAM_INFO("Neighborhood Components Analysis (NCA)",
- "This program implements Neighborhood Components Analysis, both a linear "
- "dimensionality reduction technique and a distance learning technique. The"
- " method seeks to improve k-nearest-neighbor classification on a dataset "
- "by scaling the dimensions. The method is nonparametric, and does not "
- "require a value of k. It works by using stochastic (\"soft\") neighbor "
- "assignments and using optimization techniques over the gradient of the "
- "accuracy of the neighbor assignments.\n"
- "\n"
- "To work, this algorithm needs labeled data. It can be given as the last "
- "row of the input dataset (--input_file), or alternatively in a separate "
- "file (--labels_file).\n"
- "\n"
- "This implementation of NCA uses either stochastic gradient descent or the "
- "L_BFGS optimizer. Both of these optimizers do not guarantee global "
- "convergence for a nonconvex objective function (NCA's objective function "
- "is nonconvex), so the final results could depend on the random seed or "
- "other optimizer parameters.\n"
- "\n"
- "Stochastic gradient descent, specified by --optimizer \"sgd\", depends "
- "primarily on two parameters: the step size (--step_size) and the maximum "
- "number of iterations (--max_iterations). In addition, a normalized "
- "starting point can be used (--normalize), which is necessary if many "
- "warnings of the form 'Denominator of p_i is 0!' are given. Tuning the "
- "step size can be a tedious affair. In general, the step size is too large"
- " if the objective is not mostly uniformly decreasing, or if zero-valued "
- "denominator warnings are being issued. The step size is too small if the "
- "objective is changing very slowly. Setting the termination condition can "
- "be done easily once a good step size parameter is found; either increase "
- "the maximum iterations to a large number and allow SGD to find a minimum, "
- "or set the maximum iterations to 0 (allowing infinite iterations) and set "
- "the tolerance (--tolerance) to define the maximum allowed difference "
- "between objectives for SGD to terminate. Be careful -- setting the "
- "tolerance instead of the maximum iterations can take a very long time and "
- "may actually never converge due to the properties of the SGD optimizer.\n"
- "\n"
- "The L-BFGS optimizer, specified by --optimizer \"lbfgs\", uses a "
- "back-tracking line search algorithm to minimize a function. The "
- "following parameters are used by L-BFGS: --num_basis (specifies the number"
- " of memory points used by L-BFGS), --max_iterations, --armijo_constant, "
- "--wolfe, --tolerance (the optimization is terminated when the gradient "
- "norm is below this value), --max_line_search_trials, --min_step and "
- "--max_step (which both refer to the line search routine). For more "
- "details on the L-BFGS optimizer, consult either the MLPACK L-BFGS "
- "documentation (in lbfgs.hpp) or the vast set of published literature on "
- "L-BFGS.\n"
- "\n"
- "By default, the SGD optimizer is used.");
-
-PARAM_STRING_REQ("input_file", "Input dataset to run NCA on.", "i");
-PARAM_STRING_REQ("output_file", "Output file for learned distance matrix.",
- "o");
-PARAM_STRING("labels_file", "File of labels for input dataset.", "l", "");
-PARAM_STRING("optimizer", "Optimizer to use; \"sgd\" or \"lbfgs\".", "O", "");
-
-PARAM_FLAG("normalize", "Use a normalized starting point for optimization. This"
- " is useful for when points are far apart, or when SGD is returning NaN.",
- "N");
-
-PARAM_INT("max_iterations", "Maximum number of iterations for SGD or L-BFGS (0 "
- "indicates no limit).", "n", 500000);
-PARAM_DOUBLE("tolerance", "Maximum tolerance for termination of SGD or L-BFGS.",
- "t", 1e-7);
-
-PARAM_DOUBLE("step_size", "Step size for stochastic gradient descent (alpha).",
- "a", 0.01);
-PARAM_FLAG("linear_scan", "Don't shuffle the order in which data points are "
- "visited for SGD.", "L");
-
-PARAM_INT("num_basis", "Number of memory points to be stored for L-BFGS.", "N",
- 5);
-PARAM_DOUBLE("armijo_constant", "Armijo constant for L-BFGS.", "A", 1e-4);
-PARAM_DOUBLE("wolfe", "Wolfe condition parameter for L-BFGS.", "w", 0.9);
-PARAM_INT("max_line_search_trials", "Maximum number of line search trials for "
- "L-BFGS.", "L", 50);
-PARAM_DOUBLE("min_step", "Minimum step of line search for L-BFGS.", "m", 1e-20);
-PARAM_DOUBLE("max_step", "Maximum step of line search for L-BFGS.", "M", 1e20);
-
-PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
-
-
-using namespace mlpack;
-using namespace mlpack::nca;
-using namespace mlpack::metric;
-using namespace mlpack::optimization;
-using namespace std;
-
-int main(int argc, char* argv[])
-{
- // Parse command line.
- CLI::ParseCommandLine(argc, argv);
-
- if (CLI::GetParam<int>("seed") != 0)
- math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
- else
- math::RandomSeed((size_t) std::time(NULL));
-
- const string inputFile = CLI::GetParam<string>("input_file");
- const string labelsFile = CLI::GetParam<string>("labels_file");
- const string outputFile = CLI::GetParam<string>("output_file");
-
- const string optimizerType = CLI::GetParam<string>("optimizer");
-
- if ((optimizerType != "sgd") && (optimizerType != "lbfgs"))
- {
- Log::Fatal << "Optimizer type '" << optimizerType << "' unknown; must be "
- << "'sgd' or 'lbfgs'!" << std::endl;
- }
-
- // Warn on unused parameters.
- if (optimizerType == "sgd")
- {
- if (CLI::HasParam("num_basis"))
- Log::Warn << "Parameter --num_basis ignored (not using 'lbfgs' "
- << "optimizer)." << std::endl;
-
- if (CLI::HasParam("armijo_constant"))
- Log::Warn << "Parameter --armijo_constant ignored (not using 'lbfgs' "
- << "optimizer)." << std::endl;
-
- if (CLI::HasParam("wolfe"))
- Log::Warn << "Parameter --wolfe ignored (not using 'lbfgs' optimizer).\n";
-
- if (CLI::HasParam("max_line_search_trials"))
- Log::Warn << "Parameter --max_line_search_trials ignored (not using "
- << "'lbfgs' optimizer." << std::endl;
-
- if (CLI::HasParam("min_step"))
- Log::Warn << "Parameter --min_step ignored (not using 'lbfgs' optimizer)."
- << std::endl;
-
- if (CLI::HasParam("max_step"))
- Log::Warn << "Parameter --max_step ignored (not using 'lbfgs' optimizer)."
- << std::endl;
- }
- else if (optimizerType == "lbfgs")
- {
- if (CLI::HasParam("step_size"))
- Log::Warn << "Parameter --step_size ignored (not using 'sgd' optimizer)."
- << std::endl;
-
- if (CLI::HasParam("linear_scan"))
- Log::Warn << "Parameter --linear_scan ignored (not using 'sgd' "
- << "optimizer)." << std::endl;
- }
-
- const double stepSize = CLI::GetParam<double>("step_size");
- const size_t maxIterations = (size_t) CLI::GetParam<int>("max_iterations");
- const double tolerance = CLI::GetParam<double>("tolerance");
- const bool normalize = CLI::HasParam("normalize");
- const bool shuffle = !CLI::HasParam("linear_scan");
- const int numBasis = CLI::GetParam<int>("num_basis");
- const double armijoConstant = CLI::GetParam<double>("armijo_constant");
- const double wolfe = CLI::GetParam<double>("wolfe");
- const int maxLineSearchTrials = CLI::GetParam<int>("max_line_search_trials");
- const double minStep = CLI::GetParam<double>("min_step");
- const double maxStep = CLI::GetParam<double>("max_step");
-
- // Load data.
- arma::mat data;
- data::Load(inputFile.c_str(), data, true);
-
- // Do we want to load labels separately?
- arma::umat labels(data.n_cols, 1);
- if (labelsFile != "")
- {
- data::Load(labelsFile.c_str(), labels, true);
-
- if (labels.n_rows == 1)
- labels = trans(labels);
-
- if (labels.n_cols > 1)
- Log::Fatal << "Labels must have only one column or row!" << endl;
- }
- else
- {
- for (size_t i = 0; i < data.n_cols; i++)
- labels[i] = (int) data(data.n_rows - 1, i);
-
- data.shed_row(data.n_rows - 1);
- }
-
- arma::mat distance;
-
- // Normalize the data, if necessary.
- if (normalize)
- {
- // Find the minimum and maximum values for each dimension.
- arma::vec ranges = arma::max(data, 1) - arma::min(data, 1);
- for (size_t d = 0; d < ranges.n_elem; ++d)
- if (ranges[d] == 0.0)
- ranges[d] = 1; // A range of 0 produces NaN later on.
-
- distance = diagmat(1.0 / ranges);
- Log::Info << "Using normalized starting point for optimization."
- << std::endl;
- }
- else
- {
- distance.eye();
- }
-
- // Now create the NCA object and run the optimization.
- if (optimizerType == "sgd")
- {
- NCA<LMetric<2> > nca(data, labels.unsafe_col(0));
- nca.Optimizer().StepSize() = stepSize;
- nca.Optimizer().MaxIterations() = maxIterations;
- nca.Optimizer().Tolerance() = tolerance;
- nca.Optimizer().Shuffle() = shuffle;
-
- nca.LearnDistance(distance);
- }
- else if (optimizerType == "lbfgs")
- {
- NCA<LMetric<2>, L_BFGS> nca(data, labels.unsafe_col(0));
- nca.Optimizer().NumBasis() = numBasis;
- nca.Optimizer().MaxIterations() = maxIterations;
- nca.Optimizer().ArmijoConstant() = armijoConstant;
- nca.Optimizer().Wolfe() = wolfe;
- nca.Optimizer().MinGradientNorm() = tolerance;
- nca.Optimizer().MaxLineSearchTrials() = maxLineSearchTrials;
- nca.Optimizer().MinStep() = minStep;
- nca.Optimizer().MaxStep() = maxStep;
-
- nca.LearnDistance(distance);
- }
-
- // Save the output.
- data::Save(CLI::GetParam<string>("output_file").c_str(), distance, true);
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/nca/nca_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,260 @@
+/**
+ * @file nca_main.cpp
+ * @author Ryan Curtin
+ *
+ * Executable for Neighborhood Components Analysis.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+
+#include "nca.hpp"
+
+#include <mlpack/core/optimizers/lbfgs/lbfgs.hpp>
+
+// Define parameters.
+PROGRAM_INFO("Neighborhood Components Analysis (NCA)",
+ "This program implements Neighborhood Components Analysis, both a linear "
+ "dimensionality reduction technique and a distance learning technique. The"
+ " method seeks to improve k-nearest-neighbor classification on a dataset "
+ "by scaling the dimensions. The method is nonparametric, and does not "
+ "require a value of k. It works by using stochastic (\"soft\") neighbor "
+ "assignments and using optimization techniques over the gradient of the "
+ "accuracy of the neighbor assignments.\n"
+ "\n"
+ "To work, this algorithm needs labeled data. It can be given as the last "
+ "row of the input dataset (--input_file), or alternatively in a separate "
+ "file (--labels_file).\n"
+ "\n"
+ "This implementation of NCA uses either stochastic gradient descent or the "
+ "L_BFGS optimizer. Both of these optimizers do not guarantee global "
+ "convergence for a nonconvex objective function (NCA's objective function "
+ "is nonconvex), so the final results could depend on the random seed or "
+ "other optimizer parameters.\n"
+ "\n"
+ "Stochastic gradient descent, specified by --optimizer \"sgd\", depends "
+ "primarily on two parameters: the step size (--step_size) and the maximum "
+ "number of iterations (--max_iterations). In addition, a normalized "
+ "starting point can be used (--normalize), which is necessary if many "
+ "warnings of the form 'Denominator of p_i is 0!' are given. Tuning the "
+ "step size can be a tedious affair. In general, the step size is too large"
+ " if the objective is not mostly uniformly decreasing, or if zero-valued "
+ "denominator warnings are being issued. The step size is too small if the "
+ "objective is changing very slowly. Setting the termination condition can "
+ "be done easily once a good step size parameter is found; either increase "
+ "the maximum iterations to a large number and allow SGD to find a minimum, "
+ "or set the maximum iterations to 0 (allowing infinite iterations) and set "
+ "the tolerance (--tolerance) to define the maximum allowed difference "
+ "between objectives for SGD to terminate. Be careful -- setting the "
+ "tolerance instead of the maximum iterations can take a very long time and "
+ "may actually never converge due to the properties of the SGD optimizer.\n"
+ "\n"
+ "The L-BFGS optimizer, specified by --optimizer \"lbfgs\", uses a "
+ "back-tracking line search algorithm to minimize a function. The "
+ "following parameters are used by L-BFGS: --num_basis (specifies the number"
+ " of memory points used by L-BFGS), --max_iterations, --armijo_constant, "
+ "--wolfe, --tolerance (the optimization is terminated when the gradient "
+ "norm is below this value), --max_line_search_trials, --min_step and "
+ "--max_step (which both refer to the line search routine). For more "
+ "details on the L-BFGS optimizer, consult either the MLPACK L-BFGS "
+ "documentation (in lbfgs.hpp) or the vast set of published literature on "
+ "L-BFGS.\n"
+ "\n"
+ "By default, the SGD optimizer is used.");
+
+PARAM_STRING_REQ("input_file", "Input dataset to run NCA on.", "i");
+PARAM_STRING_REQ("output_file", "Output file for learned distance matrix.",
+ "o");
+PARAM_STRING("labels_file", "File of labels for input dataset.", "l", "");
+PARAM_STRING("optimizer", "Optimizer to use; \"sgd\" or \"lbfgs\".", "O", "");
+
+PARAM_FLAG("normalize", "Use a normalized starting point for optimization. This"
+ " is useful for when points are far apart, or when SGD is returning NaN.",
+ "N");
+
+PARAM_INT("max_iterations", "Maximum number of iterations for SGD or L-BFGS (0 "
+ "indicates no limit).", "n", 500000);
+PARAM_DOUBLE("tolerance", "Maximum tolerance for termination of SGD or L-BFGS.",
+ "t", 1e-7);
+
+PARAM_DOUBLE("step_size", "Step size for stochastic gradient descent (alpha).",
+ "a", 0.01);
+PARAM_FLAG("linear_scan", "Don't shuffle the order in which data points are "
+ "visited for SGD.", "L");
+
+PARAM_INT("num_basis", "Number of memory points to be stored for L-BFGS.", "N",
+ 5);
+PARAM_DOUBLE("armijo_constant", "Armijo constant for L-BFGS.", "A", 1e-4);
+PARAM_DOUBLE("wolfe", "Wolfe condition parameter for L-BFGS.", "w", 0.9);
+PARAM_INT("max_line_search_trials", "Maximum number of line search trials for "
+ "L-BFGS.", "L", 50);
+PARAM_DOUBLE("min_step", "Minimum step of line search for L-BFGS.", "m", 1e-20);
+PARAM_DOUBLE("max_step", "Maximum step of line search for L-BFGS.", "M", 1e20);
+
+PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
+
+
+using namespace mlpack;
+using namespace mlpack::nca;
+using namespace mlpack::metric;
+using namespace mlpack::optimization;
+using namespace std;
+
+int main(int argc, char* argv[])
+{
+ // Parse command line.
+ CLI::ParseCommandLine(argc, argv);
+
+ if (CLI::GetParam<int>("seed") != 0)
+ math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
+ else
+ math::RandomSeed((size_t) std::time(NULL));
+
+ const string inputFile = CLI::GetParam<string>("input_file");
+ const string labelsFile = CLI::GetParam<string>("labels_file");
+ const string outputFile = CLI::GetParam<string>("output_file");
+
+ const string optimizerType = CLI::GetParam<string>("optimizer");
+
+ if ((optimizerType != "sgd") && (optimizerType != "lbfgs"))
+ {
+ Log::Fatal << "Optimizer type '" << optimizerType << "' unknown; must be "
+ << "'sgd' or 'lbfgs'!" << std::endl;
+ }
+
+ // Warn on unused parameters.
+ if (optimizerType == "sgd")
+ {
+ if (CLI::HasParam("num_basis"))
+ Log::Warn << "Parameter --num_basis ignored (not using 'lbfgs' "
+ << "optimizer)." << std::endl;
+
+ if (CLI::HasParam("armijo_constant"))
+ Log::Warn << "Parameter --armijo_constant ignored (not using 'lbfgs' "
+ << "optimizer)." << std::endl;
+
+ if (CLI::HasParam("wolfe"))
+ Log::Warn << "Parameter --wolfe ignored (not using 'lbfgs' optimizer).\n";
+
+ if (CLI::HasParam("max_line_search_trials"))
+ Log::Warn << "Parameter --max_line_search_trials ignored (not using "
+ << "'lbfgs' optimizer." << std::endl;
+
+ if (CLI::HasParam("min_step"))
+ Log::Warn << "Parameter --min_step ignored (not using 'lbfgs' optimizer)."
+ << std::endl;
+
+ if (CLI::HasParam("max_step"))
+ Log::Warn << "Parameter --max_step ignored (not using 'lbfgs' optimizer)."
+ << std::endl;
+ }
+ else if (optimizerType == "lbfgs")
+ {
+ if (CLI::HasParam("step_size"))
+ Log::Warn << "Parameter --step_size ignored (not using 'sgd' optimizer)."
+ << std::endl;
+
+ if (CLI::HasParam("linear_scan"))
+ Log::Warn << "Parameter --linear_scan ignored (not using 'sgd' "
+ << "optimizer)." << std::endl;
+ }
+
+ const double stepSize = CLI::GetParam<double>("step_size");
+ const size_t maxIterations = (size_t) CLI::GetParam<int>("max_iterations");
+ const double tolerance = CLI::GetParam<double>("tolerance");
+ const bool normalize = CLI::HasParam("normalize");
+ const bool shuffle = !CLI::HasParam("linear_scan");
+ const int numBasis = CLI::GetParam<int>("num_basis");
+ const double armijoConstant = CLI::GetParam<double>("armijo_constant");
+ const double wolfe = CLI::GetParam<double>("wolfe");
+ const int maxLineSearchTrials = CLI::GetParam<int>("max_line_search_trials");
+ const double minStep = CLI::GetParam<double>("min_step");
+ const double maxStep = CLI::GetParam<double>("max_step");
+
+ // Load data.
+ arma::mat data;
+ data::Load(inputFile.c_str(), data, true);
+
+ // Do we want to load labels separately?
+ arma::umat labels(data.n_cols, 1);
+ if (labelsFile != "")
+ {
+ data::Load(labelsFile.c_str(), labels, true);
+
+ if (labels.n_rows == 1)
+ labels = trans(labels);
+
+ if (labels.n_cols > 1)
+ Log::Fatal << "Labels must have only one column or row!" << endl;
+ }
+ else
+ {
+ for (size_t i = 0; i < data.n_cols; i++)
+ labels[i] = (int) data(data.n_rows - 1, i);
+
+ data.shed_row(data.n_rows - 1);
+ }
+
+ arma::mat distance;
+
+ // Normalize the data, if necessary.
+ if (normalize)
+ {
+ // Find the minimum and maximum values for each dimension.
+ arma::vec ranges = arma::max(data, 1) - arma::min(data, 1);
+ for (size_t d = 0; d < ranges.n_elem; ++d)
+ if (ranges[d] == 0.0)
+ ranges[d] = 1; // A range of 0 produces NaN later on.
+
+ distance = diagmat(1.0 / ranges);
+ Log::Info << "Using normalized starting point for optimization."
+ << std::endl;
+ }
+ else
+ {
+ distance.eye();
+ }
+
+ // Now create the NCA object and run the optimization.
+ if (optimizerType == "sgd")
+ {
+ NCA<LMetric<2> > nca(data, labels.unsafe_col(0));
+ nca.Optimizer().StepSize() = stepSize;
+ nca.Optimizer().MaxIterations() = maxIterations;
+ nca.Optimizer().Tolerance() = tolerance;
+ nca.Optimizer().Shuffle() = shuffle;
+
+ nca.LearnDistance(distance);
+ }
+ else if (optimizerType == "lbfgs")
+ {
+ NCA<LMetric<2>, L_BFGS> nca(data, labels.unsafe_col(0));
+ nca.Optimizer().NumBasis() = numBasis;
+ nca.Optimizer().MaxIterations() = maxIterations;
+ nca.Optimizer().ArmijoConstant() = armijoConstant;
+ nca.Optimizer().Wolfe() = wolfe;
+ nca.Optimizer().MinGradientNorm() = tolerance;
+ nca.Optimizer().MaxLineSearchTrials() = maxLineSearchTrials;
+ nca.Optimizer().MinStep() = minStep;
+ nca.Optimizer().MaxStep() = maxStep;
+
+ nca.LearnDistance(distance);
+ }
+
+ // Save the output.
+ data::Save(CLI::GetParam<string>("output_file").c_str(), distance, true);
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_softmax_error_function.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/nca/nca_softmax_error_function.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_softmax_error_function.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,167 +0,0 @@
-/**
- * @file nca_softmax_error_function.hpp
- * @author Ryan Curtin
- *
- * Implementation of the stochastic neighbor assignment probability error
- * function (the "softmax error").
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NCA_NCA_SOFTMAX_ERROR_FUNCTION_HPP
-#define __MLPACK_METHODS_NCA_NCA_SOFTMAX_ERROR_FUNCTION_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace nca {
-
-/**
- * The "softmax" stochastic neighbor assignment probability function.
- *
- * The actual function is
- *
- * p_ij = (exp(-|| A x_i - A x_j || ^ 2)) /
- * (sum_{k != i} (exp(-|| A x_i - A x_k || ^ 2)))
- *
- * where x_n represents a point and A is the current scaling matrix.
- *
- * This class is more flexible than the original paper, allowing an arbitrary
- * metric function to be used in place of || A x_i - A x_j ||^2, meaning that
- * the squared Euclidean distance is not the only allowed metric for NCA.
- * However, that is probably the best way to use this class.
- *
- * In addition to the standard Evaluate() and Gradient() functions which MLPACK
- * optimizers use, overloads of Evaluate() and Gradient() are given which only
- * operate on one point in the dataset. This is useful for optimizers like
- * stochastic gradient descent (see mlpack::optimization::SGD).
- */
-template<typename MetricType = metric::SquaredEuclideanDistance>
-class SoftmaxErrorFunction
-{
- public:
- /**
- * Initialize with the given kernel; useful when the kernel has some state to
- * store, which is set elsewhere. If no kernel is given, an empty kernel is
- * used; this way, you can call the constructor with no arguments. A
- * reference to the dataset we will be optimizing over is also required.
- *
- * @param dataset Matrix containing the dataset.
- * @param labels Vector of class labels for each point in the dataset.
- * @param kernel Instantiated kernel (optional).
- */
- SoftmaxErrorFunction(const arma::mat& dataset,
- const arma::uvec& labels,
- MetricType metric = MetricType());
-
- /**
- * Evaluate the softmax function for the given covariance matrix. This is the
- * non-separable implementation, where the objective function is not
- * decomposed into the sum of several objective functions.
- *
- * @param covariance Covariance matrix of Mahalanobis distance.
- */
- double Evaluate(const arma::mat& covariance);
-
- /**
- * Evaluate the softmax objective function for the given covariance matrix on
- * only one point of the dataset. This is the separable implementation, where
- * the objective function is decomposed into the sum of many objective
- * functions, and here, only one of those constituent objective functions is
- * returned.
- *
- * @param covariance Covariance matrix of Mahalanobis distance.
- * @param i Index of point to use for objective function.
- */
- double Evaluate(const arma::mat& covariance, const size_t i);
-
- /**
- * Evaluate the gradient of the softmax function for the given covariance
- * matrix. This is the non-separable implementation, where the objective
- * function is not decomposed into the sum of several objective functions.
- *
- * @param covariance Covariance matrix of Mahalanobis distance.
- * @param gradient Matrix to store the calculated gradient in.
- */
- void Gradient(const arma::mat& covariance, arma::mat& gradient);
-
- /**
- * Evaluate the gradient of the softmax function for the given covariance
- * matrix on only one point of the dataset. This is the separable
- * implementation, where the objective function is decomposed into the sum of
- * many objective functions, and here, only one of those constituent objective
- * functions is returned.
- *
- * @param covariance Covariance matrix of Mahalanobis distance.
- * @param i Index of point to use for objective function.
- * @param gradient Matrix to store the calculated gradient in.
- */
- void Gradient(const arma::mat& covariance,
- const size_t i,
- arma::mat& gradient);
-
- /**
- * Get the initial point.
- */
- const arma::mat GetInitialPoint() const;
-
- /**
- * Get the number of functions the objective function can be decomposed into.
- * This is just the number of points in the dataset.
- */
- size_t NumFunctions() const { return dataset.n_cols; }
-
- private:
- const arma::mat& dataset;
- const arma::uvec& labels;
-
- MetricType metric;
-
- //! Last coordinates. Used for the non-separable Evaluate() and Gradient().
- arma::mat lastCoordinates;
- //! Stretched dataset. Kept internal to avoid memory reallocations.
- arma::mat stretchedDataset;
- //! Holds calculated p_i, for the non-separable Evaluate() and Gradient().
- arma::vec p;
- //! Holds denominators for calculation of p_ij, for the non-separable
- //! Evaluate() and Gradient().
- arma::vec denominators;
-
- //! False if nothing has ever been precalculated (only at construction time).
- bool precalculated;
-
- /**
- * Precalculate the denominators and numerators that will make up the p_ij,
- * but only if the coordinates matrix is different than the last coordinates
- * the Precalculate() method was run with. This method is only called by the
- * non-separable Evaluate() and Gradient().
- *
- * This will update last_coordinates_ and stretched_dataset_, and also
- * calculate the p_i and denominators_ which are used in the calculation of
- * p_i or p_ij. The calculation will be O((n * (n + 1)) / 2), which is not
- * great.
- *
- * @param coordinates Coordinates matrix to use for precalculation.
- */
- void Precalculate(const arma::mat& coordinates);
-};
-
-}; // namespace nca
-}; // namespace mlpack
-
-// Include implementation.
-#include "nca_softmax_error_function_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_softmax_error_function.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/nca/nca_softmax_error_function.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_softmax_error_function.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_softmax_error_function.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,167 @@
+/**
+ * @file nca_softmax_error_function.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the stochastic neighbor assignment probability error
+ * function (the "softmax error").
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NCA_NCA_SOFTMAX_ERROR_FUNCTION_HPP
+#define __MLPACK_METHODS_NCA_NCA_SOFTMAX_ERROR_FUNCTION_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace nca {
+
+/**
+ * The "softmax" stochastic neighbor assignment probability function.
+ *
+ * The actual function is
+ *
+ * p_ij = (exp(-|| A x_i - A x_j || ^ 2)) /
+ * (sum_{k != i} (exp(-|| A x_i - A x_k || ^ 2)))
+ *
+ * where x_n represents a point and A is the current scaling matrix.
+ *
+ * This class is more flexible than the original paper, allowing an arbitrary
+ * metric function to be used in place of || A x_i - A x_j ||^2, meaning that
+ * the squared Euclidean distance is not the only allowed metric for NCA.
+ * However, that is probably the best way to use this class.
+ *
+ * In addition to the standard Evaluate() and Gradient() functions which MLPACK
+ * optimizers use, overloads of Evaluate() and Gradient() are given which only
+ * operate on one point in the dataset. This is useful for optimizers like
+ * stochastic gradient descent (see mlpack::optimization::SGD).
+ */
+template<typename MetricType = metric::SquaredEuclideanDistance>
+class SoftmaxErrorFunction
+{
+ public:
+ /**
+ * Initialize with the given kernel; useful when the kernel has some state to
+ * store, which is set elsewhere. If no kernel is given, an empty kernel is
+ * used; this way, you can call the constructor with no arguments. A
+ * reference to the dataset we will be optimizing over is also required.
+ *
+ * @param dataset Matrix containing the dataset.
+ * @param labels Vector of class labels for each point in the dataset.
+ * @param kernel Instantiated kernel (optional).
+ */
+ SoftmaxErrorFunction(const arma::mat& dataset,
+ const arma::uvec& labels,
+ MetricType metric = MetricType());
+
+ /**
+ * Evaluate the softmax function for the given covariance matrix. This is the
+ * non-separable implementation, where the objective function is not
+ * decomposed into the sum of several objective functions.
+ *
+ * @param covariance Covariance matrix of Mahalanobis distance.
+ */
+ double Evaluate(const arma::mat& covariance);
+
+ /**
+ * Evaluate the softmax objective function for the given covariance matrix on
+ * only one point of the dataset. This is the separable implementation, where
+ * the objective function is decomposed into the sum of many objective
+ * functions, and here, only one of those constituent objective functions is
+ * returned.
+ *
+ * @param covariance Covariance matrix of Mahalanobis distance.
+ * @param i Index of point to use for objective function.
+ */
+ double Evaluate(const arma::mat& covariance, const size_t i);
+
+ /**
+ * Evaluate the gradient of the softmax function for the given covariance
+ * matrix. This is the non-separable implementation, where the objective
+ * function is not decomposed into the sum of several objective functions.
+ *
+ * @param covariance Covariance matrix of Mahalanobis distance.
+ * @param gradient Matrix to store the calculated gradient in.
+ */
+ void Gradient(const arma::mat& covariance, arma::mat& gradient);
+
+ /**
+ * Evaluate the gradient of the softmax function for the given covariance
+ * matrix on only one point of the dataset. This is the separable
+ * implementation, where the objective function is decomposed into the sum of
+ * many objective functions, and here, only one of those constituent objective
+ * functions is returned.
+ *
+ * @param covariance Covariance matrix of Mahalanobis distance.
+ * @param i Index of point to use for objective function.
+ * @param gradient Matrix to store the calculated gradient in.
+ */
+ void Gradient(const arma::mat& covariance,
+ const size_t i,
+ arma::mat& gradient);
+
+ /**
+ * Get the initial point.
+ */
+ const arma::mat GetInitialPoint() const;
+
+ /**
+ * Get the number of functions the objective function can be decomposed into.
+ * This is just the number of points in the dataset.
+ */
+ size_t NumFunctions() const { return dataset.n_cols; }
+
+ private:
+ const arma::mat& dataset;
+ const arma::uvec& labels;
+
+ MetricType metric;
+
+ //! Last coordinates. Used for the non-separable Evaluate() and Gradient().
+ arma::mat lastCoordinates;
+ //! Stretched dataset. Kept internal to avoid memory reallocations.
+ arma::mat stretchedDataset;
+ //! Holds calculated p_i, for the non-separable Evaluate() and Gradient().
+ arma::vec p;
+ //! Holds denominators for calculation of p_ij, for the non-separable
+ //! Evaluate() and Gradient().
+ arma::vec denominators;
+
+ //! False if nothing has ever been precalculated (only at construction time).
+ bool precalculated;
+
+ /**
+ * Precalculate the denominators and numerators that will make up the p_ij,
+ * but only if the coordinates matrix is different than the last coordinates
+ * the Precalculate() method was run with. This method is only called by the
+ * non-separable Evaluate() and Gradient().
+ *
+ * This will update last_coordinates_ and stretched_dataset_, and also
+ * calculate the p_i and denominators_ which are used in the calculation of
+ * p_i or p_ij. The calculation will be O((n * (n + 1)) / 2), which is not
+ * great.
+ *
+ * @param coordinates Coordinates matrix to use for precalculation.
+ */
+ void Precalculate(const arma::mat& coordinates);
+};
+
+}; // namespace nca
+}; // namespace mlpack
+
+// Include implementation.
+#include "nca_softmax_error_function_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_softmax_error_function_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/nca/nca_softmax_error_function_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_softmax_error_function_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,287 +0,0 @@
-/**
- * @file nca_softmax_impl.h
- * @author Ryan Curtin
- *
- * Implementation of the Softmax error function.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NCA_NCA_SOFTMAX_ERROR_FUNCTCLIN_IMPL_H
-#define __MLPACK_METHODS_NCA_NCA_SOFTMAX_ERROR_FUNCTCLIN_IMPL_H
-
-// In case it hasn't been included already.
-#include "nca_softmax_error_function.hpp"
-
-namespace mlpack {
-namespace nca {
-
-// Initialize with the given kernel.
-template<typename MetricType>
-SoftmaxErrorFunction<MetricType>::SoftmaxErrorFunction(const arma::mat& dataset,
- const arma::uvec& labels,
- MetricType metric) :
- dataset(dataset),
- labels(labels),
- metric(metric),
- precalculated(false)
-{ /* nothing to do */ }
-
-//! The non-separable implementation, which uses Precalculate() to save time.
-template<typename MetricType>
-double SoftmaxErrorFunction<MetricType>::Evaluate(const arma::mat& coordinates)
-{
- // Calculate the denominators and numerators, if necessary.
- Precalculate(coordinates);
-
- return -accu(p); // Sum of p_i for all i. We negate because our solver
- // minimizes, not maximizes.
-};
-
-//! The separated objective function, which does not use Precalculate().
-template<typename MetricType>
-double SoftmaxErrorFunction<MetricType>::Evaluate(const arma::mat& coordinates,
- const size_t i)
-{
- // Unfortunately each evaluation will take O(N) time because it requires a
- // scan over all points in the dataset. Our objective is to compute p_i.
- double denominator = 0;
- double numerator = 0;
-
- // It's quicker to do this now than one point at a time later.
- stretchedDataset = coordinates * dataset;
-
- for (size_t k = 0; k < dataset.n_cols; ++k)
- {
- // Don't consider the case where the points are the same.
- if (k == i)
- continue;
-
- // We want to evaluate exp(-D(A x_i, A x_k)).
- double eval = std::exp(-metric.Evaluate(stretchedDataset.unsafe_col(i),
- stretchedDataset.unsafe_col(k)));
-
- // If they are in the same
- if (labels[i] == labels[k])
- numerator += eval;
-
- denominator += eval;
- }
-
- // Now the result is just a simple division, but we have to be sure that the
- // denominator is not 0.
- if (denominator == 0.0)
- {
- Log::Warn << "Denominator of p_" << i << " is 0!" << std::endl;
- return 0;
- }
-
- return -(numerator / denominator); // Negate because the optimizer is a
- // minimizer.
-}
-
-//! The non-separable implementation, where Precalculate() is used.
-template<typename MetricType>
-void SoftmaxErrorFunction<MetricType>::Gradient(const arma::mat& coordinates,
- arma::mat& gradient)
-{
- // Calculate the denominators and numerators, if necessary.
- Precalculate(coordinates);
-
- // Now, we handle the summation over i:
- // sum_i (p_i sum_k (p_ik x_ik x_ik^T) -
- // sum_{j in class of i} (p_ij x_ij x_ij^T)
- // We can algebraically manipulate the whole thing to produce a more
- // memory-friendly way to calculate this. Looping over each i and k (again
- // O((n * (n + 1)) / 2) as with the last step, we can add the following to the
- // sum:
- //
- // if class of i is the same as the class of k, add
- // (((p_i - (1 / p_i)) p_ik) + ((p_k - (1 / p_k)) p_ki)) x_ik x_ik^T
- // otherwise, add
- // (p_i p_ik + p_k p_ki) x_ik x_ik^T
- arma::mat sum;
- sum.zeros(stretchedDataset.n_rows, stretchedDataset.n_rows);
- for (size_t i = 0; i < stretchedDataset.n_cols; i++)
- {
- for (size_t k = (i + 1); k < stretchedDataset.n_cols; k++)
- {
- // Calculate p_ik and p_ki first.
- double eval = exp(-metric.Evaluate(stretchedDataset.unsafe_col(i),
- stretchedDataset.unsafe_col(k)));
- double p_ik = 0, p_ki = 0;
- p_ik = eval / denominators(i);
- p_ki = eval / denominators(k);
-
- // Subtract x_i from x_k. We are not using stretched points here.
- arma::vec x_ik = dataset.col(i) - dataset.col(k);
- arma::mat secondTerm = (x_ik * trans(x_ik));
-
- if (labels[i] == labels[k])
- sum += ((p[i] - 1) * p_ik + (p[k] - 1) * p_ki) * secondTerm;
- else
- sum += (p[i] * p_ik + p[k] * p_ki) * secondTerm;
- }
- }
-
- // Assemble the final gradient.
- gradient = -2 * coordinates * sum;
-}
-
-//! The separable implementation.
-template<typename MetricType>
-void SoftmaxErrorFunction<MetricType>::Gradient(const arma::mat& coordinates,
- const size_t i,
- arma::mat& gradient)
-{
- // We will need to calculate p_i before this evaluation is done, so these two
- // variables will hold the information necessary for that.
- double numerator = 0;
- double denominator = 0;
-
- // The gradient involves two matrix terms which are eventually combined into
- // one.
- arma::mat firstTerm;
- arma::mat secondTerm;
-
- firstTerm.zeros(coordinates.n_rows, coordinates.n_cols);
- secondTerm.zeros(coordinates.n_rows, coordinates.n_cols);
-
- // Compute the stretched dataset.
- stretchedDataset = coordinates * dataset;
-
- for (size_t k = 0; k < dataset.n_cols; ++k)
- {
- // Don't consider the case where the points are the same.
- if (i == k)
- continue;
-
- // Calculate the numerator of p_ik.
- double eval = exp(-metric.Evaluate(stretchedDataset.unsafe_col(i),
- stretchedDataset.unsafe_col(k)));
-
- // If the points are in the same class, we must add to the second term of
- // the gradient as well as the numerator of p_i. We will divide by the
- // denominator of p_ik later. For x_ik we are not using stretched points.
- arma::vec x_ik = dataset.col(i) - dataset.col(k);
- if (labels[i] == labels[k])
- {
- numerator += eval;
- secondTerm += eval * x_ik * trans(x_ik);
- }
-
- // We always have to add to the denominator of p_i and the first term of the
- // gradient computation. We will divide by the denominator of p_ik later.
- denominator += eval;
- firstTerm += eval * x_ik * trans(x_ik);
- }
-
- // Calculate p_i.
- double p = 0;
- if (denominator == 0)
- {
- Log::Warn << "Denominator of p_" << i << " is 0!" << std::endl;
- // If the denominator is zero, then all p_ik should be zero and there is
- // no gradient contribution from this point.
- gradient.zeros(coordinates.n_rows, coordinates.n_rows);
- return;
- }
- else
- {
- p = numerator / denominator;
- firstTerm /= denominator;
- secondTerm /= denominator;
- }
-
- // Now multiply the first term by p_i, and add the two together and multiply
- // all by 2 * A. We negate it though, because our optimizer is a minimizer.
- gradient = -2 * coordinates * (p * firstTerm - secondTerm);
-}
-
-template<typename MetricType>
-const arma::mat SoftmaxErrorFunction<MetricType>::GetInitialPoint() const
-{
- return arma::eye<arma::mat>(dataset.n_rows, dataset.n_rows);
-}
-
-template<typename MetricType>
-void SoftmaxErrorFunction<MetricType>::Precalculate(
- const arma::mat& coordinates)
-{
- // Ensure it is the right size.
- lastCoordinates.set_size(coordinates.n_rows, coordinates.n_cols);
-
- // Make sure the calculation is necessary.
- if ((accu(coordinates == lastCoordinates) == coordinates.n_elem) &&
- precalculated)
- return; // No need to calculate; we already have this stuff saved.
-
- // Coordinates are different; save the new ones, and stretch the dataset.
- lastCoordinates = coordinates;
- stretchedDataset = coordinates * dataset;
-
- // For each point i, we must evaluate the softmax function:
- // p_ij = exp( -K(x_i, x_j) ) / ( sum_{k != i} ( exp( -K(x_i, x_k) )))
- // p_i = sum_{j in class of i} p_ij
- // We will do this by keeping track of the denominators for each i as well as
- // the numerators (the sum for all j in class of i). This will be on the
- // order of O((n * (n + 1)) / 2), which really isn't all that great.
- p.zeros(stretchedDataset.n_cols);
- denominators.zeros(stretchedDataset.n_cols);
- for (size_t i = 0; i < stretchedDataset.n_cols; i++)
- {
- for (size_t j = (i + 1); j < stretchedDataset.n_cols; j++)
- {
- // Evaluate exp(-d(x_i, x_j)).
- double eval = exp(-metric.Evaluate(stretchedDataset.unsafe_col(i),
- stretchedDataset.unsafe_col(j)));
-
- // Add this to the denominators of both p_i and p_j: K(i, j) = K(j, i).
- denominators[i] += eval;
- denominators[j] += eval;
-
- // If i and j are the same class, add to numerator of both.
- if (labels[i] == labels[j])
- {
- p[i] += eval;
- p[j] += eval;
- }
- }
- }
-
- // Divide p_i by their denominators.
- p /= denominators;
-
- // Clean up any bad values.
- for (size_t i = 0; i < stretchedDataset.n_cols; i++)
- {
- if (denominators[i] == 0.0)
- {
- Log::Debug << "Denominator of p_{" << i << ", j} is 0." << std::endl;
-
- // Set to usable values.
- denominators[i] = std::numeric_limits<double>::infinity();
- p[i] = 0;
- }
- }
-
- // We've done a precalculation. Mark it as done.
- precalculated = true;
-}
-
-}; // namespace nca
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_softmax_error_function_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/nca/nca_softmax_error_function_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_softmax_error_function_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nca/nca_softmax_error_function_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,287 @@
+/**
+ * @file nca_softmax_impl.h
+ * @author Ryan Curtin
+ *
+ * Implementation of the Softmax error function.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NCA_NCA_SOFTMAX_ERROR_FUNCTCLIN_IMPL_H
+#define __MLPACK_METHODS_NCA_NCA_SOFTMAX_ERROR_FUNCTCLIN_IMPL_H
+
+// In case it hasn't been included already.
+#include "nca_softmax_error_function.hpp"
+
+namespace mlpack {
+namespace nca {
+
+// Initialize with the given kernel.
+template<typename MetricType>
+SoftmaxErrorFunction<MetricType>::SoftmaxErrorFunction(const arma::mat& dataset,
+ const arma::uvec& labels,
+ MetricType metric) :
+ dataset(dataset),
+ labels(labels),
+ metric(metric),
+ precalculated(false)
+{ /* nothing to do */ }
+
+//! The non-separable implementation, which uses Precalculate() to save time.
+template<typename MetricType>
+double SoftmaxErrorFunction<MetricType>::Evaluate(const arma::mat& coordinates)
+{
+ // Calculate the denominators and numerators, if necessary.
+ Precalculate(coordinates);
+
+ return -accu(p); // Sum of p_i for all i. We negate because our solver
+ // minimizes, not maximizes.
+};
+
+//! The separated objective function, which does not use Precalculate().
+template<typename MetricType>
+double SoftmaxErrorFunction<MetricType>::Evaluate(const arma::mat& coordinates,
+ const size_t i)
+{
+ // Unfortunately each evaluation will take O(N) time because it requires a
+ // scan over all points in the dataset. Our objective is to compute p_i.
+ double denominator = 0;
+ double numerator = 0;
+
+ // It's quicker to do this now than one point at a time later.
+ stretchedDataset = coordinates * dataset;
+
+ for (size_t k = 0; k < dataset.n_cols; ++k)
+ {
+ // Don't consider the case where the points are the same.
+ if (k == i)
+ continue;
+
+ // We want to evaluate exp(-D(A x_i, A x_k)).
+ double eval = std::exp(-metric.Evaluate(stretchedDataset.unsafe_col(i),
+ stretchedDataset.unsafe_col(k)));
+
+ // If they are in the same
+ if (labels[i] == labels[k])
+ numerator += eval;
+
+ denominator += eval;
+ }
+
+ // Now the result is just a simple division, but we have to be sure that the
+ // denominator is not 0.
+ if (denominator == 0.0)
+ {
+ Log::Warn << "Denominator of p_" << i << " is 0!" << std::endl;
+ return 0;
+ }
+
+ return -(numerator / denominator); // Negate because the optimizer is a
+ // minimizer.
+}
+
+//! The non-separable implementation, where Precalculate() is used.
+template<typename MetricType>
+void SoftmaxErrorFunction<MetricType>::Gradient(const arma::mat& coordinates,
+ arma::mat& gradient)
+{
+ // Calculate the denominators and numerators, if necessary.
+ Precalculate(coordinates);
+
+ // Now, we handle the summation over i:
+ // sum_i (p_i sum_k (p_ik x_ik x_ik^T) -
+ // sum_{j in class of i} (p_ij x_ij x_ij^T)
+ // We can algebraically manipulate the whole thing to produce a more
+ // memory-friendly way to calculate this. Looping over each i and k (again
+ // O((n * (n + 1)) / 2) as with the last step, we can add the following to the
+ // sum:
+ //
+ // if class of i is the same as the class of k, add
+ // (((p_i - (1 / p_i)) p_ik) + ((p_k - (1 / p_k)) p_ki)) x_ik x_ik^T
+ // otherwise, add
+ // (p_i p_ik + p_k p_ki) x_ik x_ik^T
+ arma::mat sum;
+ sum.zeros(stretchedDataset.n_rows, stretchedDataset.n_rows);
+ for (size_t i = 0; i < stretchedDataset.n_cols; i++)
+ {
+ for (size_t k = (i + 1); k < stretchedDataset.n_cols; k++)
+ {
+ // Calculate p_ik and p_ki first.
+ double eval = exp(-metric.Evaluate(stretchedDataset.unsafe_col(i),
+ stretchedDataset.unsafe_col(k)));
+ double p_ik = 0, p_ki = 0;
+ p_ik = eval / denominators(i);
+ p_ki = eval / denominators(k);
+
+ // Subtract x_i from x_k. We are not using stretched points here.
+ arma::vec x_ik = dataset.col(i) - dataset.col(k);
+ arma::mat secondTerm = (x_ik * trans(x_ik));
+
+ if (labels[i] == labels[k])
+ sum += ((p[i] - 1) * p_ik + (p[k] - 1) * p_ki) * secondTerm;
+ else
+ sum += (p[i] * p_ik + p[k] * p_ki) * secondTerm;
+ }
+ }
+
+ // Assemble the final gradient.
+ gradient = -2 * coordinates * sum;
+}
+
+//! The separable implementation.
+template<typename MetricType>
+void SoftmaxErrorFunction<MetricType>::Gradient(const arma::mat& coordinates,
+ const size_t i,
+ arma::mat& gradient)
+{
+ // We will need to calculate p_i before this evaluation is done, so these two
+ // variables will hold the information necessary for that.
+ double numerator = 0;
+ double denominator = 0;
+
+ // The gradient involves two matrix terms which are eventually combined into
+ // one.
+ arma::mat firstTerm;
+ arma::mat secondTerm;
+
+ firstTerm.zeros(coordinates.n_rows, coordinates.n_cols);
+ secondTerm.zeros(coordinates.n_rows, coordinates.n_cols);
+
+ // Compute the stretched dataset.
+ stretchedDataset = coordinates * dataset;
+
+ for (size_t k = 0; k < dataset.n_cols; ++k)
+ {
+ // Don't consider the case where the points are the same.
+ if (i == k)
+ continue;
+
+ // Calculate the numerator of p_ik.
+ double eval = exp(-metric.Evaluate(stretchedDataset.unsafe_col(i),
+ stretchedDataset.unsafe_col(k)));
+
+ // If the points are in the same class, we must add to the second term of
+ // the gradient as well as the numerator of p_i. We will divide by the
+ // denominator of p_ik later. For x_ik we are not using stretched points.
+ arma::vec x_ik = dataset.col(i) - dataset.col(k);
+ if (labels[i] == labels[k])
+ {
+ numerator += eval;
+ secondTerm += eval * x_ik * trans(x_ik);
+ }
+
+ // We always have to add to the denominator of p_i and the first term of the
+ // gradient computation. We will divide by the denominator of p_ik later.
+ denominator += eval;
+ firstTerm += eval * x_ik * trans(x_ik);
+ }
+
+ // Calculate p_i.
+ double p = 0;
+ if (denominator == 0)
+ {
+ Log::Warn << "Denominator of p_" << i << " is 0!" << std::endl;
+ // If the denominator is zero, then all p_ik should be zero and there is
+ // no gradient contribution from this point.
+ gradient.zeros(coordinates.n_rows, coordinates.n_rows);
+ return;
+ }
+ else
+ {
+ p = numerator / denominator;
+ firstTerm /= denominator;
+ secondTerm /= denominator;
+ }
+
+ // Now multiply the first term by p_i, and add the two together and multiply
+ // all by 2 * A. We negate it though, because our optimizer is a minimizer.
+ gradient = -2 * coordinates * (p * firstTerm - secondTerm);
+}
+
+template<typename MetricType>
+const arma::mat SoftmaxErrorFunction<MetricType>::GetInitialPoint() const
+{
+ return arma::eye<arma::mat>(dataset.n_rows, dataset.n_rows);
+}
+
+template<typename MetricType>
+void SoftmaxErrorFunction<MetricType>::Precalculate(
+ const arma::mat& coordinates)
+{
+ // Ensure it is the right size.
+ lastCoordinates.set_size(coordinates.n_rows, coordinates.n_cols);
+
+ // Make sure the calculation is necessary.
+ if ((accu(coordinates == lastCoordinates) == coordinates.n_elem) &&
+ precalculated)
+ return; // No need to calculate; we already have this stuff saved.
+
+ // Coordinates are different; save the new ones, and stretch the dataset.
+ lastCoordinates = coordinates;
+ stretchedDataset = coordinates * dataset;
+
+ // For each point i, we must evaluate the softmax function:
+ // p_ij = exp( -K(x_i, x_j) ) / ( sum_{k != i} ( exp( -K(x_i, x_k) )))
+ // p_i = sum_{j in class of i} p_ij
+ // We will do this by keeping track of the denominators for each i as well as
+ // the numerators (the sum for all j in class of i). This will be on the
+ // order of O((n * (n + 1)) / 2), which really isn't all that great.
+ p.zeros(stretchedDataset.n_cols);
+ denominators.zeros(stretchedDataset.n_cols);
+ for (size_t i = 0; i < stretchedDataset.n_cols; i++)
+ {
+ for (size_t j = (i + 1); j < stretchedDataset.n_cols; j++)
+ {
+ // Evaluate exp(-d(x_i, x_j)).
+ double eval = exp(-metric.Evaluate(stretchedDataset.unsafe_col(i),
+ stretchedDataset.unsafe_col(j)));
+
+ // Add this to the denominators of both p_i and p_j: K(i, j) = K(j, i).
+ denominators[i] += eval;
+ denominators[j] += eval;
+
+ // If i and j are the same class, add to numerator of both.
+ if (labels[i] == labels[j])
+ {
+ p[i] += eval;
+ p[j] += eval;
+ }
+ }
+ }
+
+ // Divide p_i by their denominators.
+ p /= denominators;
+
+ // Clean up any bad values.
+ for (size_t i = 0; i < stretchedDataset.n_cols; i++)
+ {
+ if (denominators[i] == 0.0)
+ {
+ Log::Debug << "Denominator of p_{" << i << ", j} is 0." << std::endl;
+
+ // Set to usable values.
+ denominators[i] = std::numeric_limits<double>::infinity();
+ p[i] = 0;
+ }
+ }
+
+ // We've done a precalculation. Mark it as done.
+ precalculated = true;
+}
+
+}; // namespace nca
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/allkfn_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/allkfn_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/allkfn_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,210 +0,0 @@
-/**
- * @file allkfn_main.cpp
- * @author Ryan Curtin
- *
- * Implementation of the AllkFN executable. Allows some number of standard
- * options.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-
-#include <string>
-#include <fstream>
-#include <iostream>
-
-#include "neighbor_search.hpp"
-#include "unmap.hpp"
-
-using namespace std;
-using namespace mlpack;
-using namespace mlpack::neighbor;
-using namespace mlpack::tree;
-
-// Information about the program itself.
-PROGRAM_INFO("All K-Furthest-Neighbors",
- "This program will calculate the all k-furthest-neighbors of a set of "
- "points. You may specify a separate set of reference points and query "
- "points, or just a reference set which will be used as both the reference "
- "and query set."
- "\n\n"
- "For example, the following will calculate the 5 furthest neighbors of each"
- "point in 'input.csv' and store the distances in 'distances.csv' and the "
- "neighbors in the file 'neighbors.csv':"
- "\n\n"
- "$ allkfn --k=5 --reference_file=input.csv --distances_file=distances.csv\n"
- " --neighbors_file=neighbors.csv"
- "\n\n"
- "The output files are organized such that row i and column j in the "
- "neighbors output file corresponds to the index of the point in the "
- "reference set which is the i'th furthest neighbor from the point in the "
- "query set with index j. Row i and column j in the distances output file "
- "corresponds to the distance between those two points.");
-
-// Define our input parameters that this program will take.
-PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
- "r");
-PARAM_INT_REQ("k", "Number of furthest neighbors to find.", "k");
-PARAM_STRING_REQ("distances_file", "File to output distances into.", "d");
-PARAM_STRING_REQ("neighbors_file", "File to output neighbors into.", "n");
-
-PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
-
-PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
-PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
-PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
- "dual-tree search.", "s");
-
-int main(int argc, char *argv[])
-{
- // Give CLI the command line parameters the user passed in.
- CLI::ParseCommandLine(argc, argv);
-
- // Get all the parameters.
- string referenceFile = CLI::GetParam<string>("reference_file");
-
- string distancesFile = CLI::GetParam<string>("distances_file");
- string neighborsFile = CLI::GetParam<string>("neighbors_file");
-
- int lsInt = CLI::GetParam<int>("leaf_size");
-
- size_t k = CLI::GetParam<int>("k");
-
- bool naive = CLI::HasParam("naive");
- bool singleMode = CLI::HasParam("single_mode");
-
- arma::mat referenceData;
- arma::mat queryData; // So it doesn't go out of scope.
- data::Load(referenceFile.c_str(), referenceData, true);
-
- Log::Info << "Loaded reference data from '" << referenceFile << "' ("
- << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
-
- // Sanity check on k value: must be greater than 0, must be less than the
- // number of reference points.
- if (k > referenceData.n_cols)
- {
- Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
- Log::Fatal << "than or equal to the number of reference points (";
- Log::Fatal << referenceData.n_cols << ")." << endl;
- }
-
- // Sanity check on leaf size.
- if (lsInt < 0)
- {
- Log::Fatal << "Invalid leaf size: " << lsInt << ". Must be greater "
- "than or equal to 0." << endl;
- }
- size_t leafSize = lsInt;
-
- // Naive mode overrides single mode.
- if (singleMode && naive)
- {
- Log::Warn << "--single_mode ignored because --naive is present." << endl;
- }
-
- if (naive)
- leafSize = referenceData.n_cols;
-
- arma::Mat<size_t> neighbors;
- arma::mat distances;
-
- AllkFN* allkfn = NULL;
-
- std::vector<size_t> oldFromNewRefs;
-
- // Build trees by hand, so we can save memory: if we pass a tree to
- // NeighborSearch, it does not copy the matrix.
- Log::Info << "Building reference tree..." << endl;
- Timer::Start("reference_tree_building");
-
- BinarySpaceTree<bound::HRectBound<2>, QueryStat<FurthestNeighborSort> >
- refTree(referenceData, oldFromNewRefs, leafSize);
- BinarySpaceTree<bound::HRectBound<2>, QueryStat<FurthestNeighborSort> >*
- queryTree = NULL; // Empty for now.
-
- Timer::Stop("reference_tree_building");
-
- std::vector<size_t> oldFromNewQueries;
-
- if (CLI::GetParam<string>("query_file") != "")
- {
- string queryFile = CLI::GetParam<string>("query_file");
-
- data::Load(queryFile.c_str(), queryData, true);
-
- Log::Info << "Loaded query data from '" << queryFile << "' ("
- << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
-
- Log::Info << "Building query tree..." << endl;
-
- if (naive && leafSize < queryData.n_cols)
- leafSize = queryData.n_cols;
-
- // Build trees by hand, so we can save memory: if we pass a tree to
- // NeighborSearch, it does not copy the matrix.
- Timer::Start("query_tree_building");
-
- queryTree = new BinarySpaceTree<bound::HRectBound<2>,
- QueryStat<FurthestNeighborSort> >(queryData, oldFromNewQueries,
- leafSize);
-
- Timer::Stop("query_tree_building");
-
- allkfn = new AllkFN(&refTree, queryTree, referenceData, queryData,
- singleMode);
-
- Log::Info << "Tree built." << endl;
- }
- else
- {
- allkfn = new AllkFN(&refTree, referenceData, singleMode);
-
- Log::Info << "Trees built." << endl;
- }
-
- Log::Info << "Computing " << k << " nearest neighbors..." << endl;
- allkfn->Search(k, neighbors, distances);
-
- Log::Info << "Neighbors computed." << endl;
-
- // We have to map back to the original indices from before the tree
- // construction.
- Log::Info << "Re-mapping indices..." << endl;
-
- arma::mat distancesOut(distances.n_rows, distances.n_cols);
- arma::Mat<size_t> neighborsOut(neighbors.n_rows, neighbors.n_cols);
-
- // Map the points back to their original locations.
- if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
- Unmap(neighbors, distances, oldFromNewRefs, oldFromNewQueries, neighborsOut,
- distancesOut);
- else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
- Unmap(neighbors, distances, oldFromNewRefs, neighborsOut, distancesOut);
- else
- Unmap(neighbors, distances, oldFromNewRefs, oldFromNewRefs, neighborsOut,
- distancesOut);
-
- // Clean up.
- if (queryTree)
- delete queryTree;
-
- // Save output.
- data::Save(distancesFile, distancesOut);
- data::Save(neighborsFile, neighborsOut);
-
- delete allkfn;
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/allkfn_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/allkfn_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/allkfn_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/allkfn_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,210 @@
+/**
+ * @file allkfn_main.cpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the AllkFN executable. Allows some number of standard
+ * options.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+
+#include <string>
+#include <fstream>
+#include <iostream>
+
+#include "neighbor_search.hpp"
+#include "unmap.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::neighbor;
+using namespace mlpack::tree;
+
+// Information about the program itself.
+PROGRAM_INFO("All K-Furthest-Neighbors",
+ "This program will calculate the all k-furthest-neighbors of a set of "
+ "points. You may specify a separate set of reference points and query "
+ "points, or just a reference set which will be used as both the reference "
+ "and query set."
+ "\n\n"
+ "For example, the following will calculate the 5 furthest neighbors of each"
+ "point in 'input.csv' and store the distances in 'distances.csv' and the "
+ "neighbors in the file 'neighbors.csv':"
+ "\n\n"
+ "$ allkfn --k=5 --reference_file=input.csv --distances_file=distances.csv\n"
+ " --neighbors_file=neighbors.csv"
+ "\n\n"
+ "The output files are organized such that row i and column j in the "
+ "neighbors output file corresponds to the index of the point in the "
+ "reference set which is the i'th furthest neighbor from the point in the "
+ "query set with index j. Row i and column j in the distances output file "
+ "corresponds to the distance between those two points.");
+
+// Define our input parameters that this program will take.
+PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
+ "r");
+PARAM_INT_REQ("k", "Number of furthest neighbors to find.", "k");
+PARAM_STRING_REQ("distances_file", "File to output distances into.", "d");
+PARAM_STRING_REQ("neighbors_file", "File to output neighbors into.", "n");
+
+PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
+
+PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
+PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
+PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
+ "dual-tree search.", "s");
+
+int main(int argc, char *argv[])
+{
+ // Give CLI the command line parameters the user passed in.
+ CLI::ParseCommandLine(argc, argv);
+
+ // Get all the parameters.
+ string referenceFile = CLI::GetParam<string>("reference_file");
+
+ string distancesFile = CLI::GetParam<string>("distances_file");
+ string neighborsFile = CLI::GetParam<string>("neighbors_file");
+
+ int lsInt = CLI::GetParam<int>("leaf_size");
+
+ size_t k = CLI::GetParam<int>("k");
+
+ bool naive = CLI::HasParam("naive");
+ bool singleMode = CLI::HasParam("single_mode");
+
+ arma::mat referenceData;
+ arma::mat queryData; // So it doesn't go out of scope.
+ data::Load(referenceFile.c_str(), referenceData, true);
+
+ Log::Info << "Loaded reference data from '" << referenceFile << "' ("
+ << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
+
+ // Sanity check on k value: must be greater than 0, must be less than the
+ // number of reference points.
+ if (k > referenceData.n_cols)
+ {
+ Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
+ Log::Fatal << "than or equal to the number of reference points (";
+ Log::Fatal << referenceData.n_cols << ")." << endl;
+ }
+
+ // Sanity check on leaf size.
+ if (lsInt < 0)
+ {
+ Log::Fatal << "Invalid leaf size: " << lsInt << ". Must be greater "
+ "than or equal to 0." << endl;
+ }
+ size_t leafSize = lsInt;
+
+ // Naive mode overrides single mode.
+ if (singleMode && naive)
+ {
+ Log::Warn << "--single_mode ignored because --naive is present." << endl;
+ }
+
+ if (naive)
+ leafSize = referenceData.n_cols;
+
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ AllkFN* allkfn = NULL;
+
+ std::vector<size_t> oldFromNewRefs;
+
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ Log::Info << "Building reference tree..." << endl;
+ Timer::Start("reference_tree_building");
+
+ BinarySpaceTree<bound::HRectBound<2>, QueryStat<FurthestNeighborSort> >
+ refTree(referenceData, oldFromNewRefs, leafSize);
+ BinarySpaceTree<bound::HRectBound<2>, QueryStat<FurthestNeighborSort> >*
+ queryTree = NULL; // Empty for now.
+
+ Timer::Stop("reference_tree_building");
+
+ std::vector<size_t> oldFromNewQueries;
+
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ string queryFile = CLI::GetParam<string>("query_file");
+
+ data::Load(queryFile.c_str(), queryData, true);
+
+ Log::Info << "Loaded query data from '" << queryFile << "' ("
+ << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+
+ Log::Info << "Building query tree..." << endl;
+
+ if (naive && leafSize < queryData.n_cols)
+ leafSize = queryData.n_cols;
+
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ Timer::Start("query_tree_building");
+
+ queryTree = new BinarySpaceTree<bound::HRectBound<2>,
+ QueryStat<FurthestNeighborSort> >(queryData, oldFromNewQueries,
+ leafSize);
+
+ Timer::Stop("query_tree_building");
+
+ allkfn = new AllkFN(&refTree, queryTree, referenceData, queryData,
+ singleMode);
+
+ Log::Info << "Tree built." << endl;
+ }
+ else
+ {
+ allkfn = new AllkFN(&refTree, referenceData, singleMode);
+
+ Log::Info << "Trees built." << endl;
+ }
+
+ Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+ allkfn->Search(k, neighbors, distances);
+
+ Log::Info << "Neighbors computed." << endl;
+
+ // We have to map back to the original indices from before the tree
+ // construction.
+ Log::Info << "Re-mapping indices..." << endl;
+
+ arma::mat distancesOut(distances.n_rows, distances.n_cols);
+ arma::Mat<size_t> neighborsOut(neighbors.n_rows, neighbors.n_cols);
+
+ // Map the points back to their original locations.
+ if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
+ Unmap(neighbors, distances, oldFromNewRefs, oldFromNewQueries, neighborsOut,
+ distancesOut);
+ else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
+ Unmap(neighbors, distances, oldFromNewRefs, neighborsOut, distancesOut);
+ else
+ Unmap(neighbors, distances, oldFromNewRefs, oldFromNewRefs, neighborsOut,
+ distancesOut);
+
+ // Clean up.
+ if (queryTree)
+ delete queryTree;
+
+ // Save output.
+ data::Save(distancesFile, distancesOut);
+ data::Save(neighborsFile, neighborsOut);
+
+ delete allkfn;
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/allknn_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/allknn_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/allknn_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,328 +0,0 @@
-/**
- * @file allknn_main.cpp
- * @author Ryan Curtin
- *
- * Implementation of the AllkNN executable. Allows some number of standard
- * options.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/core/tree/cover_tree.hpp>
-
-#include <string>
-#include <fstream>
-#include <iostream>
-
-#include "neighbor_search.hpp"
-#include "unmap.hpp"
-
-using namespace std;
-using namespace mlpack;
-using namespace mlpack::neighbor;
-using namespace mlpack::tree;
-
-// Information about the program itself.
-PROGRAM_INFO("All K-Nearest-Neighbors",
- "This program will calculate the all k-nearest-neighbors of a set of "
- "points using kd-trees or cover trees (cover tree support is experimental "
- "and may not be optimally fast). You may specify a separate set of "
- "reference points and query points, or just a reference set which will be "
- "used as both the reference and query set."
- "\n\n"
- "For example, the following will calculate the 5 nearest neighbors of each"
- "point in 'input.csv' and store the distances in 'distances.csv' and the "
- "neighbors in the file 'neighbors.csv':"
- "\n\n"
- "$ allknn --k=5 --reference_file=input.csv --distances_file=distances.csv\n"
- " --neighbors_file=neighbors.csv"
- "\n\n"
- "The output files are organized such that row i and column j in the "
- "neighbors output file corresponds to the index of the point in the "
- "reference set which is the i'th nearest neighbor from the point in the "
- "query set with index j. Row i and column j in the distances output file "
- "corresponds to the distance between those two points.");
-
-// Define our input parameters that this program will take.
-PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
- "r");
-PARAM_STRING_REQ("distances_file", "File to output distances into.", "d");
-PARAM_STRING_REQ("neighbors_file", "File to output neighbors into.", "n");
-
-PARAM_INT_REQ("k", "Number of furthest neighbors to find.", "k");
-
-PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
-
-PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
-PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
-PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
- "dual-tree search).", "s");
-PARAM_FLAG("cover_tree", "If true, use cover trees to perform the search "
- "(experimental, may be slow).", "c");
-PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
- "random orthogonal basis.", "R");
-PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
-
-int main(int argc, char *argv[])
-{
- // Give CLI the command line parameters the user passed in.
- CLI::ParseCommandLine(argc, argv);
-
- if (CLI::GetParam<int>("seed") != 0)
- math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
- else
- math::RandomSeed((size_t) std::time(NULL));
-
- // Get all the parameters.
- const string referenceFile = CLI::GetParam<string>("reference_file");
- const string queryFile = CLI::GetParam<string>("query_file");
-
- const string distancesFile = CLI::GetParam<string>("distances_file");
- const string neighborsFile = CLI::GetParam<string>("neighbors_file");
-
- int lsInt = CLI::GetParam<int>("leaf_size");
-
- size_t k = CLI::GetParam<int>("k");
-
- bool naive = CLI::HasParam("naive");
- bool singleMode = CLI::HasParam("single_mode");
- const bool randomBasis = CLI::HasParam("random_basis");
-
- arma::mat referenceData;
- arma::mat queryData; // So it doesn't go out of scope.
- data::Load(referenceFile, referenceData, true);
-
- Log::Info << "Loaded reference data from '" << referenceFile << "' ("
- << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
-
- if (queryFile != "")
- {
- data::Load(queryFile, queryData, true);
- Log::Info << "Loaded query data from '" << queryFile << "' ("
- << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
- }
-
- // Sanity check on k value: must be greater than 0, must be less than the
- // number of reference points.
- if (k > referenceData.n_cols)
- {
- Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
- Log::Fatal << "than or equal to the number of reference points (";
- Log::Fatal << referenceData.n_cols << ")." << endl;
- }
-
- // Sanity check on leaf size.
- if (lsInt < 0)
- {
- Log::Fatal << "Invalid leaf size: " << lsInt << ". Must be greater "
- "than or equal to 0." << endl;
- }
- size_t leafSize = lsInt;
-
- // Naive mode overrides single mode.
- if (singleMode && naive)
- {
- Log::Warn << "--single_mode ignored because --naive is present." << endl;
- }
-
- if (naive)
- leafSize = referenceData.n_cols;
-
- // See if we want to project onto a random basis.
- if (randomBasis)
- {
- // Generate the random basis.
- while (true)
- {
- // [Q, R] = qr(randn(d, d));
- // Q = Q * diag(sign(diag(R)));
- arma::mat q, r;
- if (arma::qr(q, r, arma::randn<arma::mat>(referenceData.n_rows,
- referenceData.n_rows)))
- {
- arma::vec rDiag(r.n_rows);
- for (size_t i = 0; i < rDiag.n_elem; ++i)
- {
- if (r(i, i) < 0)
- rDiag(i) = -1;
- else if (r(i, i) > 0)
- rDiag(i) = 1;
- else
- rDiag(i) = 0;
- }
-
- q *= arma::diagmat(rDiag);
-
- // Check if the determinant is positive.
- if (arma::det(q) >= 0)
- {
- referenceData = q * referenceData;
- if (queryFile != "")
- queryData = q * queryData;
- break;
- }
- }
- }
- }
-
- arma::Mat<size_t> neighbors;
- arma::mat distances;
-
- if (!CLI::HasParam("cover_tree"))
- {
- // Because we may construct it differently, we need a pointer.
- AllkNN* allknn = NULL;
-
- // Mappings for when we build the tree.
- std::vector<size_t> oldFromNewRefs;
-
- // Build trees by hand, so we can save memory: if we pass a tree to
- // NeighborSearch, it does not copy the matrix.
- Log::Info << "Building reference tree..." << endl;
- Timer::Start("tree_building");
-
- BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >
- refTree(referenceData, oldFromNewRefs, leafSize);
- BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >*
- queryTree = NULL; // Empty for now.
-
- Timer::Stop("tree_building");
-
- std::vector<size_t> oldFromNewQueries;
-
- if (CLI::GetParam<string>("query_file") != "")
- {
- if (naive && leafSize < queryData.n_cols)
- leafSize = queryData.n_cols;
-
- Log::Info << "Loaded query data from '" << queryFile << "' ("
- << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
-
- Log::Info << "Building query tree..." << endl;
-
- // Build trees by hand, so we can save memory: if we pass a tree to
- // NeighborSearch, it does not copy the matrix.
- if (!singleMode)
- {
- Timer::Start("tree_building");
-
- queryTree = new BinarySpaceTree<bound::HRectBound<2>,
- QueryStat<NearestNeighborSort> >(queryData, oldFromNewQueries,
- leafSize);
-
- Timer::Stop("tree_building");
- }
-
- allknn = new AllkNN(&refTree, queryTree, referenceData, queryData,
- singleMode);
-
- Log::Info << "Tree built." << endl;
- }
- else
- {
- allknn = new AllkNN(&refTree, referenceData, singleMode);
-
- Log::Info << "Trees built." << endl;
- }
-
- arma::mat distancesOut;
- arma::Mat<size_t> neighborsOut;
-
- Log::Info << "Computing " << k << " nearest neighbors..." << endl;
- allknn->Search(k, neighborsOut, distancesOut);
-
- Log::Info << "Neighbors computed." << endl;
-
- // We have to map back to the original indices from before the tree
- // construction.
- Log::Info << "Re-mapping indices..." << endl;
-
- // Map the results back to the correct places.
- if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
- Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewQueries,
- neighbors, distances);
- else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
- Unmap(neighborsOut, distancesOut, oldFromNewRefs, neighbors, distances);
- else
- Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewRefs,
- neighbors, distances);
-
- // Clean up.
- if (queryTree)
- delete queryTree;
-
- delete allknn;
- }
- else // Cover trees.
- {
- // Make sure to notify the user that they are using cover trees.
- Log::Info << "Using cover trees for nearest-neighbor calculation." << endl;
-
- // Build our reference tree.
- Log::Info << "Building reference tree..." << endl;
- Timer::Start("tree_building");
- CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> > referenceTree(referenceData, 1.3);
- CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> >* queryTree = NULL;
- Timer::Stop("tree_building");
-
- NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
- CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> > >* allknn = NULL;
-
- // See if we have query data.
- if (CLI::HasParam("query_file"))
- {
- // Build query tree.
- if (!singleMode)
- {
- Log::Info << "Building query tree..." << endl;
- Timer::Start("tree_building");
- queryTree = new CoverTree<metric::LMetric<2, true>,
- tree::FirstPointIsRoot, QueryStat<NearestNeighborSort> >(queryData,
- 1.3);
- Timer::Stop("tree_building");
- }
-
- allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
- CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> > >(&referenceTree, queryTree,
- referenceData, queryData, singleMode);
- }
- else
- {
- allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
- CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> > >(&referenceTree, referenceData,
- singleMode);
- }
-
- Log::Info << "Computing " << k << " nearest neighbors..." << endl;
- allknn->Search(k, neighbors, distances);
-
- Log::Info << "Neighbors computed." << endl;
-
- delete allknn;
-
- if (queryTree)
- delete queryTree;
- }
-
- // Save output.
- data::Save(distancesFile, distances);
- data::Save(neighborsFile, neighbors);
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/allknn_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/allknn_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/allknn_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/allknn_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,328 @@
+/**
+ * @file allknn_main.cpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the AllkNN executable. Allows some number of standard
+ * options.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
+
+#include <string>
+#include <fstream>
+#include <iostream>
+
+#include "neighbor_search.hpp"
+#include "unmap.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::neighbor;
+using namespace mlpack::tree;
+
+// Information about the program itself.
+PROGRAM_INFO("All K-Nearest-Neighbors",
+ "This program will calculate the all k-nearest-neighbors of a set of "
+ "points using kd-trees or cover trees (cover tree support is experimental "
+ "and may not be optimally fast). You may specify a separate set of "
+ "reference points and query points, or just a reference set which will be "
+ "used as both the reference and query set."
+ "\n\n"
+ "For example, the following will calculate the 5 nearest neighbors of each"
+ "point in 'input.csv' and store the distances in 'distances.csv' and the "
+ "neighbors in the file 'neighbors.csv':"
+ "\n\n"
+ "$ allknn --k=5 --reference_file=input.csv --distances_file=distances.csv\n"
+ " --neighbors_file=neighbors.csv"
+ "\n\n"
+ "The output files are organized such that row i and column j in the "
+ "neighbors output file corresponds to the index of the point in the "
+ "reference set which is the i'th nearest neighbor from the point in the "
+ "query set with index j. Row i and column j in the distances output file "
+ "corresponds to the distance between those two points.");
+
+// Define our input parameters that this program will take.
+PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
+ "r");
+PARAM_STRING_REQ("distances_file", "File to output distances into.", "d");
+PARAM_STRING_REQ("neighbors_file", "File to output neighbors into.", "n");
+
+PARAM_INT_REQ("k", "Number of furthest neighbors to find.", "k");
+
+PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
+
+PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
+PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
+PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
+ "dual-tree search).", "s");
+PARAM_FLAG("cover_tree", "If true, use cover trees to perform the search "
+ "(experimental, may be slow).", "c");
+PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
+ "random orthogonal basis.", "R");
+PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
+
+int main(int argc, char *argv[])
+{
+ // Give CLI the command line parameters the user passed in.
+ CLI::ParseCommandLine(argc, argv);
+
+ if (CLI::GetParam<int>("seed") != 0)
+ math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
+ else
+ math::RandomSeed((size_t) std::time(NULL));
+
+ // Get all the parameters.
+ const string referenceFile = CLI::GetParam<string>("reference_file");
+ const string queryFile = CLI::GetParam<string>("query_file");
+
+ const string distancesFile = CLI::GetParam<string>("distances_file");
+ const string neighborsFile = CLI::GetParam<string>("neighbors_file");
+
+ int lsInt = CLI::GetParam<int>("leaf_size");
+
+ size_t k = CLI::GetParam<int>("k");
+
+ bool naive = CLI::HasParam("naive");
+ bool singleMode = CLI::HasParam("single_mode");
+ const bool randomBasis = CLI::HasParam("random_basis");
+
+ arma::mat referenceData;
+ arma::mat queryData; // So it doesn't go out of scope.
+ data::Load(referenceFile, referenceData, true);
+
+ Log::Info << "Loaded reference data from '" << referenceFile << "' ("
+ << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
+
+ if (queryFile != "")
+ {
+ data::Load(queryFile, queryData, true);
+ Log::Info << "Loaded query data from '" << queryFile << "' ("
+ << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+ }
+
+ // Sanity check on k value: must be greater than 0, must be less than the
+ // number of reference points.
+ if (k > referenceData.n_cols)
+ {
+ Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
+ Log::Fatal << "than or equal to the number of reference points (";
+ Log::Fatal << referenceData.n_cols << ")." << endl;
+ }
+
+ // Sanity check on leaf size.
+ if (lsInt < 0)
+ {
+ Log::Fatal << "Invalid leaf size: " << lsInt << ". Must be greater "
+ "than or equal to 0." << endl;
+ }
+ size_t leafSize = lsInt;
+
+ // Naive mode overrides single mode.
+ if (singleMode && naive)
+ {
+ Log::Warn << "--single_mode ignored because --naive is present." << endl;
+ }
+
+ if (naive)
+ leafSize = referenceData.n_cols;
+
+ // See if we want to project onto a random basis.
+ if (randomBasis)
+ {
+ // Generate the random basis.
+ while (true)
+ {
+ // [Q, R] = qr(randn(d, d));
+ // Q = Q * diag(sign(diag(R)));
+ arma::mat q, r;
+ if (arma::qr(q, r, arma::randn<arma::mat>(referenceData.n_rows,
+ referenceData.n_rows)))
+ {
+ arma::vec rDiag(r.n_rows);
+ for (size_t i = 0; i < rDiag.n_elem; ++i)
+ {
+ if (r(i, i) < 0)
+ rDiag(i) = -1;
+ else if (r(i, i) > 0)
+ rDiag(i) = 1;
+ else
+ rDiag(i) = 0;
+ }
+
+ q *= arma::diagmat(rDiag);
+
+ // Check if the determinant is positive.
+ if (arma::det(q) >= 0)
+ {
+ referenceData = q * referenceData;
+ if (queryFile != "")
+ queryData = q * queryData;
+ break;
+ }
+ }
+ }
+ }
+
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ if (!CLI::HasParam("cover_tree"))
+ {
+ // Because we may construct it differently, we need a pointer.
+ AllkNN* allknn = NULL;
+
+ // Mappings for when we build the tree.
+ std::vector<size_t> oldFromNewRefs;
+
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ Log::Info << "Building reference tree..." << endl;
+ Timer::Start("tree_building");
+
+ BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >
+ refTree(referenceData, oldFromNewRefs, leafSize);
+ BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >*
+ queryTree = NULL; // Empty for now.
+
+ Timer::Stop("tree_building");
+
+ std::vector<size_t> oldFromNewQueries;
+
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ if (naive && leafSize < queryData.n_cols)
+ leafSize = queryData.n_cols;
+
+ Log::Info << "Loaded query data from '" << queryFile << "' ("
+ << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+
+ Log::Info << "Building query tree..." << endl;
+
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ if (!singleMode)
+ {
+ Timer::Start("tree_building");
+
+ queryTree = new BinarySpaceTree<bound::HRectBound<2>,
+ QueryStat<NearestNeighborSort> >(queryData, oldFromNewQueries,
+ leafSize);
+
+ Timer::Stop("tree_building");
+ }
+
+ allknn = new AllkNN(&refTree, queryTree, referenceData, queryData,
+ singleMode);
+
+ Log::Info << "Tree built." << endl;
+ }
+ else
+ {
+ allknn = new AllkNN(&refTree, referenceData, singleMode);
+
+ Log::Info << "Trees built." << endl;
+ }
+
+ arma::mat distancesOut;
+ arma::Mat<size_t> neighborsOut;
+
+ Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+ allknn->Search(k, neighborsOut, distancesOut);
+
+ Log::Info << "Neighbors computed." << endl;
+
+ // We have to map back to the original indices from before the tree
+ // construction.
+ Log::Info << "Re-mapping indices..." << endl;
+
+ // Map the results back to the correct places.
+ if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
+ Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewQueries,
+ neighbors, distances);
+ else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
+ Unmap(neighborsOut, distancesOut, oldFromNewRefs, neighbors, distances);
+ else
+ Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewRefs,
+ neighbors, distances);
+
+ // Clean up.
+ if (queryTree)
+ delete queryTree;
+
+ delete allknn;
+ }
+ else // Cover trees.
+ {
+ // Make sure to notify the user that they are using cover trees.
+ Log::Info << "Using cover trees for nearest-neighbor calculation." << endl;
+
+ // Build our reference tree.
+ Log::Info << "Building reference tree..." << endl;
+ Timer::Start("tree_building");
+ CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> > referenceTree(referenceData, 1.3);
+ CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> >* queryTree = NULL;
+ Timer::Stop("tree_building");
+
+ NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+ CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> > >* allknn = NULL;
+
+ // See if we have query data.
+ if (CLI::HasParam("query_file"))
+ {
+ // Build query tree.
+ if (!singleMode)
+ {
+ Log::Info << "Building query tree..." << endl;
+ Timer::Start("tree_building");
+ queryTree = new CoverTree<metric::LMetric<2, true>,
+ tree::FirstPointIsRoot, QueryStat<NearestNeighborSort> >(queryData,
+ 1.3);
+ Timer::Stop("tree_building");
+ }
+
+ allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+ CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> > >(&referenceTree, queryTree,
+ referenceData, queryData, singleMode);
+ }
+ else
+ {
+ allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+ CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> > >(&referenceTree, referenceData,
+ singleMode);
+ }
+
+ Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+ allknn->Search(k, neighbors, distances);
+
+ Log::Info << "Neighbors computed." << endl;
+
+ delete allknn;
+
+ if (queryTree)
+ delete queryTree;
+ }
+
+ // Save output.
+ data::Save(distancesFile, distances);
+ data::Save(neighborsFile, neighbors);
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/neighbor_search.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,311 +0,0 @@
-/**
- * @file neighbor_search.hpp
- * @author Ryan Curtin
- *
- * Defines the NeighborSearch class, which performs an abstract
- * nearest-neighbor-like query on two datasets.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
-#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
-
-#include <mlpack/core.hpp>
-#include <vector>
-#include <string>
-
-#include <mlpack/core/tree/binary_space_tree.hpp>
-
-#include <mlpack/core/metrics/lmetric.hpp>
-#include "sort_policies/nearest_neighbor_sort.hpp"
-
-namespace mlpack {
-namespace neighbor /** Neighbor-search routines. These include
- * all-nearest-neighbors and all-furthest-neighbors
- * searches. */ {
-
-/**
- * Extra data for each node in the tree. For neighbor searches, each node only
- * needs to store a bound on neighbor distances.
- */
-template<typename SortPolicy>
-class QueryStat
-{
- private:
- //! The first bound on the node's neighbor distances (B_1). This represents
- //! the worst candidate distance of any descendants of this node.
- double firstBound;
- //! The second bound on the node's neighbor distances (B_2). This represents
- //! a bound on the worst distance of any descendants of this node assembled
- //! using the best descendant candidate distance modified by the furthest
- //! descendant distance.
- double secondBound;
- //! The better of the two bounds.
- double bound;
-
- public:
- /**
- * Initialize the statistic with the worst possible distance according to
- * our sorting policy.
- */
- QueryStat() :
- firstBound(SortPolicy::WorstDistance()),
- secondBound(SortPolicy::WorstDistance()),
- bound(SortPolicy::WorstDistance()) { }
-
- /**
- * Initialization for a fully initialized node. In this case, we don't need
- * to worry about the node.
- */
- template<typename TreeType>
- QueryStat(TreeType& /* node */) :
- firstBound(SortPolicy::WorstDistance()),
- secondBound(SortPolicy::WorstDistance()),
- bound(SortPolicy::WorstDistance()) { }
-
- //! Get the first bound.
- double FirstBound() const { return firstBound; }
- //! Modify the first bound.
- double& FirstBound() { return firstBound; }
- //! Get the second bound.
- double SecondBound() const { return secondBound; }
- //! Modify the second bound.
- double& SecondBound() { return secondBound; }
- //! Get the overall bound (the better of the two bounds).
- double Bound() const { return bound; }
- //! Modify the overall bound (it should be the better of the two bounds).
- double& Bound() { return bound; }
-};
-
-/**
- * The NeighborSearch class is a template class for performing distance-based
- * neighbor searches. It takes a query dataset and a reference dataset (or just
- * a reference dataset) and, for each point in the query dataset, finds the k
- * neighbors in the reference dataset which have the 'best' distance according
- * to a given sorting policy. A constructor is given which takes only a
- * reference dataset, and if that constructor is used, the given reference
- * dataset is also used as the query dataset.
- *
- * The template parameters SortPolicy and Metric define the sort function used
- * and the metric (distance function) used. More information on those classes
- * can be found in the NearestNeighborSort class and the kernel::ExampleKernel
- * class.
- *
- * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
- * @tparam MetricType The metric to use for computation.
- * @tparam TreeType The tree type to use.
- */
-template<typename SortPolicy = NearestNeighborSort,
- typename MetricType = mlpack::metric::SquaredEuclideanDistance,
- typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
- QueryStat<SortPolicy> > >
-class NeighborSearch
-{
- public:
- /**
- * Initialize the NeighborSearch object, passing both a query and reference
- * dataset. Optionally, perform the computation in naive mode or single-tree
- * mode, and set the leaf size used for tree-building. An initialized
- * distance metric can be given, for cases where the metric has internal data
- * (i.e. the distance::MahalanobisDistance class).
- *
- * This method will copy the matrices to internal copies, which are rearranged
- * during tree-building. You can avoid this extra copy by pre-constructing
- * the trees and passing them using a diferent constructor.
- *
- * @param referenceSet Set of reference points.
- * @param querySet Set of query points.
- * @param naive If true, O(n^2) naive search will be used (as opposed to
- * dual-tree search). This overrides singleMode (if it is set to true).
- * @param singleMode If true, single-tree search will be used (as opposed to
- * dual-tree search).
- * @param leafSize Leaf size for tree construction (ignored if tree is given).
- * @param metric An optional instance of the MetricType class.
- */
- NeighborSearch(const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
- const bool naive = false,
- const bool singleMode = false,
- const size_t leafSize = 20,
- const MetricType metric = MetricType());
-
- /**
- * Initialize the NeighborSearch object, passing only one dataset, which is
- * used as both the query and the reference dataset. Optionally, perform the
- * computation in naive mode or single-tree mode, and set the leaf size used
- * for tree-building. An initialized distance metric can be given, for cases
- * where the metric has internal data (i.e. the distance::MahalanobisDistance
- * class).
- *
- * If naive mode is being used and a pre-built tree is given, it may not work:
- * naive mode operates by building a one-node tree (the root node holds all
- * the points). If that condition is not satisfied with the pre-built tree,
- * then naive mode will not work.
- *
- * @param referenceSet Set of reference points.
- * @param naive If true, O(n^2) naive search will be used (as opposed to
- * dual-tree search). This overrides singleMode (if it is set to true).
- * @param singleMode If true, single-tree search will be used (as opposed to
- * dual-tree search).
- * @param leafSize Leaf size for tree construction (ignored if tree is given).
- * @param metric An optional instance of the MetricType class.
- */
- NeighborSearch(const typename TreeType::Mat& referenceSet,
- const bool naive = false,
- const bool singleMode = false,
- const size_t leafSize = 20,
- const MetricType metric = MetricType());
-
- /**
- * Initialize the NeighborSearch object with the given datasets and
- * pre-constructed trees. It is assumed that the points in referenceSet and
- * querySet correspond to the points in referenceTree and queryTree,
- * respectively. Optionally, choose to use single-tree mode. Naive mode is
- * not available as an option for this constructor; instead, to run naive
- * computation, construct a tree with all of the points in one leaf (i.e.
- * leafSize = number of points). Additionally, an instantiated distance
- * metric can be given, for cases where the distance metric holds data.
- *
- * There is no copying of the data matrices in this constructor (because
- * tree-building is not necessary), so this is the constructor to use when
- * copies absolutely must be avoided.
- *
- * @note
- * Because tree-building (at least with BinarySpaceTree) modifies the ordering
- * of a matrix, be sure you pass the modified matrix to this object! In
- * addition, mapping the points of the matrix back to their original indices
- * is not done when this constructor is used.
- * @endnote
- *
- * @param referenceTree Pre-built tree for reference points.
- * @param queryTree Pre-built tree for query points.
- * @param referenceSet Set of reference points corresponding to referenceTree.
- * @param querySet Set of query points corresponding to queryTree.
- * @param singleMode Whether single-tree computation should be used (as
- * opposed to dual-tree computation).
- * @param metric Instantiated distance metric.
- */
- NeighborSearch(TreeType* referenceTree,
- TreeType* queryTree,
- const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
- const bool singleMode = false,
- const MetricType metric = MetricType());
-
- /**
- * Initialize the NeighborSearch object with the given reference dataset and
- * pre-constructed tree. It is assumed that the points in referenceSet
- * correspond to the points in referenceTree. Optionally, choose to use
- * single-tree mode. Naive mode is not available as an option for this
- * constructor; instead, to run naive computation, construct a tree with all
- * the points in one leaf (i.e. leafSize = number of points). Additionally,
- * an instantiated distance metric can be given, for the case where the
- * distance metric holds data.
- *
- * There is no copying of the data matrices in this constructor (because
- * tree-building is not necessary), so this is the constructor to use when
- * copies absolutely must be avoided.
- *
- * @note
- * Because tree-building (at least with BinarySpaceTree) modifies the ordering
- * of a matrix, be sure you pass the modified matrix to this object! In
- * addition, mapping the points of the matrix back to their original indices
- * is not done when this constructor is used.
- * @endnote
- *
- * @param referenceTree Pre-built tree for reference points.
- * @param referenceSet Set of reference points corresponding to referenceTree.
- * @param singleMode Whether single-tree computation should be used (as
- * opposed to dual-tree computation).
- * @param metric Instantiated distance metric.
- */
- NeighborSearch(TreeType* referenceTree,
- const typename TreeType::Mat& referenceSet,
- const bool singleMode = false,
- const MetricType metric = MetricType());
-
-
- /**
- * Delete the NeighborSearch object. The tree is the only member we are
- * responsible for deleting. The others will take care of themselves.
- */
- ~NeighborSearch();
-
- /**
- * Compute the nearest neighbors and store the output in the given matrices.
- * The matrices will be set to the size of n columns by k rows, where n is the
- * number of points in the query dataset and k is the number of neighbors
- * being searched for.
- *
- * @param k Number of neighbors to search for.
- * @param resultingNeighbors Matrix storing lists of neighbors for each query
- * point.
- * @param distances Matrix storing distances of neighbors for each query
- * point.
- */
- void Search(const size_t k,
- arma::Mat<size_t>& resultingNeighbors,
- arma::mat& distances);
-
- private:
- //! Copy of reference dataset (if we need it, because tree building modifies
- //! it).
- arma::mat referenceCopy;
- //! Copy of query dataset (if we need it, because tree building modifies it).
- arma::mat queryCopy;
-
- //! Reference dataset.
- const arma::mat& referenceSet;
- //! Query dataset (may not be given).
- const arma::mat& querySet;
-
- //! Pointer to the root of the reference tree.
- TreeType* referenceTree;
- //! Pointer to the root of the query tree (might not exist).
- TreeType* queryTree;
-
- //! Indicates if we should free the reference tree at deletion time.
- bool ownReferenceTree;
- //! Indicates if we should free the query tree at deletion time.
- bool ownQueryTree;
-
- //! Indicates if O(n^2) naive search is being used.
- bool naive;
- //! Indicates if single-tree search is being used (opposed to dual-tree).
- bool singleMode;
-
- //! Instantiation of kernel.
- MetricType metric;
-
- //! Permutations of reference points during tree building.
- std::vector<size_t> oldFromNewReferences;
- //! Permutations of query points during tree building.
- std::vector<size_t> oldFromNewQueries;
-
- //! Total number of pruned nodes during the neighbor search.
- size_t numberOfPrunes;
-}; // class NeighborSearch
-
-}; // namespace neighbor
-}; // namespace mlpack
-
-// Include implementation.
-#include "neighbor_search_impl.hpp"
-
-// Include convenience typedefs.
-#include "typedef.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/neighbor_search.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,311 @@
+/**
+ * @file neighbor_search.hpp
+ * @author Ryan Curtin
+ *
+ * Defines the NeighborSearch class, which performs an abstract
+ * nearest-neighbor-like query on two datasets.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
+
+#include <mlpack/core.hpp>
+#include <vector>
+#include <string>
+
+#include <mlpack/core/tree/binary_space_tree.hpp>
+
+#include <mlpack/core/metrics/lmetric.hpp>
+#include "sort_policies/nearest_neighbor_sort.hpp"
+
+namespace mlpack {
+namespace neighbor /** Neighbor-search routines. These include
+ * all-nearest-neighbors and all-furthest-neighbors
+ * searches. */ {
+
+/**
+ * Extra data for each node in the tree. For neighbor searches, each node only
+ * needs to store a bound on neighbor distances.
+ */
+template<typename SortPolicy>
+class QueryStat
+{
+ private:
+ //! The first bound on the node's neighbor distances (B_1). This represents
+ //! the worst candidate distance of any descendants of this node.
+ double firstBound;
+ //! The second bound on the node's neighbor distances (B_2). This represents
+ //! a bound on the worst distance of any descendants of this node assembled
+ //! using the best descendant candidate distance modified by the furthest
+ //! descendant distance.
+ double secondBound;
+ //! The better of the two bounds.
+ double bound;
+
+ public:
+ /**
+ * Initialize the statistic with the worst possible distance according to
+ * our sorting policy.
+ */
+ QueryStat() :
+ firstBound(SortPolicy::WorstDistance()),
+ secondBound(SortPolicy::WorstDistance()),
+ bound(SortPolicy::WorstDistance()) { }
+
+ /**
+ * Initialization for a fully initialized node. In this case, we don't need
+ * to worry about the node.
+ */
+ template<typename TreeType>
+ QueryStat(TreeType& /* node */) :
+ firstBound(SortPolicy::WorstDistance()),
+ secondBound(SortPolicy::WorstDistance()),
+ bound(SortPolicy::WorstDistance()) { }
+
+ //! Get the first bound.
+ double FirstBound() const { return firstBound; }
+ //! Modify the first bound.
+ double& FirstBound() { return firstBound; }
+ //! Get the second bound.
+ double SecondBound() const { return secondBound; }
+ //! Modify the second bound.
+ double& SecondBound() { return secondBound; }
+ //! Get the overall bound (the better of the two bounds).
+ double Bound() const { return bound; }
+ //! Modify the overall bound (it should be the better of the two bounds).
+ double& Bound() { return bound; }
+};
+
+/**
+ * The NeighborSearch class is a template class for performing distance-based
+ * neighbor searches. It takes a query dataset and a reference dataset (or just
+ * a reference dataset) and, for each point in the query dataset, finds the k
+ * neighbors in the reference dataset which have the 'best' distance according
+ * to a given sorting policy. A constructor is given which takes only a
+ * reference dataset, and if that constructor is used, the given reference
+ * dataset is also used as the query dataset.
+ *
+ * The template parameters SortPolicy and Metric define the sort function used
+ * and the metric (distance function) used. More information on those classes
+ * can be found in the NearestNeighborSort class and the kernel::ExampleKernel
+ * class.
+ *
+ * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
+ * @tparam MetricType The metric to use for computation.
+ * @tparam TreeType The tree type to use.
+ */
+template<typename SortPolicy = NearestNeighborSort,
+ typename MetricType = mlpack::metric::SquaredEuclideanDistance,
+ typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
+ QueryStat<SortPolicy> > >
+class NeighborSearch
+{
+ public:
+ /**
+ * Initialize the NeighborSearch object, passing both a query and reference
+ * dataset. Optionally, perform the computation in naive mode or single-tree
+ * mode, and set the leaf size used for tree-building. An initialized
+ * distance metric can be given, for cases where the metric has internal data
+ * (i.e. the distance::MahalanobisDistance class).
+ *
+ * This method will copy the matrices to internal copies, which are rearranged
+ * during tree-building. You can avoid this extra copy by pre-constructing
+ * the trees and passing them using a diferent constructor.
+ *
+ * @param referenceSet Set of reference points.
+ * @param querySet Set of query points.
+ * @param naive If true, O(n^2) naive search will be used (as opposed to
+ * dual-tree search). This overrides singleMode (if it is set to true).
+ * @param singleMode If true, single-tree search will be used (as opposed to
+ * dual-tree search).
+ * @param leafSize Leaf size for tree construction (ignored if tree is given).
+ * @param metric An optional instance of the MetricType class.
+ */
+ NeighborSearch(const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
+ const bool naive = false,
+ const bool singleMode = false,
+ const size_t leafSize = 20,
+ const MetricType metric = MetricType());
+
+ /**
+ * Initialize the NeighborSearch object, passing only one dataset, which is
+ * used as both the query and the reference dataset. Optionally, perform the
+ * computation in naive mode or single-tree mode, and set the leaf size used
+ * for tree-building. An initialized distance metric can be given, for cases
+ * where the metric has internal data (i.e. the distance::MahalanobisDistance
+ * class).
+ *
+ * If naive mode is being used and a pre-built tree is given, it may not work:
+ * naive mode operates by building a one-node tree (the root node holds all
+ * the points). If that condition is not satisfied with the pre-built tree,
+ * then naive mode will not work.
+ *
+ * @param referenceSet Set of reference points.
+ * @param naive If true, O(n^2) naive search will be used (as opposed to
+ * dual-tree search). This overrides singleMode (if it is set to true).
+ * @param singleMode If true, single-tree search will be used (as opposed to
+ * dual-tree search).
+ * @param leafSize Leaf size for tree construction (ignored if tree is given).
+ * @param metric An optional instance of the MetricType class.
+ */
+ NeighborSearch(const typename TreeType::Mat& referenceSet,
+ const bool naive = false,
+ const bool singleMode = false,
+ const size_t leafSize = 20,
+ const MetricType metric = MetricType());
+
+ /**
+ * Initialize the NeighborSearch object with the given datasets and
+ * pre-constructed trees. It is assumed that the points in referenceSet and
+ * querySet correspond to the points in referenceTree and queryTree,
+ * respectively. Optionally, choose to use single-tree mode. Naive mode is
+ * not available as an option for this constructor; instead, to run naive
+ * computation, construct a tree with all of the points in one leaf (i.e.
+ * leafSize = number of points). Additionally, an instantiated distance
+ * metric can be given, for cases where the distance metric holds data.
+ *
+ * There is no copying of the data matrices in this constructor (because
+ * tree-building is not necessary), so this is the constructor to use when
+ * copies absolutely must be avoided.
+ *
+ * @note
+ * Because tree-building (at least with BinarySpaceTree) modifies the ordering
+ * of a matrix, be sure you pass the modified matrix to this object! In
+ * addition, mapping the points of the matrix back to their original indices
+ * is not done when this constructor is used.
+ * @endnote
+ *
+ * @param referenceTree Pre-built tree for reference points.
+ * @param queryTree Pre-built tree for query points.
+ * @param referenceSet Set of reference points corresponding to referenceTree.
+ * @param querySet Set of query points corresponding to queryTree.
+ * @param singleMode Whether single-tree computation should be used (as
+ * opposed to dual-tree computation).
+ * @param metric Instantiated distance metric.
+ */
+ NeighborSearch(TreeType* referenceTree,
+ TreeType* queryTree,
+ const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
+ const bool singleMode = false,
+ const MetricType metric = MetricType());
+
+ /**
+ * Initialize the NeighborSearch object with the given reference dataset and
+ * pre-constructed tree. It is assumed that the points in referenceSet
+ * correspond to the points in referenceTree. Optionally, choose to use
+ * single-tree mode. Naive mode is not available as an option for this
+ * constructor; instead, to run naive computation, construct a tree with all
+ * the points in one leaf (i.e. leafSize = number of points). Additionally,
+ * an instantiated distance metric can be given, for the case where the
+ * distance metric holds data.
+ *
+ * There is no copying of the data matrices in this constructor (because
+ * tree-building is not necessary), so this is the constructor to use when
+ * copies absolutely must be avoided.
+ *
+ * @note
+ * Because tree-building (at least with BinarySpaceTree) modifies the ordering
+ * of a matrix, be sure you pass the modified matrix to this object! In
+ * addition, mapping the points of the matrix back to their original indices
+ * is not done when this constructor is used.
+ * @endnote
+ *
+ * @param referenceTree Pre-built tree for reference points.
+ * @param referenceSet Set of reference points corresponding to referenceTree.
+ * @param singleMode Whether single-tree computation should be used (as
+ * opposed to dual-tree computation).
+ * @param metric Instantiated distance metric.
+ */
+ NeighborSearch(TreeType* referenceTree,
+ const typename TreeType::Mat& referenceSet,
+ const bool singleMode = false,
+ const MetricType metric = MetricType());
+
+
+ /**
+ * Delete the NeighborSearch object. The tree is the only member we are
+ * responsible for deleting. The others will take care of themselves.
+ */
+ ~NeighborSearch();
+
+ /**
+ * Compute the nearest neighbors and store the output in the given matrices.
+ * The matrices will be set to the size of n columns by k rows, where n is the
+ * number of points in the query dataset and k is the number of neighbors
+ * being searched for.
+ *
+ * @param k Number of neighbors to search for.
+ * @param resultingNeighbors Matrix storing lists of neighbors for each query
+ * point.
+ * @param distances Matrix storing distances of neighbors for each query
+ * point.
+ */
+ void Search(const size_t k,
+ arma::Mat<size_t>& resultingNeighbors,
+ arma::mat& distances);
+
+ private:
+ //! Copy of reference dataset (if we need it, because tree building modifies
+ //! it).
+ arma::mat referenceCopy;
+ //! Copy of query dataset (if we need it, because tree building modifies it).
+ arma::mat queryCopy;
+
+ //! Reference dataset.
+ const arma::mat& referenceSet;
+ //! Query dataset (may not be given).
+ const arma::mat& querySet;
+
+ //! Pointer to the root of the reference tree.
+ TreeType* referenceTree;
+ //! Pointer to the root of the query tree (might not exist).
+ TreeType* queryTree;
+
+ //! Indicates if we should free the reference tree at deletion time.
+ bool ownReferenceTree;
+ //! Indicates if we should free the query tree at deletion time.
+ bool ownQueryTree;
+
+ //! Indicates if O(n^2) naive search is being used.
+ bool naive;
+ //! Indicates if single-tree search is being used (opposed to dual-tree).
+ bool singleMode;
+
+ //! Instantiation of kernel.
+ MetricType metric;
+
+ //! Permutations of reference points during tree building.
+ std::vector<size_t> oldFromNewReferences;
+ //! Permutations of query points during tree building.
+ std::vector<size_t> oldFromNewQueries;
+
+ //! Total number of pruned nodes during the neighbor search.
+ size_t numberOfPrunes;
+}; // class NeighborSearch
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+// Include implementation.
+#include "neighbor_search_impl.hpp"
+
+// Include convenience typedefs.
+#include "typedef.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,315 +0,0 @@
-/**
- * @file neighbor_search_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of Neighbor-Search class to perform all-nearest-neighbors on
- * two specified data sets.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_IMPL_HPP
-#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_IMPL_HPP
-
-#include <mlpack/core.hpp>
-
-#include "neighbor_search_rules.hpp"
-
-using namespace mlpack::neighbor;
-
-// Construct the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-NeighborSearch<SortPolicy, MetricType, TreeType>::
-NeighborSearch(const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
- const bool naive,
- const bool singleMode,
- const size_t leafSize,
- const MetricType metric) :
- referenceCopy(referenceSet),
- queryCopy(querySet),
- referenceSet(referenceCopy),
- querySet(queryCopy),
- referenceTree(NULL),
- queryTree(NULL),
- ownReferenceTree(true), // False if a tree was passed.
- ownQueryTree(true), // False if a tree was passed.
- naive(naive),
- singleMode(!naive && singleMode), // No single mode if naive.
- metric(metric),
- numberOfPrunes(0)
-{
- // C++11 will allow us to call out to other constructors so we can avoid this
- // copypasta problem.
-
- // We'll time tree building, but only if we are building trees.
- if (!referenceTree || !queryTree)
- Timer::Start("tree_building");
-
- // Construct as a naive object if we need to.
- referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
- (naive ? referenceCopy.n_cols : leafSize));
-
- queryTree = new TreeType(queryCopy, oldFromNewQueries,
- (naive ? querySet.n_cols : leafSize));
-
- // Stop the timer we started above (if we need to).
- if (!referenceTree || !queryTree)
- Timer::Stop("tree_building");
-}
-
-// Construct the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-NeighborSearch<SortPolicy, MetricType, TreeType>::
-NeighborSearch(const typename TreeType::Mat& referenceSet,
- const bool naive,
- const bool singleMode,
- const size_t leafSize,
- const MetricType metric) :
- referenceCopy(referenceSet),
- referenceSet(referenceCopy),
- querySet(referenceCopy),
- referenceTree(NULL),
- queryTree(NULL),
- ownReferenceTree(true),
- ownQueryTree(false), // Since it will be the same as referenceTree.
- naive(naive),
- singleMode(!naive && singleMode), // No single mode if naive.
- metric(metric),
- numberOfPrunes(0)
-{
- // We'll time tree building, but only if we are building trees.
- Timer::Start("tree_building");
-
- // Construct as a naive object if we need to.
- referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
- (naive ? referenceSet.n_cols : leafSize));
-
- // Stop the timer we started above.
- Timer::Stop("tree_building");
-}
-
-// Construct the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-NeighborSearch<SortPolicy, MetricType, TreeType>::NeighborSearch(
- TreeType* referenceTree,
- TreeType* queryTree,
- const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
- const bool singleMode,
- const MetricType metric) :
- referenceSet(referenceSet),
- querySet(querySet),
- referenceTree(referenceTree),
- queryTree(queryTree),
- ownReferenceTree(false),
- ownQueryTree(false),
- naive(false),
- singleMode(singleMode),
- metric(metric),
- numberOfPrunes(0)
-{
- // Nothing else to initialize.
-}
-
-// Construct the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-NeighborSearch<SortPolicy, MetricType, TreeType>::NeighborSearch(
- TreeType* referenceTree,
- const typename TreeType::Mat& referenceSet,
- const bool singleMode,
- const MetricType metric) :
- referenceSet(referenceSet),
- querySet(referenceSet),
- referenceTree(referenceTree),
- queryTree(NULL),
- ownReferenceTree(false),
- ownQueryTree(false),
- naive(false),
- singleMode(singleMode),
- metric(metric),
- numberOfPrunes(0)
-{
- // Nothing else to initialize.
-}
-
-/**
- * The tree is the only member we may be responsible for deleting. The others
- * will take care of themselves.
- */
-template<typename SortPolicy, typename MetricType, typename TreeType>
-NeighborSearch<SortPolicy, MetricType, TreeType>::~NeighborSearch()
-{
- if (ownReferenceTree)
- delete referenceTree;
- if (ownQueryTree)
- delete queryTree;
-}
-
-/**
- * Computes the best neighbors and stores them in resultingNeighbors and
- * distances.
- */
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
- const size_t k,
- arma::Mat<size_t>& resultingNeighbors,
- arma::mat& distances)
-{
- Timer::Start("computing_neighbors");
-
- // If we have built the trees ourselves, then we will have to map all the
- // indices back to their original indices when this computation is finished.
- // To avoid an extra copy, we will store the neighbors and distances in a
- // separate matrix.
- arma::Mat<size_t>* neighborPtr = &resultingNeighbors;
- arma::mat* distancePtr = &distances;
-
- if (ownQueryTree || (ownReferenceTree && !queryTree))
- distancePtr = new arma::mat; // Query indices need to be mapped.
- if (ownReferenceTree || ownQueryTree)
- neighborPtr = new arma::Mat<size_t>; // All indices need mapping.
-
- // Set the size of the neighbor and distance matrices.
- neighborPtr->set_size(k, querySet.n_cols);
- distancePtr->set_size(k, querySet.n_cols);
- distancePtr->fill(SortPolicy::WorstDistance());
-
- size_t numPrunes = 0;
-
- if (singleMode)
- {
- // Create the helper object for the tree traversal.
- typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
- RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric);
-
- // Create the traverser.
- typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
-
- // Now have it traverse for each point.
- for (size_t i = 0; i < querySet.n_cols; ++i)
- traverser.Traverse(i, *referenceTree);
-
- numPrunes = traverser.NumPrunes();
- }
- else // Dual-tree recursion.
- {
- // Create the helper object for the tree traversal.
- typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
- RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric);
-
- typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
-
- if (queryTree)
- traverser.Traverse(*queryTree, *referenceTree);
- else
- traverser.Traverse(*referenceTree, *referenceTree);
-
- numPrunes = traverser.NumPrunes();
- }
-
- Log::Debug << "Pruned " << numPrunes << " nodes." << std::endl;
-
- Timer::Stop("computing_neighbors");
-
- // Now, do we need to do mapping of indices?
- if (!ownReferenceTree && !ownQueryTree)
- {
- // No mapping needed. We are done.
- return;
- }
- else if (ownReferenceTree && ownQueryTree) // Map references and queries.
- {
- // Set size of output matrices correctly.
- resultingNeighbors.set_size(k, querySet.n_cols);
- distances.set_size(k, querySet.n_cols);
-
- for (size_t i = 0; i < distances.n_cols; i++)
- {
- // Map distances (copy a column).
- distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
-
- // Map indices of neighbors.
- for (size_t j = 0; j < distances.n_rows; j++)
- {
- resultingNeighbors(j, oldFromNewQueries[i]) =
- oldFromNewReferences[(*neighborPtr)(j, i)];
- }
- }
-
- // Finished with temporary matrices.
- delete neighborPtr;
- delete distancePtr;
- }
- else if (ownReferenceTree)
- {
- if (!queryTree) // No query tree -- map both references and queries.
- {
- resultingNeighbors.set_size(k, querySet.n_cols);
- distances.set_size(k, querySet.n_cols);
-
- for (size_t i = 0; i < distances.n_cols; i++)
- {
- // Map distances (copy a column).
- distances.col(oldFromNewReferences[i]) = distancePtr->col(i);
-
- // Map indices of neighbors.
- for (size_t j = 0; j < distances.n_rows; j++)
- {
- resultingNeighbors(j, oldFromNewReferences[i]) =
- oldFromNewReferences[(*neighborPtr)(j, i)];
- }
- }
- }
- else // Map only references.
- {
- // Set size of neighbor indices matrix correctly.
- resultingNeighbors.set_size(k, querySet.n_cols);
-
- // Map indices of neighbors.
- for (size_t i = 0; i < resultingNeighbors.n_cols; i++)
- {
- for (size_t j = 0; j < resultingNeighbors.n_rows; j++)
- {
- resultingNeighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
- }
- }
- }
-
- // Finished with temporary matrix.
- delete neighborPtr;
- }
- else if (ownQueryTree)
- {
- // Set size of matrices correctly.
- resultingNeighbors.set_size(k, querySet.n_cols);
- distances.set_size(k, querySet.n_cols);
-
- for (size_t i = 0; i < distances.n_cols; i++)
- {
- // Map distances (copy a column).
- distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
-
- // Map indices of neighbors.
- resultingNeighbors.col(oldFromNewQueries[i]) = neighborPtr->col(i);
- }
-
- // Finished with temporary matrices.
- delete neighborPtr;
- delete distancePtr;
- }
-} // Search
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,315 @@
+/**
+ * @file neighbor_search_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of Neighbor-Search class to perform all-nearest-neighbors on
+ * two specified data sets.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_IMPL_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_IMPL_HPP
+
+#include <mlpack/core.hpp>
+
+#include "neighbor_search_rules.hpp"
+
+using namespace mlpack::neighbor;
+
+// Construct the object.
+template<typename SortPolicy, typename MetricType, typename TreeType>
+NeighborSearch<SortPolicy, MetricType, TreeType>::
+NeighborSearch(const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
+ const bool naive,
+ const bool singleMode,
+ const size_t leafSize,
+ const MetricType metric) :
+ referenceCopy(referenceSet),
+ queryCopy(querySet),
+ referenceSet(referenceCopy),
+ querySet(queryCopy),
+ referenceTree(NULL),
+ queryTree(NULL),
+ ownReferenceTree(true), // False if a tree was passed.
+ ownQueryTree(true), // False if a tree was passed.
+ naive(naive),
+ singleMode(!naive && singleMode), // No single mode if naive.
+ metric(metric),
+ numberOfPrunes(0)
+{
+ // C++11 will allow us to call out to other constructors so we can avoid this
+ // copypasta problem.
+
+ // We'll time tree building, but only if we are building trees.
+ if (!referenceTree || !queryTree)
+ Timer::Start("tree_building");
+
+ // Construct as a naive object if we need to.
+ referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
+ (naive ? referenceCopy.n_cols : leafSize));
+
+ queryTree = new TreeType(queryCopy, oldFromNewQueries,
+ (naive ? querySet.n_cols : leafSize));
+
+ // Stop the timer we started above (if we need to).
+ if (!referenceTree || !queryTree)
+ Timer::Stop("tree_building");
+}
+
+// Construct the object.
+template<typename SortPolicy, typename MetricType, typename TreeType>
+NeighborSearch<SortPolicy, MetricType, TreeType>::
+NeighborSearch(const typename TreeType::Mat& referenceSet,
+ const bool naive,
+ const bool singleMode,
+ const size_t leafSize,
+ const MetricType metric) :
+ referenceCopy(referenceSet),
+ referenceSet(referenceCopy),
+ querySet(referenceCopy),
+ referenceTree(NULL),
+ queryTree(NULL),
+ ownReferenceTree(true),
+ ownQueryTree(false), // Since it will be the same as referenceTree.
+ naive(naive),
+ singleMode(!naive && singleMode), // No single mode if naive.
+ metric(metric),
+ numberOfPrunes(0)
+{
+ // We'll time tree building, but only if we are building trees.
+ Timer::Start("tree_building");
+
+ // Construct as a naive object if we need to.
+ referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
+ (naive ? referenceSet.n_cols : leafSize));
+
+ // Stop the timer we started above.
+ Timer::Stop("tree_building");
+}
+
+// Construct the object.
+template<typename SortPolicy, typename MetricType, typename TreeType>
+NeighborSearch<SortPolicy, MetricType, TreeType>::NeighborSearch(
+ TreeType* referenceTree,
+ TreeType* queryTree,
+ const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
+ const bool singleMode,
+ const MetricType metric) :
+ referenceSet(referenceSet),
+ querySet(querySet),
+ referenceTree(referenceTree),
+ queryTree(queryTree),
+ ownReferenceTree(false),
+ ownQueryTree(false),
+ naive(false),
+ singleMode(singleMode),
+ metric(metric),
+ numberOfPrunes(0)
+{
+ // Nothing else to initialize.
+}
+
+// Construct the object.
+template<typename SortPolicy, typename MetricType, typename TreeType>
+NeighborSearch<SortPolicy, MetricType, TreeType>::NeighborSearch(
+ TreeType* referenceTree,
+ const typename TreeType::Mat& referenceSet,
+ const bool singleMode,
+ const MetricType metric) :
+ referenceSet(referenceSet),
+ querySet(referenceSet),
+ referenceTree(referenceTree),
+ queryTree(NULL),
+ ownReferenceTree(false),
+ ownQueryTree(false),
+ naive(false),
+ singleMode(singleMode),
+ metric(metric),
+ numberOfPrunes(0)
+{
+ // Nothing else to initialize.
+}
+
+/**
+ * The tree is the only member we may be responsible for deleting. The others
+ * will take care of themselves.
+ */
+template<typename SortPolicy, typename MetricType, typename TreeType>
+NeighborSearch<SortPolicy, MetricType, TreeType>::~NeighborSearch()
+{
+ if (ownReferenceTree)
+ delete referenceTree;
+ if (ownQueryTree)
+ delete queryTree;
+}
+
+/**
+ * Computes the best neighbors and stores them in resultingNeighbors and
+ * distances.
+ */
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
+ const size_t k,
+ arma::Mat<size_t>& resultingNeighbors,
+ arma::mat& distances)
+{
+ Timer::Start("computing_neighbors");
+
+ // If we have built the trees ourselves, then we will have to map all the
+ // indices back to their original indices when this computation is finished.
+ // To avoid an extra copy, we will store the neighbors and distances in a
+ // separate matrix.
+ arma::Mat<size_t>* neighborPtr = &resultingNeighbors;
+ arma::mat* distancePtr = &distances;
+
+ if (ownQueryTree || (ownReferenceTree && !queryTree))
+ distancePtr = new arma::mat; // Query indices need to be mapped.
+ if (ownReferenceTree || ownQueryTree)
+ neighborPtr = new arma::Mat<size_t>; // All indices need mapping.
+
+ // Set the size of the neighbor and distance matrices.
+ neighborPtr->set_size(k, querySet.n_cols);
+ distancePtr->set_size(k, querySet.n_cols);
+ distancePtr->fill(SortPolicy::WorstDistance());
+
+ size_t numPrunes = 0;
+
+ if (singleMode)
+ {
+ // Create the helper object for the tree traversal.
+ typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
+ RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric);
+
+ // Create the traverser.
+ typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+
+ // Now have it traverse for each point.
+ for (size_t i = 0; i < querySet.n_cols; ++i)
+ traverser.Traverse(i, *referenceTree);
+
+ numPrunes = traverser.NumPrunes();
+ }
+ else // Dual-tree recursion.
+ {
+ // Create the helper object for the tree traversal.
+ typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
+ RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric);
+
+ typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+
+ if (queryTree)
+ traverser.Traverse(*queryTree, *referenceTree);
+ else
+ traverser.Traverse(*referenceTree, *referenceTree);
+
+ numPrunes = traverser.NumPrunes();
+ }
+
+ Log::Debug << "Pruned " << numPrunes << " nodes." << std::endl;
+
+ Timer::Stop("computing_neighbors");
+
+ // Now, do we need to do mapping of indices?
+ if (!ownReferenceTree && !ownQueryTree)
+ {
+ // No mapping needed. We are done.
+ return;
+ }
+ else if (ownReferenceTree && ownQueryTree) // Map references and queries.
+ {
+ // Set size of output matrices correctly.
+ resultingNeighbors.set_size(k, querySet.n_cols);
+ distances.set_size(k, querySet.n_cols);
+
+ for (size_t i = 0; i < distances.n_cols; i++)
+ {
+ // Map distances (copy a column).
+ distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
+
+ // Map indices of neighbors.
+ for (size_t j = 0; j < distances.n_rows; j++)
+ {
+ resultingNeighbors(j, oldFromNewQueries[i]) =
+ oldFromNewReferences[(*neighborPtr)(j, i)];
+ }
+ }
+
+ // Finished with temporary matrices.
+ delete neighborPtr;
+ delete distancePtr;
+ }
+ else if (ownReferenceTree)
+ {
+ if (!queryTree) // No query tree -- map both references and queries.
+ {
+ resultingNeighbors.set_size(k, querySet.n_cols);
+ distances.set_size(k, querySet.n_cols);
+
+ for (size_t i = 0; i < distances.n_cols; i++)
+ {
+ // Map distances (copy a column).
+ distances.col(oldFromNewReferences[i]) = distancePtr->col(i);
+
+ // Map indices of neighbors.
+ for (size_t j = 0; j < distances.n_rows; j++)
+ {
+ resultingNeighbors(j, oldFromNewReferences[i]) =
+ oldFromNewReferences[(*neighborPtr)(j, i)];
+ }
+ }
+ }
+ else // Map only references.
+ {
+ // Set size of neighbor indices matrix correctly.
+ resultingNeighbors.set_size(k, querySet.n_cols);
+
+ // Map indices of neighbors.
+ for (size_t i = 0; i < resultingNeighbors.n_cols; i++)
+ {
+ for (size_t j = 0; j < resultingNeighbors.n_rows; j++)
+ {
+ resultingNeighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
+ }
+ }
+ }
+
+ // Finished with temporary matrix.
+ delete neighborPtr;
+ }
+ else if (ownQueryTree)
+ {
+ // Set size of matrices correctly.
+ resultingNeighbors.set_size(k, querySet.n_cols);
+ distances.set_size(k, querySet.n_cols);
+
+ for (size_t i = 0; i < distances.n_cols; i++)
+ {
+ // Map distances (copy a column).
+ distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
+
+ // Map indices of neighbors.
+ resultingNeighbors.col(oldFromNewQueries[i]) = neighborPtr->col(i);
+ }
+
+ // Finished with temporary matrices.
+ delete neighborPtr;
+ delete distancePtr;
+ }
+} // Search
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,179 +0,0 @@
-/**
- * @file neighbor_search_rules.hpp
- * @author Ryan Curtin
- *
- * Defines the pruning rules and base case rules necessary to perform a
- * tree-based search (with an arbitrary tree) for the NeighborSearch class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
-#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
-
-namespace mlpack {
-namespace neighbor {
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-class NeighborSearchRules
-{
- public:
- NeighborSearchRules(const arma::mat& referenceSet,
- const arma::mat& querySet,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances,
- MetricType& metric);
-
- double BaseCase(const size_t queryIndex, const size_t referenceIndex);
-
- /**
- * Get the score for the recursion order, in general before the base case is
- * computed. This is useful for cover trees or other trees that can cache
- * some statistic that could be used to make a prune of a child before its
- * base case is computed.
- *
- * @param queryNode Query node.
- * @param referenceNode Reference node.
- */
- double Prescore(TreeType& queryNode,
- TreeType& referenceNode,
- TreeType& referenceChildNode,
- const double baseCaseResult) const;
- double PrescoreQ(TreeType& queryNode,
- TreeType& queryChildNode,
- TreeType& referenceNode,
- const double baseCaseResult) const;
-
- /**
- * Get the score for recursion order. A low score indicates priority for
- * recursion, while DBL_MAX indicates that the node should not be recursed
- * into at all (it should be pruned).
- *
- * @param queryIndex Index of query point.
- * @param referenceNode Candidate node to be recursed into.
- */
- double Score(const size_t queryIndex, TreeType& referenceNode) const;
-
- /**
- * Get the score for recursion order, passing the base case result (in the
- * situation where it may be needed to calculate the recursion order). A low
- * score indicates priority for recursion, while DBL_MAX indicates that the
- * node should not be recursed into at all (it should be pruned).
- *
- * @param queryIndex Index of query point.
- * @param referenceNode Candidate node to be recursed into.
- * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
- */
- double Score(const size_t queryIndex,
- TreeType& referenceNode,
- const double baseCaseResult) const;
-
- /**
- * Re-evaluate the score for recursion order. A low score indicates priority
- * for recursion, while DBL_MAX indicates that the node should not be recursed
- * into at all (it should be pruned). This is used when the score has already
- * been calculated, but another recursion may have modified the bounds for
- * pruning. So the old score is checked against the new pruning bound.
- *
- * @param queryIndex Index of query point.
- * @param referenceNode Candidate node to be recursed into.
- * @param oldScore Old score produced by Score() (or Rescore()).
- */
- double Rescore(const size_t queryIndex,
- TreeType& referenceNode,
- const double oldScore) const;
-
- /**
- * Get the score for recursion order. A low score indicates priority for
- * recursionm while DBL_MAX indicates that the node should not be recursed
- * into at all (it should be pruned).
- *
- * @param queryNode Candidate query node to recurse into.
- * @param referenceNode Candidate reference node to recurse into.
- */
- double Score(TreeType& queryNode, TreeType& referenceNode) const;
-
- /**
- * Get the score for recursion order, passing the base case result (in the
- * situation where it may be needed to calculate the recursion order). A low
- * score indicates priority for recursion, while DBL_MAX indicates that the
- * node should not be recursed into at all (it should be pruned).
- *
- * @param queryNode Candidate query node to recurse into.
- * @param referenceNode Candidate reference node to recurse into.
- * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
- */
- double Score(TreeType& queryNode,
- TreeType& referenceNode,
- const double baseCaseResult) const;
-
- /**
- * Re-evaluate the score for recursion order. A low score indicates priority
- * for recursion, while DBL_MAX indicates that the node should not be recursed
- * into at all (it should be pruned). This is used when the score has already
- * been calculated, but another recursion may have modified the bounds for
- * pruning. So the old score is checked against the new pruning bound.
- *
- * @param queryNode Candidate query node to recurse into.
- * @param referenceNode Candidate reference node to recurse into.
- * @param oldScore Old score produced by Socre() (or Rescore()).
- */
- double Rescore(TreeType& queryNode,
- TreeType& referenceNode,
- const double oldScore) const;
-
- private:
- //! The reference set.
- const arma::mat& referenceSet;
-
- //! The query set.
- const arma::mat& querySet;
-
- //! The matrix the resultant neighbor indices should be stored in.
- arma::Mat<size_t>& neighbors;
-
- //! The matrix the resultant neighbor distances should be stored in.
- arma::mat& distances;
-
- //! The instantiated metric.
- MetricType& metric;
-
- /**
- * Recalculate the bound for a given query node.
- */
- double CalculateBound(TreeType& queryNode) const;
-
- /**
- * Insert a point into the neighbors and distances matrices; this is a helper
- * function.
- *
- * @param queryIndex Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
- * @param neighbor Index of reference point which is being inserted.
- * @param distance Distance from query point to reference point.
- */
- void InsertNeighbor(const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance);
-};
-
-}; // namespace neighbor
-}; // namespace mlpack
-
-// Include implementation.
-#include "neighbor_search_rules_impl.hpp"
-
-#endif // __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp (from rev 14999, mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,170 @@
+/**
+ * @file neighbor_search_rules.hpp
+ * @author Ryan Curtin
+ *
+ * Defines the pruning rules and base case rules necessary to perform a
+ * tree-based search (with an arbitrary tree) for the NeighborSearch class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
+
+namespace mlpack {
+namespace neighbor {
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+class NeighborSearchRules
+{
+ public:
+ NeighborSearchRules(const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances,
+ MetricType& metric);
+
+ double BaseCase(const size_t queryIndex, const size_t referenceIndex);
+
+ /**
+ * Get the score for the recursion order, in general before the base case is
+ * computed. This is useful for cover trees or other trees that can cache
+ * some statistic that could be used to make a prune of a child before its
+ * base case is computed.
+ *
+ * @param queryNode Query node.
+ * @param referenceNode Reference node.
+ */
+ double Prescore(TreeType& queryNode,
+ TreeType& referenceNode,
+ TreeType& referenceChildNode,
+ const double baseCaseResult) const;
+
+ /**
+ * Get the score for recursion order. A low score indicates priority for
+ * recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned).
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ */
+ double Score(const size_t queryIndex, TreeType& referenceNode) const;
+
+ /**
+ * Get the score for recursion order, passing the base case result (in the
+ * situation where it may be needed to calculate the recursion order). A low
+ * score indicates priority for recursion, while DBL_MAX indicates that the
+ * node should not be recursed into at all (it should be pruned).
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
+ */
+ double Score(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double baseCaseResult) const;
+
+ /**
+ * Re-evaluate the score for recursion order. A low score indicates priority
+ * for recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned). This is used when the score has already
+ * been calculated, but another recursion may have modified the bounds for
+ * pruning. So the old score is checked against the new pruning bound.
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ * @param oldScore Old score produced by Score() (or Rescore()).
+ */
+ double Rescore(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double oldScore) const;
+
+ /**
+ * Get the score for recursion order. A low score indicates priority for
+ * recursionm while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned).
+ *
+ * @param queryNode Candidate query node to recurse into.
+ * @param referenceNode Candidate reference node to recurse into.
+ */
+ double Score(TreeType& queryNode, TreeType& referenceNode) const;
+
+ /**
+ * Get the score for recursion order, passing the base case result (in the
+ * situation where it may be needed to calculate the recursion order). A low
+ * score indicates priority for recursion, while DBL_MAX indicates that the
+ * node should not be recursed into at all (it should be pruned).
+ *
+ * @param queryNode Candidate query node to recurse into.
+ * @param referenceNode Candidate reference node to recurse into.
+ * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
+ */
+ double Score(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double baseCaseResult) const;
+
+ /**
+ * Re-evaluate the score for recursion order. A low score indicates priority
+ * for recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned). This is used when the score has already
+ * been calculated, but another recursion may have modified the bounds for
+ * pruning. So the old score is checked against the new pruning bound.
+ *
+ * @param queryNode Candidate query node to recurse into.
+ * @param referenceNode Candidate reference node to recurse into.
+ * @param oldScore Old score produced by Socre() (or Rescore()).
+ */
+ double Rescore(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double oldScore) const;
+
+ private:
+ //! The reference set.
+ const arma::mat& referenceSet;
+
+ //! The query set.
+ const arma::mat& querySet;
+
+ //! The matrix the resultant neighbor indices should be stored in.
+ arma::Mat<size_t>& neighbors;
+
+ //! The matrix the resultant neighbor distances should be stored in.
+ arma::mat& distances;
+
+ //! The instantiated metric.
+ MetricType& metric;
+
+ /**
+ * Insert a point into the neighbors and distances matrices; this is a helper
+ * function.
+ *
+ * @param queryIndex Index of point whose neighbors we are inserting into.
+ * @param pos Position in list to insert into.
+ * @param neighbor Index of reference point which is being inserted.
+ * @param distance Distance from query point to reference point.
+ */
+ void InsertNeighbor(const size_t queryIndex,
+ const size_t pos,
+ const size_t neighbor,
+ const double distance);
+};
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+// Include implementation.
+#include "neighbor_search_rules_impl.hpp"
+
+#endif // __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,332 +0,0 @@
-/**
- * @file nearest_neighbor_rules_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of NearestNeighborRules.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
-#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
-
-// In case it hasn't been included yet.
-#include "neighbor_search_rules.hpp"
-
-namespace mlpack {
-namespace neighbor {
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
- const arma::mat& referenceSet,
- const arma::mat& querySet,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances,
- MetricType& metric) :
- referenceSet(referenceSet),
- querySet(querySet),
- neighbors(neighbors),
- distances(distances),
- metric(metric)
-{ /* Nothing left to do. */ }
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline force_inline // Absolutely MUST be inline so optimizations can happen.
-double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
-BaseCase(const size_t queryIndex, const size_t referenceIndex)
-{
- // If the datasets are the same, then this search is only using one dataset
- // and we should not return identical points.
- if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
- return 0.0;
-
- double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
- referenceSet.unsafe_col(referenceIndex));
-
- // If this distance is better than any of the current candidates, the
- // SortDistance() function will give us the position to insert it into.
- arma::vec queryDist = distances.unsafe_col(queryIndex);
- const size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
-
- // SortDistance() returns (size_t() - 1) if we shouldn't add it.
- if (insertPosition != (size_t() - 1))
- InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance);
-
- return distance;
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Prescore(
- TreeType& queryNode,
- TreeType& referenceNode,
- TreeType& referenceChildNode,
- const double baseCaseResult) const
-{
- const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
- &referenceNode, &referenceChildNode, baseCaseResult);
-
- // Update our bound.
- const double bestDistance = CalculateBound(queryNode);
-
- return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::PrescoreQ(
- TreeType& queryNode,
- TreeType& queryChildNode,
- TreeType& referenceNode,
- const double baseCaseResult) const
-{
- const double distance = SortPolicy::BestNodeToNodeDistance(&referenceNode,
- &queryNode, &queryChildNode, baseCaseResult);
-
- // Update our bound.
- const double bestDistance = CalculateBound(queryNode);
-
- return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
- const size_t queryIndex,
- TreeType& referenceNode) const
-{
- const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
- const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
- &referenceNode);
- const double bestDistance = distances(distances.n_rows - 1, queryIndex);
-
- return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
- const size_t queryIndex,
- TreeType& referenceNode,
- const double baseCaseResult) const
-{
- const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
- const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
- &referenceNode, baseCaseResult);
- const double bestDistance = distances(distances.n_rows - 1, queryIndex);
-
- return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
- const size_t queryIndex,
- TreeType& /* referenceNode */,
- const double oldScore) const
-{
- // If we are already pruning, still prune.
- if (oldScore == DBL_MAX)
- return oldScore;
-
- // Just check the score again against the distances.
- const double bestDistance = distances(distances.n_rows - 1, queryIndex);
-
- return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
- TreeType& queryNode,
- TreeType& referenceNode) const
-{
- const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
- &referenceNode);
-
- // Update our bound.
- const double bestDistance = CalculateBound(queryNode);
-
- return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
- TreeType& queryNode,
- TreeType& referenceNode,
- const double baseCaseResult) const
-{
- const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
- &referenceNode, baseCaseResult);
-
- // Update our bound.
- const double bestDistance = CalculateBound(queryNode);
-
- return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
- TreeType& queryNode,
- TreeType& /* referenceNode */,
- const double oldScore) const
-{
- if (oldScore == DBL_MAX)
- return oldScore;
-
- // Update our bound.
- const double bestDistance = CalculateBound(queryNode);
-
- return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
-}
-
-// Calculate the bound for a given query node in its current state and update
-// it.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
- CalculateBound(TreeType& queryNode) const
-{
- // We have five possible bounds, and we must take the best of them all. We
- // don't use min/max here, but instead "best/worst", because this is general
- // to the nearest-neighbors/furthest-neighbors cases. For nearest neighbors,
- // min = best, max = worst.
- //
- // (1) worst ( worst_{all points p in queryNode} D_p[k],
- // worst_{all children c in queryNode} B(c) );
- // (2) best_{all points p in queryNode} D_p[k] + worst child distance +
- // worst descendant distance;
- // (3) best_{all children c in queryNode} B(c) +
- // 2 ( worst descendant distance of queryNode -
- // worst descendant distance of c );
- // (4) B_1(parent of queryNode)
- // (5) B_2(parent of queryNode);
- //
- // D_p[k] is the current k'th candidate distance for point p.
- // So we will loop over the points in queryNode and the children in queryNode
- // to calculate all five of these quantities.
-
- double worstPointDistance = SortPolicy::BestDistance();
- double bestPointDistance = SortPolicy::WorstDistance();
-
- // Loop over all points in this node to find the best and worst distance
- // candidates (for (1) and (2)).
- for (size_t i = 0; i < queryNode.NumPoints(); ++i)
- {
- const double distance = distances(distances.n_rows - 1, queryNode.Point(i));
- if (SortPolicy::IsBetter(distance, bestPointDistance))
- bestPointDistance = distance;
- if (SortPolicy::IsBetter(worstPointDistance, distance))
- worstPointDistance = distance;
- }
-
- // Loop over all the children in this node to find the worst bound (for (1))
- // and the best bound with the correcting factor for descendant distances (for
- // (3)).
- double worstChildBound = SortPolicy::BestDistance();
- double bestAdjustedChildBound = SortPolicy::WorstDistance();
- const double queryMaxDescendantDistance =
- queryNode.FurthestDescendantDistance();
-
- for (size_t i = 0; i < queryNode.NumChildren(); ++i)
- {
- const double firstBound = queryNode.Child(i).Stat().FirstBound();
- const double secondBound = queryNode.Child(i).Stat().SecondBound();
- const double childMaxDescendantDistance =
- queryNode.Child(i).FurthestDescendantDistance();
-
- if (SortPolicy::IsBetter(worstChildBound, firstBound))
- worstChildBound = firstBound;
-
- // Now calculate adjustment for maximum descendant distances.
- const double adjustedBound = SortPolicy::CombineWorst(secondBound,
- 2 * (queryMaxDescendantDistance - childMaxDescendantDistance));
- if (SortPolicy::IsBetter(adjustedBound, bestAdjustedChildBound))
- bestAdjustedChildBound = adjustedBound;
- }
-
- // This is bound (1).
- const double firstBound =
- (SortPolicy::IsBetter(worstPointDistance, worstChildBound)) ?
- worstChildBound : worstPointDistance;
-
- // This is bound (2).
- const double secondBound = SortPolicy::CombineWorst(bestPointDistance,
- 2 * queryMaxDescendantDistance);
-
- // Bound (3) is bestAdjustedChildBound.
-
- // Bounds (4) and (5) are the parent bounds.
- const double fourthBound = (queryNode.Parent() != NULL) ?
- queryNode.Parent()->Stat().FirstBound() : SortPolicy::WorstDistance();
- const double fifthBound = (queryNode.Parent() != NULL) ?
- queryNode.Parent()->Stat().SecondBound() : SortPolicy::WorstDistance();
-
- // Now, we will take the best of these. Unfortunately due to the way
- // IsBetter() is defined, this sort of has to be a little ugly.
- // The variable interA represents the first bound (B_1), which is the worst
- // candidate distance of any descendants of this node.
- // The variable interC represents the second bound (B_2), which is a bound on
- // the worst distance of any descendants of this node assembled using the best
- // descendant candidate distance modified using the furthest descendant
- // distance.
- const double interA = (SortPolicy::IsBetter(firstBound, fourthBound)) ?
- firstBound : fourthBound;
- const double interB =
- (SortPolicy::IsBetter(bestAdjustedChildBound, secondBound)) ?
- bestAdjustedChildBound : secondBound;
- const double interC = (SortPolicy::IsBetter(interB, fifthBound)) ? interB :
- fifthBound;
-
- // Update the first and second bounds of the node.
- queryNode.Stat().FirstBound() = interA;
- queryNode.Stat().SecondBound() = interC;
-
- // Update the actual bound of the node.
- queryNode.Stat().Bound() = (SortPolicy::IsBetter(interA, interC)) ? interA :
- interC;
-
- return queryNode.Stat().Bound();
-}
-
-/**
- * Helper function to insert a point into the neighbors and distances matrices.
- *
- * @param queryIndex Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
- * @param neighbor Index of reference point which is being inserted.
- * @param distance Distance from query point to reference point.
- */
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearchRules<SortPolicy, MetricType, TreeType>::InsertNeighbor(
- const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance)
-{
- // We only memmove() if there is actually a need to shift something.
- if (pos < (distances.n_rows - 1))
- {
- int len = (distances.n_rows - 1) - pos;
- memmove(distances.colptr(queryIndex) + (pos + 1),
- distances.colptr(queryIndex) + pos,
- sizeof(double) * len);
- memmove(neighbors.colptr(queryIndex) + (pos + 1),
- neighbors.colptr(queryIndex) + pos,
- sizeof(size_t) * len);
- }
-
- // Now put the new information in the right index.
- distances(pos, queryIndex) = distance;
- neighbors(pos, queryIndex) = neighbor;
-}
-
-}; // namespace neighbor
-}; // namespace mlpack
-
-#endif // __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp (from rev 14999, mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,316 @@
+/**
+ * @file nearest_neighbor_rules_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of NearestNeighborRules.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "neighbor_search_rules.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
+ const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances,
+ MetricType& metric) :
+ referenceSet(referenceSet),
+ querySet(querySet),
+ neighbors(neighbors),
+ distances(distances),
+ metric(metric)
+{ /* Nothing left to do. */ }
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline force_inline // Absolutely MUST be inline so optimizations can happen.
+double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
+BaseCase(const size_t queryIndex, const size_t referenceIndex)
+{
+ // If the datasets are the same, then this search is only using one dataset
+ // and we should not return identical points.
+ if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
+ return 0.0;
+
+ double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
+ referenceSet.unsafe_col(referenceIndex));
+
+ // If this distance is better than any of the current candidates, the
+ // SortDistance() function will give us the position to insert it into.
+ arma::vec queryDist = distances.unsafe_col(queryIndex);
+ size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
+
+ // SortDistance() returns (size_t() - 1) if we shouldn't add it.
+ if (insertPosition != (size_t() - 1))
+ InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance);
+
+ return distance;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Prescore(
+ TreeType& queryNode,
+ TreeType& referenceNode,
+ TreeType& referenceChildNode,
+ const double baseCaseResult) const
+{
+ const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
+ &referenceNode, &referenceChildNode, baseCaseResult);
+
+ // Calculate the bound on the fly. This bound will be the minimum of
+ // pointBound (the bounds given by the points in this node) and childBound
+ // (the bounds given by the children of this node).
+ double pointBound = SortPolicy::WorstDistance();
+ double childBound = SortPolicy::WorstDistance();
+ const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
+
+ // Find the bound of the points contained in this node.
+ for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+ {
+ // The bound for this point is the k-th best distance plus the maximum
+ // distance to a child of this node.
+ const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
+ maxDescendantDistance;
+ if (SortPolicy::IsBetter(bound, pointBound))
+ pointBound = bound;
+ }
+
+ // Find the bound of the children.
+ for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+ {
+ const double bound = queryNode.Child(i).Stat().Bound();
+ if (SortPolicy::IsBetter(bound, childBound))
+ childBound = bound;
+ }
+
+ // Update our bound.
+ queryNode.Stat().Bound() = std::min(pointBound, childBound);
+ const double bestDistance = queryNode.Stat().Bound();
+
+ return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
+ const size_t queryIndex,
+ TreeType& referenceNode) const
+{
+ const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+ const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
+ &referenceNode);
+ const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+
+ return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
+ const size_t queryIndex,
+ TreeType& referenceNode,
+ const double baseCaseResult) const
+{
+ const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+ const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
+ &referenceNode, baseCaseResult);
+ const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+
+ return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
+ const size_t queryIndex,
+ TreeType& /* referenceNode */,
+ const double oldScore) const
+{
+ // If we are already pruning, still prune.
+ if (oldScore == DBL_MAX)
+ return oldScore;
+
+ // Just check the score again against the distances.
+ const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+
+ return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
+ TreeType& queryNode,
+ TreeType& referenceNode) const
+{
+ const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
+ &referenceNode);
+
+ // Calculate the bound on the fly. This bound will be the minimum of
+ // pointBound (the bounds given by the points in this node) and childBound
+ // (the bounds given by the children of this node).
+ double pointBound = SortPolicy::WorstDistance();
+ double childBound = SortPolicy::WorstDistance();
+ const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
+
+ // Find the bound of the points contained in this node.
+ for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+ {
+ // The bound for this point is the k-th best distance plus the maximum
+ // distance to a child of this node.
+ const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
+ maxDescendantDistance;
+ if (SortPolicy::IsBetter(bound, pointBound))
+ pointBound = bound;
+ }
+
+ // Find the bound of the children.
+ for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+ {
+ const double bound = queryNode.Child(i).Stat().Bound();
+ if (SortPolicy::IsBetter(bound, childBound))
+ childBound = bound;
+ }
+
+ // Update our bound.
+ queryNode.Stat().Bound() = std::min(pointBound, childBound);
+ const double bestDistance = queryNode.Stat().Bound();
+
+ return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
+ TreeType& queryNode,
+ TreeType& referenceNode,
+ const double baseCaseResult) const
+{
+ const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
+ &referenceNode, baseCaseResult);
+
+ // Calculate the bound on the fly. This bound will be the minimum of
+ // pointBound (the bounds given by the points in this node) and childBound
+ // (the bounds given by the children of this node).
+ double pointBound = SortPolicy::WorstDistance();
+ double childBound = SortPolicy::WorstDistance();
+ const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
+
+ // Find the bound of the points contained in this node.
+ for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+ {
+ // The bound for this point is the k-th best distance plus the maximum
+ // distance to a child of this node.
+ const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
+ maxDescendantDistance;
+ if (SortPolicy::IsBetter(bound, pointBound))
+ pointBound = bound;
+ }
+
+ // Find the bound of the children.
+ for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+ {
+ const double bound = queryNode.Child(i).Stat().Bound();
+ if (SortPolicy::IsBetter(bound, childBound))
+ childBound = bound;
+ }
+
+ // Update our bound.
+ queryNode.Stat().Bound() = std::min(pointBound, childBound);
+ const double bestDistance = queryNode.Stat().Bound();
+
+ return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
+ TreeType& queryNode,
+ TreeType& /* referenceNode */,
+ const double oldScore) const
+{
+ if (oldScore == DBL_MAX)
+ return oldScore;
+
+ // Calculate the bound on the fly. This bound will be the minimum of
+ // pointBound (the bounds given by the points in this node) and childBound
+ // (the bounds given by the children of this node).
+ double pointBound = SortPolicy::WorstDistance();
+ double childBound = SortPolicy::WorstDistance();
+ const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
+
+ // Find the bound of the points contained in this node.
+ for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+ {
+ // The bound for this point is the k-th best distance plus the maximum
+ // distance to a child of this node.
+ const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
+ maxDescendantDistance;
+ if (SortPolicy::IsBetter(bound, pointBound))
+ pointBound = bound;
+ }
+
+ // Find the bound of the children.
+ for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+ {
+ const double bound = queryNode.Child(i).Stat().Bound();
+ if (SortPolicy::IsBetter(bound, childBound))
+ childBound = bound;
+ }
+
+ // Update our bound.
+ queryNode.Stat().Bound() = std::min(pointBound, childBound);
+ const double bestDistance = queryNode.Stat().Bound();
+
+ return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
+}
+
+/**
+ * Helper function to insert a point into the neighbors and distances matrices.
+ *
+ * @param queryIndex Index of point whose neighbors we are inserting into.
+ * @param pos Position in list to insert into.
+ * @param neighbor Index of reference point which is being inserted.
+ * @param distance Distance from query point to reference point.
+ */
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearchRules<SortPolicy, MetricType, TreeType>::InsertNeighbor(
+ const size_t queryIndex,
+ const size_t pos,
+ const size_t neighbor,
+ const double distance)
+{
+ // We only memmove() if there is actually a need to shift something.
+ if (pos < (distances.n_rows - 1))
+ {
+ int len = (distances.n_rows - 1) - pos;
+ memmove(distances.colptr(queryIndex) + (pos + 1),
+ distances.colptr(queryIndex) + pos,
+ sizeof(double) * len);
+ memmove(neighbors.colptr(queryIndex) + (pos + 1),
+ neighbors.colptr(queryIndex) + pos,
+ sizeof(size_t) * len);
+ }
+
+ // Now put the new information in the right index.
+ distances(pos, queryIndex) = distance;
+ neighbors(pos, queryIndex) = neighbor;
+}
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+#endif // __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,41 +0,0 @@
-/***
- * @file nearest_neighbor_sort.cpp
- * @author Ryan Curtin
- *
- * Implementation of the simple FurthestNeighborSort policy class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "furthest_neighbor_sort.hpp"
-
-using namespace mlpack::neighbor;
-
-size_t FurthestNeighborSort::SortDistance(const arma::vec& list,
- double newDistance)
-{
- // The first element in the list is the nearest neighbor. We only want to
- // insert if the new distance is greater than the last element in the list.
- if (newDistance < list[list.n_elem - 1])
- return (size_t() - 1); // Do not insert.
-
- // Search from the beginning. This may not be the best way.
- for (size_t i = 0; i < list.n_elem; i++)
- if (newDistance >= list[i])
- return i;
-
- // Control should never reach here.
- return (size_t() - 1);
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,41 @@
+/***
+ * @file nearest_neighbor_sort.cpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the simple FurthestNeighborSort policy class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "furthest_neighbor_sort.hpp"
+
+using namespace mlpack::neighbor;
+
+size_t FurthestNeighborSort::SortDistance(const arma::vec& list,
+ double newDistance)
+{
+ // The first element in the list is the nearest neighbor. We only want to
+ // insert if the new distance is greater than the last element in the list.
+ if (newDistance < list[list.n_elem - 1])
+ return (size_t() - 1); // Do not insert.
+
+ // Search from the beginning. This may not be the best way.
+ for (size_t i = 0; i < list.n_elem; i++)
+ if (newDistance >= list[i])
+ return i;
+
+ // Control should never reach here.
+ return (size_t() - 1);
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,159 +0,0 @@
-/**
- * @file furthest_neighbor_sort.hpp
- * @author Ryan Curtin
- *
- * Implementation of the SortPolicy class for NeighborSearch; in this case, the
- * furthest neighbors are those that are most important.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_FURTHEST_NEIGHBOR_SORT_HPP
-#define __MLPACK_METHODS_NEIGHBOR_SEARCH_FURTHEST_NEIGHBOR_SORT_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace neighbor {
-
-/**
- * This class implements the necessary methods for the SortPolicy template
- * parameter of the NeighborSearch class. The sorting policy here is that the
- * minimum distance is the best (so, when used with NeighborSearch, the output
- * is furthest neighbors).
- */
-class FurthestNeighborSort
-{
- public:
- /**
- * Return the index in the vector where the new distance should be inserted,
- * or size_t() - 1 if it should not be inserted (i.e. if it is not any better
- * than any of the existing points in the list). The list should be sorted
- * such that the best point is the first in the list. The actual insertion is
- * not performed.
- *
- * @param list Vector of existing distance points, sorted such that the best
- * point is the first in the list.
- * @param new_distance Distance to try to insert.
- *
- * @return size_t containing the position to insert into, or (size_t() - 1)
- * if the new distance should not be inserted.
- */
- static size_t SortDistance(const arma::vec& list, double newDistance);
-
- /**
- * Return whether or not value is "better" than ref. In this case, that means
- * that the value is greater than the reference.
- *
- * @param value Value to compare
- * @param ref Value to compare with
- *
- * @return bool indicating whether or not (value > ref).
- */
- static inline bool IsBetter(const double value, const double ref)
- {
- return (value > ref);
- }
-
- /**
- * Return the best possible distance between two nodes. In our case, this is
- * the maximum distance between the two tree nodes using the given distance
- * function.
- */
- template<typename TreeType>
- static double BestNodeToNodeDistance(const TreeType* queryNode,
- const TreeType* referenceNode);
-
- /**
- * Return the best possible distance between two nodes, given that the
- * distance between the centers of the two nodes has already been calculated.
- * This is used in conjunction with trees that have self-children (like cover
- * trees).
- */
- template<typename TreeType>
- static double BestNodeToNodeDistance(const TreeType* queryNode,
- const TreeType* referenceNode,
- const double centerToCenterDistance);
-
- /**
- * Return the best possible distance between the query node and the reference
- * child node given the base case distance between the query node and the
- * reference node. TreeType::ParentDistance() must be implemented to use
- * this.
- *
- * @param queryNode Query node.
- * @param referenceNode Reference node.
- * @param referenceChildNode Child of reference node which is being inspected.
- * @param centerToCenterDistance Distance between centers of query node and
- * reference node.
- */
- template<typename TreeType>
- static double BestNodeToNodeDistance(const TreeType* queryNode,
- const TreeType* referenceNode,
- const TreeType* referenceChildNode,
- const double centerToCenterDistance);
-
- /**
- * Return the best possible distance between a node and a point. In our case,
- * this is the maximum distance between the tree node and the point using the
- * given distance function.
- */
- template<typename TreeType>
- static double BestPointToNodeDistance(const arma::vec& queryPoint,
- const TreeType* referenceNode);
-
- /**
- * Return the best possible distance between a point and a node, given that
- * the distance between the point and the center of the node has already been
- * calculated. This is used in conjunction with trees that have
- * self-children (like cover trees).
- */
- template<typename TreeType>
- static double BestPointToNodeDistance(const arma::vec& queryPoint,
- const TreeType* referenceNode,
- const double pointToCenterDistance);
-
- /**
- * Return what should represent the worst possible distance with this
- * particular sort policy. In our case, this should be the minimum possible
- * distance, 0.
- *
- * @return 0
- */
- static inline double WorstDistance() { return 0; }
-
- /**
- * Return what should represent the best possible distance with this
- * particular sort policy. In our case, this should be the maximum possible
- * distance, DBL_MAX.
- *
- * @return DBL_MAX
- */
- static inline double BestDistance() { return DBL_MAX; }
-
- /**
- * Return the worst combination of the two distances.
- */
- static inline double CombineWorst(const double a, const double b)
- { return std::max(a - b, 0.0); }
-};
-
-}; // namespace neighbor
-}; // namespace mlpack
-
-// Include implementation of templated functions.
-#include "furthest_neighbor_sort_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,159 @@
+/**
+ * @file furthest_neighbor_sort.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the SortPolicy class for NeighborSearch; in this case, the
+ * furthest neighbors are those that are most important.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_FURTHEST_NEIGHBOR_SORT_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_FURTHEST_NEIGHBOR_SORT_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace neighbor {
+
+/**
+ * This class implements the necessary methods for the SortPolicy template
+ * parameter of the NeighborSearch class. The sorting policy here is that the
+ * minimum distance is the best (so, when used with NeighborSearch, the output
+ * is furthest neighbors).
+ */
+class FurthestNeighborSort
+{
+ public:
+ /**
+ * Return the index in the vector where the new distance should be inserted,
+ * or size_t() - 1 if it should not be inserted (i.e. if it is not any better
+ * than any of the existing points in the list). The list should be sorted
+ * such that the best point is the first in the list. The actual insertion is
+ * not performed.
+ *
+ * @param list Vector of existing distance points, sorted such that the best
+ * point is the first in the list.
+ * @param new_distance Distance to try to insert.
+ *
+ * @return size_t containing the position to insert into, or (size_t() - 1)
+ * if the new distance should not be inserted.
+ */
+ static size_t SortDistance(const arma::vec& list, double newDistance);
+
+ /**
+ * Return whether or not value is "better" than ref. In this case, that means
+ * that the value is greater than the reference.
+ *
+ * @param value Value to compare
+ * @param ref Value to compare with
+ *
+ * @return bool indicating whether or not (value > ref).
+ */
+ static inline bool IsBetter(const double value, const double ref)
+ {
+ return (value > ref);
+ }
+
+ /**
+ * Return the best possible distance between two nodes. In our case, this is
+ * the maximum distance between the two tree nodes using the given distance
+ * function.
+ */
+ template<typename TreeType>
+ static double BestNodeToNodeDistance(const TreeType* queryNode,
+ const TreeType* referenceNode);
+
+ /**
+ * Return the best possible distance between two nodes, given that the
+ * distance between the centers of the two nodes has already been calculated.
+ * This is used in conjunction with trees that have self-children (like cover
+ * trees).
+ */
+ template<typename TreeType>
+ static double BestNodeToNodeDistance(const TreeType* queryNode,
+ const TreeType* referenceNode,
+ const double centerToCenterDistance);
+
+ /**
+ * Return the best possible distance between the query node and the reference
+ * child node given the base case distance between the query node and the
+ * reference node. TreeType::ParentDistance() must be implemented to use
+ * this.
+ *
+ * @param queryNode Query node.
+ * @param referenceNode Reference node.
+ * @param referenceChildNode Child of reference node which is being inspected.
+ * @param centerToCenterDistance Distance between centers of query node and
+ * reference node.
+ */
+ template<typename TreeType>
+ static double BestNodeToNodeDistance(const TreeType* queryNode,
+ const TreeType* referenceNode,
+ const TreeType* referenceChildNode,
+ const double centerToCenterDistance);
+
+ /**
+ * Return the best possible distance between a node and a point. In our case,
+ * this is the maximum distance between the tree node and the point using the
+ * given distance function.
+ */
+ template<typename TreeType>
+ static double BestPointToNodeDistance(const arma::vec& queryPoint,
+ const TreeType* referenceNode);
+
+ /**
+ * Return the best possible distance between a point and a node, given that
+ * the distance between the point and the center of the node has already been
+ * calculated. This is used in conjunction with trees that have
+ * self-children (like cover trees).
+ */
+ template<typename TreeType>
+ static double BestPointToNodeDistance(const arma::vec& queryPoint,
+ const TreeType* referenceNode,
+ const double pointToCenterDistance);
+
+ /**
+ * Return what should represent the worst possible distance with this
+ * particular sort policy. In our case, this should be the minimum possible
+ * distance, 0.
+ *
+ * @return 0
+ */
+ static inline double WorstDistance() { return 0; }
+
+ /**
+ * Return what should represent the best possible distance with this
+ * particular sort policy. In our case, this should be the maximum possible
+ * distance, DBL_MAX.
+ *
+ * @return DBL_MAX
+ */
+ static inline double BestDistance() { return DBL_MAX; }
+
+ /**
+ * Return the worst combination of the two distances.
+ */
+ static inline double CombineWorst(const double a, const double b)
+ { return std::max(a - b, 0.0); }
+};
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+// Include implementation of templated functions.
+#include "furthest_neighbor_sort_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,81 +0,0 @@
-/***
- * @file furthest_neighbor_sort_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of templated methods for the FurthestNeighborSort SortPolicy
- * class for the NeighborSearch class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_FURTHEST_NEIGHBOR_SORT_IMPL_HPP
-#define __MLPACK_METHODS_NEIGHBOR_SEARCH_FURTHEST_NEIGHBOR_SORT_IMPL_HPP
-
-namespace mlpack {
-namespace neighbor {
-
-template<typename TreeType>
-inline double FurthestNeighborSort::BestNodeToNodeDistance(
- const TreeType* queryNode,
- const TreeType* referenceNode)
-{
- // This is not implemented yet for the general case because the trees do not
- // accept arbitrary distance metrics.
- return queryNode->MaxDistance(referenceNode);
-}
-
-template<typename TreeType>
-inline double FurthestNeighborSort::BestNodeToNodeDistance(
- const TreeType* queryNode,
- const TreeType* referenceNode,
- const double centerToCenterDistance)
-{
- return queryNode->MaxDistance(referenceNode, centerToCenterDistance);
-}
-
-template<typename TreeType>
-inline double FurthestNeighborSort::BestNodeToNodeDistance(
- const TreeType* queryNode,
- const TreeType* referenceNode,
- const TreeType* referenceChildNode,
- const double centerToCenterDistance)
-{
- return queryNode->MaxDistance(referenceNode, centerToCenterDistance) +
- referenceChildNode->ParentDistance();
-}
-
-template<typename TreeType>
-inline double FurthestNeighborSort::BestPointToNodeDistance(
- const arma::vec& point,
- const TreeType* referenceNode)
-{
- // This is not implemented yet for the general case because the trees do not
- // accept arbitrary distance metrics.
- return referenceNode->MaxDistance(point);
-}
-
-template<typename TreeType>
-inline double FurthestNeighborSort::BestPointToNodeDistance(
- const arma::vec& point,
- const TreeType* referenceNode,
- const double pointToCenterDistance)
-{
- return referenceNode->MaxDistance(point, pointToCenterDistance);
-}
-
-}; // namespace neighbor
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,81 @@
+/***
+ * @file furthest_neighbor_sort_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of templated methods for the FurthestNeighborSort SortPolicy
+ * class for the NeighborSearch class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_FURTHEST_NEIGHBOR_SORT_IMPL_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_FURTHEST_NEIGHBOR_SORT_IMPL_HPP
+
+namespace mlpack {
+namespace neighbor {
+
+template<typename TreeType>
+inline double FurthestNeighborSort::BestNodeToNodeDistance(
+ const TreeType* queryNode,
+ const TreeType* referenceNode)
+{
+ // This is not implemented yet for the general case because the trees do not
+ // accept arbitrary distance metrics.
+ return queryNode->MaxDistance(referenceNode);
+}
+
+template<typename TreeType>
+inline double FurthestNeighborSort::BestNodeToNodeDistance(
+ const TreeType* queryNode,
+ const TreeType* referenceNode,
+ const double centerToCenterDistance)
+{
+ return queryNode->MaxDistance(referenceNode, centerToCenterDistance);
+}
+
+template<typename TreeType>
+inline double FurthestNeighborSort::BestNodeToNodeDistance(
+ const TreeType* queryNode,
+ const TreeType* referenceNode,
+ const TreeType* referenceChildNode,
+ const double centerToCenterDistance)
+{
+ return queryNode->MaxDistance(referenceNode, centerToCenterDistance) +
+ referenceChildNode->ParentDistance();
+}
+
+template<typename TreeType>
+inline double FurthestNeighborSort::BestPointToNodeDistance(
+ const arma::vec& point,
+ const TreeType* referenceNode)
+{
+ // This is not implemented yet for the general case because the trees do not
+ // accept arbitrary distance metrics.
+ return referenceNode->MaxDistance(point);
+}
+
+template<typename TreeType>
+inline double FurthestNeighborSort::BestPointToNodeDistance(
+ const arma::vec& point,
+ const TreeType* referenceNode,
+ const double pointToCenterDistance)
+{
+ return referenceNode->MaxDistance(point, pointToCenterDistance);
+}
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,41 +0,0 @@
-/**
- * @file nearest_neighbor_sort.cpp
- * @author Ryan Curtin
- *
- * Implementation of the simple NearestNeighborSort policy class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "nearest_neighbor_sort.hpp"
-
-using namespace mlpack::neighbor;
-
-size_t NearestNeighborSort::SortDistance(const arma::vec& list,
- double newDistance)
-{
- // The first element in the list is the nearest neighbor. We only want to
- // insert if the new distance is less than the last element in the list.
- if (newDistance > list[list.n_elem - 1])
- return (size_t() - 1); // Do not insert.
-
- // Search from the beginning. This may not be the best way.
- for (size_t i = 0; i < list.n_elem; i++)
- if (newDistance <= list[i])
- return i;
-
- // Control should never reach here.
- return (size_t() - 1);
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,41 @@
+/**
+ * @file nearest_neighbor_sort.cpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the simple NearestNeighborSort policy class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "nearest_neighbor_sort.hpp"
+
+using namespace mlpack::neighbor;
+
+size_t NearestNeighborSort::SortDistance(const arma::vec& list,
+ double newDistance)
+{
+ // The first element in the list is the nearest neighbor. We only want to
+ // insert if the new distance is less than the last element in the list.
+ if (newDistance > list[list.n_elem - 1])
+ return (size_t() - 1); // Do not insert.
+
+ // Search from the beginning. This may not be the best way.
+ for (size_t i = 0; i < list.n_elem; i++)
+ if (newDistance <= list[i])
+ return i;
+
+ // Control should never reach here.
+ return (size_t() - 1);
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,165 +0,0 @@
-/**
- * @file nearest_neighbor_sort.hpp
- * @author Ryan Curtin
- *
- * Implementation of the SortPolicy class for NeighborSearch; in this case, the
- * nearest neighbors are those that are most important.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_SORT_HPP
-#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_SORT_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace neighbor {
-
-/**
- * This class implements the necessary methods for the SortPolicy template
- * parameter of the NeighborSearch class. The sorting policy here is that the
- * minimum distance is the best (so, when used with NeighborSearch, the output
- * is nearest neighbors).
- *
- * This class is also meant to serve as a guide to implement a custom
- * SortPolicy. All of the methods implemented here must be implemented by any
- * other SortPolicy classes.
- */
-class NearestNeighborSort
-{
- public:
- /**
- * Return the index in the vector where the new distance should be inserted,
- * or (size_t() - 1) if it should not be inserted (i.e. if it is not any
- * better than any of the existing points in the list). The list should be
- * sorted such that the best point is the first in the list. The actual
- * insertion is not performed.
- *
- * @param list Vector of existing distance points, sorted such that the best
- * point is first in the list.
- * @param new_distance Distance to try to insert
- *
- * @return size_t containing the position to insert into, or (size_t() - 1)
- * if the new distance should not be inserted.
- */
- static size_t SortDistance(const arma::vec& list, double newDistance);
-
- /**
- * Return whether or not value is "better" than ref. In this case, that means
- * that the value is less than the reference.
- *
- * @param value Value to compare
- * @param ref Value to compare with
- *
- * @return bool indicating whether or not (value < ref).
- */
- static inline bool IsBetter(const double value, const double ref)
- {
- return (value < ref);
- }
-
- /**
- * Return the best possible distance between two nodes. In our case, this is
- * the minimum distance between the two tree nodes using the given distance
- * function.
- */
- template<typename TreeType>
- static double BestNodeToNodeDistance(const TreeType* queryNode,
- const TreeType* referenceNode);
-
- /**
- * Return the best possible distance between two nodes, given that the
- * distance between the centers of the two nodes has already been calculated.
- * This is used in conjunction with trees that have self-children (like cover
- * trees).
- */
- template<typename TreeType>
- static double BestNodeToNodeDistance(const TreeType* queryNode,
- const TreeType* referenceNode,
- const double centerToCenterDistance);
-
- /**
- * Return the best possible distance between the query node and the reference
- * child node given the base case distance between the query node and the
- * reference node. TreeType::ParentDistance() must be implemented to use
- * this.
- *
- * @param queryNode Query node.
- * @param referenceNode Reference node.
- * @param referenceChildNode Child of reference node which is being inspected.
- * @param centerToCenterDistance Distance between centers of query node and
- * reference node.
- */
- template<typename TreeType>
- static double BestNodeToNodeDistance(const TreeType* queryNode,
- const TreeType* referenceNode,
- const TreeType* referenceChildNode,
- const double centerToCenterDistance);
- /**
- * Return the best possible distance between a node and a point. In our case,
- * this is the minimum distance between the tree node and the point using the
- * given distance function.
- */
- template<typename TreeType>
- static double BestPointToNodeDistance(const arma::vec& queryPoint,
- const TreeType* referenceNode);
-
- /**
- * Return the best possible distance between a point and a node, given that
- * the distance between the point and the center of the node has already been
- * calculated. This is used in conjunction with trees that have
- * self-children (like cover trees).
- */
- template<typename TreeType>
- static double BestPointToNodeDistance(const arma::vec& queryPoint,
- const TreeType* referenceNode,
- const double pointToCenterDistance);
-
- /**
- * Return what should represent the worst possible distance with this
- * particular sort policy. In our case, this should be the maximum possible
- * distance, DBL_MAX.
- *
- * @return DBL_MAX
- */
- static inline double WorstDistance() { return DBL_MAX; }
-
- /**
- * Return what should represent the best possible distance with this
- * particular sort policy. In our case, this should be the minimum possible
- * distance, 0.0.
- *
- * @return 0.0
- */
- static inline double BestDistance() { return 0.0; }
-
- /**
- * Return the worst combination of the two distances.
- */
- static inline double CombineWorst(const double a, const double b)
- {
- if (a == DBL_MAX || b == DBL_MAX)
- return DBL_MAX;
- return a + b; }
-};
-
-}; // namespace neighbor
-}; // namespace mlpack
-
-// Include implementation of templated functions.
-#include "nearest_neighbor_sort_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,165 @@
+/**
+ * @file nearest_neighbor_sort.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the SortPolicy class for NeighborSearch; in this case, the
+ * nearest neighbors are those that are most important.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_SORT_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_SORT_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace neighbor {
+
+/**
+ * This class implements the necessary methods for the SortPolicy template
+ * parameter of the NeighborSearch class. The sorting policy here is that the
+ * minimum distance is the best (so, when used with NeighborSearch, the output
+ * is nearest neighbors).
+ *
+ * This class is also meant to serve as a guide to implement a custom
+ * SortPolicy. All of the methods implemented here must be implemented by any
+ * other SortPolicy classes.
+ */
+class NearestNeighborSort
+{
+ public:
+ /**
+ * Return the index in the vector where the new distance should be inserted,
+ * or (size_t() - 1) if it should not be inserted (i.e. if it is not any
+ * better than any of the existing points in the list). The list should be
+ * sorted such that the best point is the first in the list. The actual
+ * insertion is not performed.
+ *
+ * @param list Vector of existing distance points, sorted such that the best
+ * point is first in the list.
+ * @param new_distance Distance to try to insert
+ *
+ * @return size_t containing the position to insert into, or (size_t() - 1)
+ * if the new distance should not be inserted.
+ */
+ static size_t SortDistance(const arma::vec& list, double newDistance);
+
+ /**
+ * Return whether or not value is "better" than ref. In this case, that means
+ * that the value is less than the reference.
+ *
+ * @param value Value to compare
+ * @param ref Value to compare with
+ *
+ * @return bool indicating whether or not (value < ref).
+ */
+ static inline bool IsBetter(const double value, const double ref)
+ {
+ return (value < ref);
+ }
+
+ /**
+ * Return the best possible distance between two nodes. In our case, this is
+ * the minimum distance between the two tree nodes using the given distance
+ * function.
+ */
+ template<typename TreeType>
+ static double BestNodeToNodeDistance(const TreeType* queryNode,
+ const TreeType* referenceNode);
+
+ /**
+ * Return the best possible distance between two nodes, given that the
+ * distance between the centers of the two nodes has already been calculated.
+ * This is used in conjunction with trees that have self-children (like cover
+ * trees).
+ */
+ template<typename TreeType>
+ static double BestNodeToNodeDistance(const TreeType* queryNode,
+ const TreeType* referenceNode,
+ const double centerToCenterDistance);
+
+ /**
+ * Return the best possible distance between the query node and the reference
+ * child node given the base case distance between the query node and the
+ * reference node. TreeType::ParentDistance() must be implemented to use
+ * this.
+ *
+ * @param queryNode Query node.
+ * @param referenceNode Reference node.
+ * @param referenceChildNode Child of reference node which is being inspected.
+ * @param centerToCenterDistance Distance between centers of query node and
+ * reference node.
+ */
+ template<typename TreeType>
+ static double BestNodeToNodeDistance(const TreeType* queryNode,
+ const TreeType* referenceNode,
+ const TreeType* referenceChildNode,
+ const double centerToCenterDistance);
+ /**
+ * Return the best possible distance between a node and a point. In our case,
+ * this is the minimum distance between the tree node and the point using the
+ * given distance function.
+ */
+ template<typename TreeType>
+ static double BestPointToNodeDistance(const arma::vec& queryPoint,
+ const TreeType* referenceNode);
+
+ /**
+ * Return the best possible distance between a point and a node, given that
+ * the distance between the point and the center of the node has already been
+ * calculated. This is used in conjunction with trees that have
+ * self-children (like cover trees).
+ */
+ template<typename TreeType>
+ static double BestPointToNodeDistance(const arma::vec& queryPoint,
+ const TreeType* referenceNode,
+ const double pointToCenterDistance);
+
+ /**
+ * Return what should represent the worst possible distance with this
+ * particular sort policy. In our case, this should be the maximum possible
+ * distance, DBL_MAX.
+ *
+ * @return DBL_MAX
+ */
+ static inline double WorstDistance() { return DBL_MAX; }
+
+ /**
+ * Return what should represent the best possible distance with this
+ * particular sort policy. In our case, this should be the minimum possible
+ * distance, 0.0.
+ *
+ * @return 0.0
+ */
+ static inline double BestDistance() { return 0.0; }
+
+ /**
+ * Return the worst combination of the two distances.
+ */
+ static inline double CombineWorst(const double a, const double b)
+ {
+ if (a == DBL_MAX || b == DBL_MAX)
+ return DBL_MAX;
+ return a + b; }
+};
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+// Include implementation of templated functions.
+#include "nearest_neighbor_sort_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,81 +0,0 @@
-/**
- * @file nearest_neighbor_sort_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of templated methods for the NearestNeighborSort SortPolicy
- * class for the NeighborSearch class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_NEIGHBOR_NEAREST_NEIGHBOR_SORT_IMPL_HPP
-#define __MLPACK_NEIGHBOR_NEAREST_NEIGHBOR_SORT_IMPL_HPP
-
-namespace mlpack {
-namespace neighbor {
-
-template<typename TreeType>
-inline double NearestNeighborSort::BestNodeToNodeDistance(
- const TreeType* queryNode,
- const TreeType* referenceNode)
-{
- // This is not implemented yet for the general case because the trees do not
- // accept arbitrary distance metrics.
- return queryNode->MinDistance(referenceNode);
-}
-
-template<typename TreeType>
-inline double NearestNeighborSort::BestNodeToNodeDistance(
- const TreeType* queryNode,
- const TreeType* referenceNode,
- const double centerToCenterDistance)
-{
- return queryNode->MinDistance(referenceNode, centerToCenterDistance);
-}
-
-template<typename TreeType>
-inline double NearestNeighborSort::BestNodeToNodeDistance(
- const TreeType* queryNode,
- const TreeType* /* referenceNode */,
- const TreeType* referenceChildNode,
- const double centerToCenterDistance)
-{
- return queryNode->MinDistance(referenceChildNode, centerToCenterDistance) -
- referenceChildNode->ParentDistance();
-}
-
-template<typename TreeType>
-inline double NearestNeighborSort::BestPointToNodeDistance(
- const arma::vec& point,
- const TreeType* referenceNode)
-{
- // This is not implemented yet for the general case because the trees do not
- // accept arbitrary distance metrics.
- return referenceNode->MinDistance(point);
-}
-
-template<typename TreeType>
-inline double NearestNeighborSort::BestPointToNodeDistance(
- const arma::vec& point,
- const TreeType* referenceNode,
- const double pointToCenterDistance)
-{
- return referenceNode->MinDistance(point, pointToCenterDistance);
-}
-
-}; // namespace neighbor
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,81 @@
+/**
+ * @file nearest_neighbor_sort_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of templated methods for the NearestNeighborSort SortPolicy
+ * class for the NeighborSearch class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_NEIGHBOR_NEAREST_NEIGHBOR_SORT_IMPL_HPP
+#define __MLPACK_NEIGHBOR_NEAREST_NEIGHBOR_SORT_IMPL_HPP
+
+namespace mlpack {
+namespace neighbor {
+
+template<typename TreeType>
+inline double NearestNeighborSort::BestNodeToNodeDistance(
+ const TreeType* queryNode,
+ const TreeType* referenceNode)
+{
+ // This is not implemented yet for the general case because the trees do not
+ // accept arbitrary distance metrics.
+ return queryNode->MinDistance(referenceNode);
+}
+
+template<typename TreeType>
+inline double NearestNeighborSort::BestNodeToNodeDistance(
+ const TreeType* queryNode,
+ const TreeType* referenceNode,
+ const double centerToCenterDistance)
+{
+ return queryNode->MinDistance(referenceNode, centerToCenterDistance);
+}
+
+template<typename TreeType>
+inline double NearestNeighborSort::BestNodeToNodeDistance(
+ const TreeType* queryNode,
+ const TreeType* /* referenceNode */,
+ const TreeType* referenceChildNode,
+ const double centerToCenterDistance)
+{
+ return queryNode->MinDistance(referenceChildNode, centerToCenterDistance) -
+ referenceChildNode->ParentDistance();
+}
+
+template<typename TreeType>
+inline double NearestNeighborSort::BestPointToNodeDistance(
+ const arma::vec& point,
+ const TreeType* referenceNode)
+{
+ // This is not implemented yet for the general case because the trees do not
+ // accept arbitrary distance metrics.
+ return referenceNode->MinDistance(point);
+}
+
+template<typename TreeType>
+inline double NearestNeighborSort::BestPointToNodeDistance(
+ const arma::vec& point,
+ const TreeType* referenceNode,
+ const double pointToCenterDistance)
+{
+ return referenceNode->MinDistance(point, pointToCenterDistance);
+}
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/typedef.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/typedef.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/typedef.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,53 +0,0 @@
-/**
- * @file typedef.hpp
- * @author Ryan Curtin
- *
- * Simple typedefs describing template instantiations of the NeighborSearch
- * class which are commonly used. This is meant to be included by
- * neighbor_search.h but is a separate file for simplicity.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_NEIGHBOR_SEARCH_TYPEDEF_H
-#define __MLPACK_NEIGHBOR_SEARCH_TYPEDEF_H
-
-// In case someone included this directly.
-#include "neighbor_search.hpp"
-
-#include <mlpack/core/metrics/lmetric.hpp>
-
-#include "sort_policies/nearest_neighbor_sort.hpp"
-#include "sort_policies/furthest_neighbor_sort.hpp"
-
-namespace mlpack {
-namespace neighbor {
-
-/**
- * The AllkNN class is the all-k-nearest-neighbors method. It returns L2
- * distances (Euclidean distances) for each of the k nearest neighbors.
- */
-typedef NeighborSearch<NearestNeighborSort, metric::EuclideanDistance> AllkNN;
-
-/**
- * The AllkFN class is the all-k-furthest-neighbors method. It returns L2
- * distances (Euclidean distances) for each of the k furthest neighbors.
- */
-typedef NeighborSearch<FurthestNeighborSort, metric::EuclideanDistance> AllkFN;
-
-}; // namespace neighbor
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/typedef.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/typedef.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/typedef.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/typedef.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,53 @@
+/**
+ * @file typedef.hpp
+ * @author Ryan Curtin
+ *
+ * Simple typedefs describing template instantiations of the NeighborSearch
+ * class which are commonly used. This is meant to be included by
+ * neighbor_search.h but is a separate file for simplicity.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_NEIGHBOR_SEARCH_TYPEDEF_H
+#define __MLPACK_NEIGHBOR_SEARCH_TYPEDEF_H
+
+// In case someone included this directly.
+#include "neighbor_search.hpp"
+
+#include <mlpack/core/metrics/lmetric.hpp>
+
+#include "sort_policies/nearest_neighbor_sort.hpp"
+#include "sort_policies/furthest_neighbor_sort.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+/**
+ * The AllkNN class is the all-k-nearest-neighbors method. It returns L2
+ * distances (Euclidean distances) for each of the k nearest neighbors.
+ */
+typedef NeighborSearch<NearestNeighborSort, metric::EuclideanDistance> AllkNN;
+
+/**
+ * The AllkFN class is the all-k-furthest-neighbors method. It returns L2
+ * distances (Euclidean distances) for each of the k furthest neighbors.
+ */
+typedef NeighborSearch<FurthestNeighborSort, metric::EuclideanDistance> AllkFN;
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/unmap.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/unmap.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/unmap.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,79 +0,0 @@
-/**
- * @file unmap.cpp
- * @author Ryan Curtin
- *
- * Auxiliary function to unmap neighbor search results.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "unmap.hpp"
-
-namespace mlpack {
-namespace neighbor {
-
-// Useful in the dual-tree setting.
-void Unmap(const arma::Mat<size_t>& neighbors,
- const arma::mat& distances,
- const std::vector<size_t>& referenceMap,
- const std::vector<size_t>& queryMap,
- arma::Mat<size_t>& neighborsOut,
- arma::mat& distancesOut,
- const bool squareRoot)
-{
- // Set matrices to correct size.
- neighborsOut.set_size(neighbors.n_rows, neighbors.n_cols);
- distancesOut.set_size(distances.n_rows, distances.n_cols);
-
- // Unmap distances.
- for (size_t i = 0; i < distances.n_cols; ++i)
- {
- // Map columns to the correct place. The ternary operator does not work
- // here...
- if (squareRoot)
- distancesOut.col(queryMap[i]) = sqrt(distances.col(i));
- else
- distancesOut.col(queryMap[i]) = distances.col(i);
-
- // Map indices of neighbors.
- for (size_t j = 0; j < distances.n_rows; ++j)
- neighborsOut(j, queryMap[i]) = referenceMap[neighbors(j, i)];
- }
-}
-
-// Useful in the single-tree setting.
-void Unmap(const arma::Mat<size_t>& neighbors,
- const arma::mat& distances,
- const std::vector<size_t>& referenceMap,
- arma::Mat<size_t>& neighborsOut,
- arma::mat& distancesOut,
- const bool squareRoot)
-{
- // Set matrices to correct size.
- neighborsOut.set_size(neighbors.n_rows, neighbors.n_cols);
-
- // Take square root of distances, if necessary.
- if (squareRoot)
- distancesOut = sqrt(distances);
- else
- distancesOut = distances;
-
- // Map neighbors back to original locations.
- for (size_t j = 0; j < neighbors.n_elem; ++j)
- neighborsOut[j] = referenceMap[neighbors[j]];
-}
-
-}; // namespace neighbor
-}; // namespace mlpack
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/unmap.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/unmap.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/unmap.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/unmap.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,79 @@
+/**
+ * @file unmap.cpp
+ * @author Ryan Curtin
+ *
+ * Auxiliary function to unmap neighbor search results.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "unmap.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+// Useful in the dual-tree setting.
+void Unmap(const arma::Mat<size_t>& neighbors,
+ const arma::mat& distances,
+ const std::vector<size_t>& referenceMap,
+ const std::vector<size_t>& queryMap,
+ arma::Mat<size_t>& neighborsOut,
+ arma::mat& distancesOut,
+ const bool squareRoot)
+{
+ // Set matrices to correct size.
+ neighborsOut.set_size(neighbors.n_rows, neighbors.n_cols);
+ distancesOut.set_size(distances.n_rows, distances.n_cols);
+
+ // Unmap distances.
+ for (size_t i = 0; i < distances.n_cols; ++i)
+ {
+ // Map columns to the correct place. The ternary operator does not work
+ // here...
+ if (squareRoot)
+ distancesOut.col(queryMap[i]) = sqrt(distances.col(i));
+ else
+ distancesOut.col(queryMap[i]) = distances.col(i);
+
+ // Map indices of neighbors.
+ for (size_t j = 0; j < distances.n_rows; ++j)
+ neighborsOut(j, queryMap[i]) = referenceMap[neighbors(j, i)];
+ }
+}
+
+// Useful in the single-tree setting.
+void Unmap(const arma::Mat<size_t>& neighbors,
+ const arma::mat& distances,
+ const std::vector<size_t>& referenceMap,
+ arma::Mat<size_t>& neighborsOut,
+ arma::mat& distancesOut,
+ const bool squareRoot)
+{
+ // Set matrices to correct size.
+ neighborsOut.set_size(neighbors.n_rows, neighbors.n_cols);
+
+ // Take square root of distances, if necessary.
+ if (squareRoot)
+ distancesOut = sqrt(distances);
+ else
+ distancesOut = distances;
+
+ // Map neighbors back to original locations.
+ for (size_t j = 0; j < neighbors.n_elem; ++j)
+ neighborsOut[j] = referenceMap[neighbors[j]];
+}
+
+}; // namespace neighbor
+}; // namespace mlpack
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/unmap.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/unmap.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/unmap.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,76 +0,0 @@
-/**
- * @file unmap.hpp
- * @author Ryan Curtin
- *
- * Convenience methods to unmap results.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_UNMAP_HPP
-#define __MLPACK_METHODS_NEIGHBOR_SEARCH_UNMAP_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace neighbor {
-
-/**
- * Assuming that the datasets have been mapped using the referenceMap and the
- * queryMap (such as during kd-tree construction), unmap the columns of the
- * distances and neighbors matrices into neighborsOut and distancesOut, and also
- * unmap the entries in each row of neighbors. This is useful for the dual-tree
- * case.
- *
- * @param neighbors Matrix of neighbors resulting from neighbor search.
- * @param distances Matrix of distances resulting from neighbor search.
- * @param referenceMap Mapping of reference set to old points.
- * @param queryMap Mapping of query set to old points.
- * @param neighborsOut Matrix to store unmapped neighbors into.
- * @param distancesOut Matrix to store unmapped distances into.
- * @param squareRoot If true, take the square root of the distances.
- */
-void Unmap(const arma::Mat<size_t>& neighbors,
- const arma::mat& distances,
- const std::vector<size_t>& referenceMap,
- const std::vector<size_t>& queryMap,
- arma::Mat<size_t>& neighborsOut,
- arma::mat& distancesOut,
- const bool squareRoot = false);
-
-/**
- * Assuming that the datasets have been mapped using referenceMap (such as
- * during kd-tree construction), unmap the columns of the distances and
- * neighbors matrices into neighborsOut and distancesOut, and also unmap the
- * entries in each row of neighbors. This is useful for the single-tree case.
- *
- * @param neighbors Matrix of neighbors resulting from neighbor search.
- * @param distances Matrix of distances resulting from neighbor search.
- * @param referenceMap Mapping of reference set to old points.
- * @param neighborsOut Matrix to store unmapped neighbors into.
- * @param distancesOut Matrix to store unmapped distances into.
- * @param squareRoot If true, take the square root of the distances.
- */
-void Unmap(const arma::Mat<size_t>& neighbors,
- const arma::mat& distances,
- const std::vector<size_t>& referenceMap,
- arma::Mat<size_t>& neighborsOut,
- arma::mat& distancesOut,
- const bool squareRoot = false);
-
-}; // namespace neighbor
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/unmap.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/neighbor_search/unmap.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/unmap.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/neighbor_search/unmap.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,76 @@
+/**
+ * @file unmap.hpp
+ * @author Ryan Curtin
+ *
+ * Convenience methods to unmap results.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_UNMAP_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_UNMAP_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace neighbor {
+
+/**
+ * Assuming that the datasets have been mapped using the referenceMap and the
+ * queryMap (such as during kd-tree construction), unmap the columns of the
+ * distances and neighbors matrices into neighborsOut and distancesOut, and also
+ * unmap the entries in each row of neighbors. This is useful for the dual-tree
+ * case.
+ *
+ * @param neighbors Matrix of neighbors resulting from neighbor search.
+ * @param distances Matrix of distances resulting from neighbor search.
+ * @param referenceMap Mapping of reference set to old points.
+ * @param queryMap Mapping of query set to old points.
+ * @param neighborsOut Matrix to store unmapped neighbors into.
+ * @param distancesOut Matrix to store unmapped distances into.
+ * @param squareRoot If true, take the square root of the distances.
+ */
+void Unmap(const arma::Mat<size_t>& neighbors,
+ const arma::mat& distances,
+ const std::vector<size_t>& referenceMap,
+ const std::vector<size_t>& queryMap,
+ arma::Mat<size_t>& neighborsOut,
+ arma::mat& distancesOut,
+ const bool squareRoot = false);
+
+/**
+ * Assuming that the datasets have been mapped using referenceMap (such as
+ * during kd-tree construction), unmap the columns of the distances and
+ * neighbors matrices into neighborsOut and distancesOut, and also unmap the
+ * entries in each row of neighbors. This is useful for the single-tree case.
+ *
+ * @param neighbors Matrix of neighbors resulting from neighbor search.
+ * @param distances Matrix of distances resulting from neighbor search.
+ * @param referenceMap Mapping of reference set to old points.
+ * @param neighborsOut Matrix to store unmapped neighbors into.
+ * @param distancesOut Matrix to store unmapped distances into.
+ * @param squareRoot If true, take the square root of the distances.
+ */
+void Unmap(const arma::Mat<size_t>& neighbors,
+ const arma::mat& distances,
+ const std::vector<size_t>& referenceMap,
+ arma::Mat<size_t>& neighborsOut,
+ arma::mat& distancesOut,
+ const bool squareRoot = false);
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/als_update_rules.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/nmf/als_update_rules.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/als_update_rules.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,115 +0,0 @@
-/**
- * @file als_update_rules.hpp
- * @author Mohan Rajendran
- *
- * Update rules for the Non-negative Matrix Factorization. This follows a method
- * titled 'Alternating Least Squares' described in the paper 'Positive Matrix
- * Factorization: A Non-negative Factor Model with Optimal Utilization of
- * Error Estimates of Data Values' by P. Paatero and U. Tapper. It uses least
- * squares projection formula to reduce the error value of
- * \f$ \sqrt{\sum_i \sum_j(V-WH)^2} \f$ by alternately calculating W and H
- * respectively while holding the other matrix constant.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NMF_ALS_UPDATE_RULES_HPP
-#define __MLPACK_METHODS_NMF_ALS_UPDATE_RULES_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace nmf {
-
-/**
- * The update rule for the basis matrix W. The formula used is
- * \f[
- * W^T = \frac{HV^T}{HH^T}
- * \f]
- */
-class WAlternatingLeastSquaresRule
-{
- public:
- // Empty constructor required for the WUpdateRule template.
- WAlternatingLeastSquaresRule() { }
-
- /**
- * The update function that actually updates the W matrix. The function takes
- * in all the matrices and only changes the value of the W matrix.
- *
- * @param V Input matrix to be factorized.
- * @param W Basis matrix to be updated.
- * @param H Encoding matrix.
- */
- inline static void Update(const arma::mat& V,
- arma::mat& W,
- const arma::mat& H)
- {
- // The call to inv() sometimes fails; so we are using the psuedoinverse.
- // W = (inv(H * H.t()) * H * V.t()).t();
- W = V * H.t() * pinv(H * H.t());
-
- // Set all negative numbers to machine epsilon
- for (size_t i = 0; i < W.n_elem; i++)
- {
- if (W(i) < 0.0)
- {
- W(i) = 0.0;
- }
- }
- }
-};
-
-/**
- * The update rule for the encoding matrix H. The formula used is
- * \f[
- * H = \frac{W^TV}{W^TW}
- * \f]
- */
-class HAlternatingLeastSquaresRule
-{
- public:
- // Empty constructor required for the HUpdateRule template.
- HAlternatingLeastSquaresRule() { }
-
- /**
- * The update function that actually updates the H matrix. The function takes
- * in all the matrices and only changes the value of the H matrix.
- *
- * @param V Input matrix to be factorized.
- * @param W Basis matrix.
- * @param H Encoding matrix to be updated.
- */
- inline static void Update(const arma::mat& V,
- const arma::mat& W,
- arma::mat& H)
- {
- H = pinv(W.t() * W) * W.t() * V;
-
- // Set all negative numbers to 0.
- for (size_t i = 0; i < H.n_elem; i++)
- {
- if (H(i) < 0.0)
- {
- H(i) = 0.0;
- }
- }
- }
-};
-
-}; // namespace nmf
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/als_update_rules.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/nmf/als_update_rules.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/als_update_rules.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/als_update_rules.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,115 @@
+/**
+ * @file als_update_rules.hpp
+ * @author Mohan Rajendran
+ *
+ * Update rules for the Non-negative Matrix Factorization. This follows a method
+ * titled 'Alternating Least Squares' described in the paper 'Positive Matrix
+ * Factorization: A Non-negative Factor Model with Optimal Utilization of
+ * Error Estimates of Data Values' by P. Paatero and U. Tapper. It uses least
+ * squares projection formula to reduce the error value of
+ * \f$ \sqrt{\sum_i \sum_j(V-WH)^2} \f$ by alternately calculating W and H
+ * respectively while holding the other matrix constant.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NMF_ALS_UPDATE_RULES_HPP
+#define __MLPACK_METHODS_NMF_ALS_UPDATE_RULES_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace nmf {
+
+/**
+ * The update rule for the basis matrix W. The formula used is
+ * \f[
+ * W^T = \frac{HV^T}{HH^T}
+ * \f]
+ */
+class WAlternatingLeastSquaresRule
+{
+ public:
+ // Empty constructor required for the WUpdateRule template.
+ WAlternatingLeastSquaresRule() { }
+
+ /**
+ * The update function that actually updates the W matrix. The function takes
+ * in all the matrices and only changes the value of the W matrix.
+ *
+ * @param V Input matrix to be factorized.
+ * @param W Basis matrix to be updated.
+ * @param H Encoding matrix.
+ */
+ inline static void Update(const arma::mat& V,
+ arma::mat& W,
+ const arma::mat& H)
+ {
+ // The call to inv() sometimes fails; so we are using the psuedoinverse.
+ // W = (inv(H * H.t()) * H * V.t()).t();
+ W = V * H.t() * pinv(H * H.t());
+
+ // Set all negative numbers to machine epsilon
+ for (size_t i = 0; i < W.n_elem; i++)
+ {
+ if (W(i) < 0.0)
+ {
+ W(i) = 0.0;
+ }
+ }
+ }
+};
+
+/**
+ * The update rule for the encoding matrix H. The formula used is
+ * \f[
+ * H = \frac{W^TV}{W^TW}
+ * \f]
+ */
+class HAlternatingLeastSquaresRule
+{
+ public:
+ // Empty constructor required for the HUpdateRule template.
+ HAlternatingLeastSquaresRule() { }
+
+ /**
+ * The update function that actually updates the H matrix. The function takes
+ * in all the matrices and only changes the value of the H matrix.
+ *
+ * @param V Input matrix to be factorized.
+ * @param W Basis matrix.
+ * @param H Encoding matrix to be updated.
+ */
+ inline static void Update(const arma::mat& V,
+ const arma::mat& W,
+ arma::mat& H)
+ {
+ H = pinv(W.t() * W) * W.t() * V;
+
+ // Set all negative numbers to 0.
+ for (size_t i = 0; i < H.n_elem; i++)
+ {
+ if (H(i) < 0.0)
+ {
+ H(i) = 0.0;
+ }
+ }
+ }
+};
+
+}; // namespace nmf
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/mult_dist_update_rules.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/nmf/mult_dist_update_rules.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/mult_dist_update_rules.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,96 +0,0 @@
-/**
- * @file mult_dist_update_rules.hpp
- * @author Mohan Rajendran
- *
- * Update rules for the Non-negative Matrix Factorization. This follows a method
- * described in the paper 'Algorithms for Non-negative Matrix Factorization'
- * by D. D. Lee and H. S. Seung. This is a multiplicative rule that ensures
- * that the Frobenius norm \f$ \sqrt{\sum_i \sum_j(V-WH)^2} \f$ is
- * non-increasing between subsequent iterations. Both of the update rules
- * for W and H are defined in this file.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NMF_MULT_DIST_UPDATE_RULES_HPP
-#define __MLPACK_METHODS_NMF_MULT_DIST_UPDATE_RULES_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace nmf {
-
-/**
- * The update rule for the basis matrix W. The formula used is
- * \f[
- * W_{ia} \leftarrow W_{ia} \frac{(VH^T)_{ia}}{(WHH^T)_{ia}}
- * \f]
- */
-class WMultiplicativeDistanceRule
-{
- public:
- // Empty constructor required for the WUpdateRule template.
- WMultiplicativeDistanceRule() { }
-
- /**
- * The update function that actually updates the W matrix. The function takes
- * in all the matrices and only changes the value of the W matrix.
- *
- * @param V Input matrix to be factorized.
- * @param W Basis matrix to be updated.
- * @param H Encoding matrix.
- */
-
- inline static void Update(const arma::mat& V,
- arma::mat& W,
- const arma::mat& H)
- {
- W = (W % (V * H.t())) / (W * H * H.t());
- }
-};
-
-/**
- * The update rule for the encoding matrix H. The formula used is
- * \f[
- * H_{a\mu} \leftarrow H_{a\mu} \frac{(W^T V)_{a\mu}}{(W^T WH)_{a\mu}}
- * \f]
- */
-class HMultiplicativeDistanceRule
-{
- public:
- // Empty constructor required for the HUpdateRule template.
- HMultiplicativeDistanceRule() { }
-
- /**
- * The update function that actually updates the H matrix. The function takes
- * in all the matrices and only changes the value of the H matrix.
- *
- * @param V Input matrix to be factorized.
- * @param W Basis matrix.
- * @param H Encoding matrix to be updated.
- */
-
- inline static void Update(const arma::mat& V,
- const arma::mat& W,
- arma::mat& H)
- {
- H = (H % (W.t() * V)) / (W.t() * W * H);
- }
-};
-
-}; // namespace nmf
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/mult_dist_update_rules.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/nmf/mult_dist_update_rules.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/mult_dist_update_rules.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/mult_dist_update_rules.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,96 @@
+/**
+ * @file mult_dist_update_rules.hpp
+ * @author Mohan Rajendran
+ *
+ * Update rules for the Non-negative Matrix Factorization. This follows a method
+ * described in the paper 'Algorithms for Non-negative Matrix Factorization'
+ * by D. D. Lee and H. S. Seung. This is a multiplicative rule that ensures
+ * that the Frobenius norm \f$ \sqrt{\sum_i \sum_j(V-WH)^2} \f$ is
+ * non-increasing between subsequent iterations. Both of the update rules
+ * for W and H are defined in this file.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NMF_MULT_DIST_UPDATE_RULES_HPP
+#define __MLPACK_METHODS_NMF_MULT_DIST_UPDATE_RULES_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace nmf {
+
+/**
+ * The update rule for the basis matrix W. The formula used is
+ * \f[
+ * W_{ia} \leftarrow W_{ia} \frac{(VH^T)_{ia}}{(WHH^T)_{ia}}
+ * \f]
+ */
+class WMultiplicativeDistanceRule
+{
+ public:
+ // Empty constructor required for the WUpdateRule template.
+ WMultiplicativeDistanceRule() { }
+
+ /**
+ * The update function that actually updates the W matrix. The function takes
+ * in all the matrices and only changes the value of the W matrix.
+ *
+ * @param V Input matrix to be factorized.
+ * @param W Basis matrix to be updated.
+ * @param H Encoding matrix.
+ */
+
+ inline static void Update(const arma::mat& V,
+ arma::mat& W,
+ const arma::mat& H)
+ {
+ W = (W % (V * H.t())) / (W * H * H.t());
+ }
+};
+
+/**
+ * The update rule for the encoding matrix H. The formula used is
+ * \f[
+ * H_{a\mu} \leftarrow H_{a\mu} \frac{(W^T V)_{a\mu}}{(W^T WH)_{a\mu}}
+ * \f]
+ */
+class HMultiplicativeDistanceRule
+{
+ public:
+ // Empty constructor required for the HUpdateRule template.
+ HMultiplicativeDistanceRule() { }
+
+ /**
+ * The update function that actually updates the H matrix. The function takes
+ * in all the matrices and only changes the value of the H matrix.
+ *
+ * @param V Input matrix to be factorized.
+ * @param W Basis matrix.
+ * @param H Encoding matrix to be updated.
+ */
+
+ inline static void Update(const arma::mat& V,
+ const arma::mat& W,
+ arma::mat& H)
+ {
+ H = (H % (W.t() * V)) / (W.t() * W * H);
+ }
+};
+
+}; // namespace nmf
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/mult_div_update_rules.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/nmf/mult_div_update_rules.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/mult_div_update_rules.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,121 +0,0 @@
-/**
- * @file mult_div_update_rules.hpp
- * @author Mohan Rajendran
- *
- * Update rules for the Non-negative Matrix Factorization. This follows a method
- * described in the paper 'Algorithms for Non-negative Matrix Factorization'
- * by D. D. Lee and H. S. Seung. This is a multiplicative rule that ensures
- * that the Kullback–Leibler divergence
- * \f$ \sum_i \sum_j (V_{ij} log\frac{V_{ij}}{(WH)_{ij}}-V_{ij}+(WH)_{ij}) \f$is
- * non-increasing between subsequent iterations. Both of the update rules
- * for W and H are defined in this file.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NMF_MULT_DIV_UPDATE_RULES_HPP
-#define __MLPACK_METHODS_NMF_MULT_DIV_UPDATE_RULES_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace nmf {
-
-/**
- * The update rule for the basis matrix W. The formula used is
- * \f[
- * W_{ia} \leftarrow W_{ia} \frac{\sum_{\mu} H_{a\mu} V_{i\mu}/(WH)_{i\mu}}
- * {\sum_{\nu} H_{a\nu}}
- * \f]
- */
-class WMultiplicativeDivergenceRule
-{
- public:
- // Empty constructor required for the WUpdateRule template.
- WMultiplicativeDivergenceRule() { }
-
- /**
- * The update function that actually updates the W matrix. The function takes
- * in all the matrices and only changes the value of the W matrix.
- *
- * @param V Input matrix to be factorized.
- * @param W Basis matrix to be updated.
- * @param H Encoding matrix.
- */
- inline static void Update(const arma::mat& V,
- arma::mat& W,
- const arma::mat& H)
- {
- // Simple implementation left in the header file.
- arma::mat t1;
- arma::rowvec t2;
-
- t1 = W * H;
- for (size_t i = 0; i < W.n_rows; ++i)
- {
- for (size_t j = 0; j < W.n_cols; ++j)
- {
- t2 = H.row(j) % V.row(i) / t1.row(i);
- W(i, j) = W(i, j) * sum(t2) / sum(H.row(j));
- }
- }
- }
-};
-
-/**
- * The update rule for the encoding matrix H. The formula used is
- * \f[
- * H_{a\mu} \leftarrow H_{a\mu} \frac{\sum_{i} W_{ia} V_{i\mu}/(WH)_{i\mu}}
- * {\sum_{k} H_{ka}}
- * \f]
- */
-class HMultiplicativeDivergenceRule
-{
- public:
- // Empty constructor required for the HUpdateRule template.
- HMultiplicativeDivergenceRule() { }
-
- /**
- * The update function that actually updates the H matrix. The function takes
- * in all the matrices and only changes the value of the H matrix.
- *
- * @param V Input matrix to be factorized.
- * @param W Basis matrix.
- * @param H Encoding matrix to updated.
- */
- inline static void Update(const arma::mat& V,
- const arma::mat& W,
- arma::mat& H)
- {
- // Simple implementation left in the header file.
- arma::mat t1;
- arma::colvec t2;
-
- t1 = W * H;
- for (size_t i = 0; i < H.n_rows; i++)
- {
- for (size_t j = 0; j < H.n_cols; j++)
- {
- t2 = W.col(i) % V.col(j) / t1.col(j);
- H(i,j) = H(i,j) * sum(t2) / sum(W.col(i));
- }
- }
- }
-};
-
-}; // namespace nmf
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/mult_div_update_rules.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/nmf/mult_div_update_rules.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/mult_div_update_rules.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/mult_div_update_rules.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,121 @@
+/**
+ * @file mult_div_update_rules.hpp
+ * @author Mohan Rajendran
+ *
+ * Update rules for the Non-negative Matrix Factorization. This follows a method
+ * described in the paper 'Algorithms for Non-negative Matrix Factorization'
+ * by D. D. Lee and H. S. Seung. This is a multiplicative rule that ensures
+ * that the Kullback–Leibler divergence
+ * \f$ \sum_i \sum_j (V_{ij} log\frac{V_{ij}}{(WH)_{ij}}-V_{ij}+(WH)_{ij}) \f$is
+ * non-increasing between subsequent iterations. Both of the update rules
+ * for W and H are defined in this file.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NMF_MULT_DIV_UPDATE_RULES_HPP
+#define __MLPACK_METHODS_NMF_MULT_DIV_UPDATE_RULES_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace nmf {
+
+/**
+ * The update rule for the basis matrix W. The formula used is
+ * \f[
+ * W_{ia} \leftarrow W_{ia} \frac{\sum_{\mu} H_{a\mu} V_{i\mu}/(WH)_{i\mu}}
+ * {\sum_{\nu} H_{a\nu}}
+ * \f]
+ */
+class WMultiplicativeDivergenceRule
+{
+ public:
+ // Empty constructor required for the WUpdateRule template.
+ WMultiplicativeDivergenceRule() { }
+
+ /**
+ * The update function that actually updates the W matrix. The function takes
+ * in all the matrices and only changes the value of the W matrix.
+ *
+ * @param V Input matrix to be factorized.
+ * @param W Basis matrix to be updated.
+ * @param H Encoding matrix.
+ */
+ inline static void Update(const arma::mat& V,
+ arma::mat& W,
+ const arma::mat& H)
+ {
+ // Simple implementation left in the header file.
+ arma::mat t1;
+ arma::rowvec t2;
+
+ t1 = W * H;
+ for (size_t i = 0; i < W.n_rows; ++i)
+ {
+ for (size_t j = 0; j < W.n_cols; ++j)
+ {
+ t2 = H.row(j) % V.row(i) / t1.row(i);
+ W(i, j) = W(i, j) * sum(t2) / sum(H.row(j));
+ }
+ }
+ }
+};
+
+/**
+ * The update rule for the encoding matrix H. The formula used is
+ * \f[
+ * H_{a\mu} \leftarrow H_{a\mu} \frac{\sum_{i} W_{ia} V_{i\mu}/(WH)_{i\mu}}
+ * {\sum_{k} H_{ka}}
+ * \f]
+ */
+class HMultiplicativeDivergenceRule
+{
+ public:
+ // Empty constructor required for the HUpdateRule template.
+ HMultiplicativeDivergenceRule() { }
+
+ /**
+ * The update function that actually updates the H matrix. The function takes
+ * in all the matrices and only changes the value of the H matrix.
+ *
+ * @param V Input matrix to be factorized.
+ * @param W Basis matrix.
+ * @param H Encoding matrix to updated.
+ */
+ inline static void Update(const arma::mat& V,
+ const arma::mat& W,
+ arma::mat& H)
+ {
+ // Simple implementation left in the header file.
+ arma::mat t1;
+ arma::colvec t2;
+
+ t1 = W * H;
+ for (size_t i = 0; i < H.n_rows; i++)
+ {
+ for (size_t j = 0; j < H.n_cols; j++)
+ {
+ t2 = W.col(i) % V.col(j) / t1.col(j);
+ H(i,j) = H(i,j) * sum(t2) / sum(W.col(i));
+ }
+ }
+ }
+};
+
+}; // namespace nmf
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/nmf.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/nmf/nmf.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/nmf.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,167 +0,0 @@
-/**
- * @file nmf.hpp
- * @author Mohan Rajendran
- *
- * Defines the NMF class to perform Non-negative Matrix Factorization
- * on the given matrix.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NMF_NMF_HPP
-#define __MLPACK_METHODS_NMF_NMF_HPP
-
-#include <mlpack/core.hpp>
-#include "mult_dist_update_rules.hpp"
-#include "random_init.hpp"
-
-namespace mlpack {
-namespace nmf {
-
-/**
- * This class implements the NMF on the given matrix V. Non-negative Matrix
- * Factorization decomposes V in the form \f$ V \approx WH \f$ where W is
- * called the basis matrix and H is called the encoding matrix. V is taken
- * to be of size n x m and the obtained W is n x r and H is r x m. The size r is
- * called the rank of the factorization.
- *
- * The implementation requires two template types; the first contains the update
- * rule for the W matrix during each iteration and the other contains the update
- * rule for the H matrix during each iteration. This templatization allows the
- * user to try various update rules (including ones not supplied with MLPACK)
- * for factorization.
- *
- * A simple example of how to run NMF is shown below.
- *
- * @code
- * extern arma::mat V; // Matrix that we want to perform NMF on.
- * size_t r = 10; // Rank of decomposition
- * arma::mat W; // Basis matrix
- * arma::mat H; // Encoding matrix
- *
- * NMF<> nmf(); // Default options
- * nmf.Apply(V, W, H, r);
- * @endcode
- *
- * For more information on non-negative matrix factorization, see the following
- * paper:
- *
- * @code
- * @article{
- * title = {{Learning the parts of objects by non-negative matrix
- * factorization}},
- * author = {Lee, Daniel D. and Seung, H. Sebastian},
- * journal = {Nature},
- * month = {Oct},
- * year = {1999},
- * number = {6755},
- * pages = {788--791},
- * publisher = {Nature Publishing Group},
- * url = {http://dx.doi.org/10.1038/44565}
- * }
- * @endcode
- *
- * @tparam WUpdateRule The update rule for calculating W matrix at each
- * iteration.
- * @tparam HUpdateRule The update rule for calculating H matrix at each
- * iteration.
- *
- * @see WMultiplicativeDistanceRule, HMultiplicativeDistanceRule
- */
-template<typename InitializationRule = RandomInitialization,
- typename WUpdateRule = WMultiplicativeDistanceRule,
- typename HUpdateRule = HMultiplicativeDistanceRule>
-class NMF
-{
- public:
- /**
- * Create the NMF object and (optionally) set the parameters which NMF will
- * run with. The minimum residue refers to the root mean square of the
- * difference between two subsequent iterations of the product W * H. A low
- * residue indicates that subsequent iterations are not producing much change
- * in W and H. Once the residue goes below the specified minimum residue, the
- * algorithm terminates.
- *
- * @param maxIterations Maximum number of iterations allowed before giving up.
- * A value of 0 indicates no limit.
- * @param minResidue The minimum allowed residue before the algorithm
- * terminates.
- * @param Initialize Optional Initialization object for initializing the
- * W and H matrices
- * @param WUpdate Optional WUpdateRule object; for when the update rule for
- * the W vector has states that it needs to store.
- * @param HUpdate Optional HUpdateRule object; for when the update rule for
- * the H vector has states that it needs to store.
- */
- NMF(const size_t maxIterations = 10000,
- const double minResidue = 1e-10,
- const InitializationRule initializeRule = InitializationRule(),
- const WUpdateRule wUpdate = WUpdateRule(),
- const HUpdateRule hUpdate = HUpdateRule());
-
- /**
- * Apply Non-Negative Matrix Factorization to the provided matrix.
- *
- * @param V Input matrix to be factorized.
- * @param W Basis matrix to be output.
- * @param H Encoding matrix to output.
- * @param r Rank r of the factorization.
- */
- void Apply(const arma::mat& V, const size_t r, arma::mat& W, arma::mat& H)
- const;
-
- private:
- //! The maximum number of iterations allowed before giving up.
- size_t maxIterations;
- //! The minimum residue, below which iteration is considered converged.
- double minResidue;
- //! Instantiated initialization Rule.
- InitializationRule initializeRule;
- //! Instantiated W update rule.
- WUpdateRule wUpdate;
- //! Instantiated H update rule.
- HUpdateRule hUpdate;
-
- public:
- //! Access the maximum number of iterations.
- size_t MaxIterations() const { return maxIterations; }
- //! Modify the maximum number of iterations.
- size_t& MaxIterations() { return maxIterations; }
- //! Access the minimum residue before termination.
- double MinResidue() const { return minResidue; }
- //! Modify the minimum residue before termination.
- double& MinResidue() { return minResidue; }
- //! Access the initialization rule.
- const InitializationRule& InitializeRule() const { return initializeRule; }
- //! Modify the initialization rule.
- InitializationRule& InitializeRule() { return initializeRule; }
- //! Access the W update rule.
- const WUpdateRule& WUpdate() const { return wUpdate; }
- //! Modify the W update rule.
- WUpdateRule& WUpdate() { return wUpdate; }
- //! Access the H update rule.
- const HUpdateRule& HUpdate() const { return hUpdate; }
- //! Modify the H update rule.
- HUpdateRule& HUpdate() { return hUpdate; }
-
-}; // class NMF
-
-}; // namespace nmf
-}; // namespace mlpack
-
-// Include implementation.
-#include "nmf_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/nmf.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/nmf/nmf.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/nmf.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/nmf.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,167 @@
+/**
+ * @file nmf.hpp
+ * @author Mohan Rajendran
+ *
+ * Defines the NMF class to perform Non-negative Matrix Factorization
+ * on the given matrix.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NMF_NMF_HPP
+#define __MLPACK_METHODS_NMF_NMF_HPP
+
+#include <mlpack/core.hpp>
+#include "mult_dist_update_rules.hpp"
+#include "random_init.hpp"
+
+namespace mlpack {
+namespace nmf {
+
+/**
+ * This class implements the NMF on the given matrix V. Non-negative Matrix
+ * Factorization decomposes V in the form \f$ V \approx WH \f$ where W is
+ * called the basis matrix and H is called the encoding matrix. V is taken
+ * to be of size n x m and the obtained W is n x r and H is r x m. The size r is
+ * called the rank of the factorization.
+ *
+ * The implementation requires two template types; the first contains the update
+ * rule for the W matrix during each iteration and the other contains the update
+ * rule for the H matrix during each iteration. This templatization allows the
+ * user to try various update rules (including ones not supplied with MLPACK)
+ * for factorization.
+ *
+ * A simple example of how to run NMF is shown below.
+ *
+ * @code
+ * extern arma::mat V; // Matrix that we want to perform NMF on.
+ * size_t r = 10; // Rank of decomposition
+ * arma::mat W; // Basis matrix
+ * arma::mat H; // Encoding matrix
+ *
+ * NMF<> nmf(); // Default options
+ * nmf.Apply(V, W, H, r);
+ * @endcode
+ *
+ * For more information on non-negative matrix factorization, see the following
+ * paper:
+ *
+ * @code
+ * @article{
+ * title = {{Learning the parts of objects by non-negative matrix
+ * factorization}},
+ * author = {Lee, Daniel D. and Seung, H. Sebastian},
+ * journal = {Nature},
+ * month = {Oct},
+ * year = {1999},
+ * number = {6755},
+ * pages = {788--791},
+ * publisher = {Nature Publishing Group},
+ * url = {http://dx.doi.org/10.1038/44565}
+ * }
+ * @endcode
+ *
+ * @tparam WUpdateRule The update rule for calculating W matrix at each
+ * iteration.
+ * @tparam HUpdateRule The update rule for calculating H matrix at each
+ * iteration.
+ *
+ * @see WMultiplicativeDistanceRule, HMultiplicativeDistanceRule
+ */
+template<typename InitializationRule = RandomInitialization,
+ typename WUpdateRule = WMultiplicativeDistanceRule,
+ typename HUpdateRule = HMultiplicativeDistanceRule>
+class NMF
+{
+ public:
+ /**
+ * Create the NMF object and (optionally) set the parameters which NMF will
+ * run with. The minimum residue refers to the root mean square of the
+ * difference between two subsequent iterations of the product W * H. A low
+ * residue indicates that subsequent iterations are not producing much change
+ * in W and H. Once the residue goes below the specified minimum residue, the
+ * algorithm terminates.
+ *
+ * @param maxIterations Maximum number of iterations allowed before giving up.
+ * A value of 0 indicates no limit.
+ * @param minResidue The minimum allowed residue before the algorithm
+ * terminates.
+ * @param Initialize Optional Initialization object for initializing the
+ * W and H matrices
+ * @param WUpdate Optional WUpdateRule object; for when the update rule for
+ * the W vector has states that it needs to store.
+ * @param HUpdate Optional HUpdateRule object; for when the update rule for
+ * the H vector has states that it needs to store.
+ */
+ NMF(const size_t maxIterations = 10000,
+ const double minResidue = 1e-10,
+ const InitializationRule initializeRule = InitializationRule(),
+ const WUpdateRule wUpdate = WUpdateRule(),
+ const HUpdateRule hUpdate = HUpdateRule());
+
+ /**
+ * Apply Non-Negative Matrix Factorization to the provided matrix.
+ *
+ * @param V Input matrix to be factorized.
+ * @param W Basis matrix to be output.
+ * @param H Encoding matrix to output.
+ * @param r Rank r of the factorization.
+ */
+ void Apply(const arma::mat& V, const size_t r, arma::mat& W, arma::mat& H)
+ const;
+
+ private:
+ //! The maximum number of iterations allowed before giving up.
+ size_t maxIterations;
+ //! The minimum residue, below which iteration is considered converged.
+ double minResidue;
+ //! Instantiated initialization Rule.
+ InitializationRule initializeRule;
+ //! Instantiated W update rule.
+ WUpdateRule wUpdate;
+ //! Instantiated H update rule.
+ HUpdateRule hUpdate;
+
+ public:
+ //! Access the maximum number of iterations.
+ size_t MaxIterations() const { return maxIterations; }
+ //! Modify the maximum number of iterations.
+ size_t& MaxIterations() { return maxIterations; }
+ //! Access the minimum residue before termination.
+ double MinResidue() const { return minResidue; }
+ //! Modify the minimum residue before termination.
+ double& MinResidue() { return minResidue; }
+ //! Access the initialization rule.
+ const InitializationRule& InitializeRule() const { return initializeRule; }
+ //! Modify the initialization rule.
+ InitializationRule& InitializeRule() { return initializeRule; }
+ //! Access the W update rule.
+ const WUpdateRule& WUpdate() const { return wUpdate; }
+ //! Modify the W update rule.
+ WUpdateRule& WUpdate() { return wUpdate; }
+ //! Access the H update rule.
+ const HUpdateRule& HUpdate() const { return hUpdate; }
+ //! Modify the H update rule.
+ HUpdateRule& HUpdate() { return hUpdate; }
+
+}; // class NMF
+
+}; // namespace nmf
+}; // namespace mlpack
+
+// Include implementation.
+#include "nmf_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/nmf_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/nmf/nmf_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/nmf_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,112 +0,0 @@
-/**
- * @file nmf.cpp
- * @author Mohan Rajendran
- *
- * Implementation of NMF class to perform Non-Negative Matrix Factorization
- * on the given matrix.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-namespace mlpack {
-namespace nmf {
-
-/**
- * Construct the NMF object.
- */
-template<typename InitializationRule,
- typename WUpdateRule,
- typename HUpdateRule>
-NMF<InitializationRule, WUpdateRule, HUpdateRule>::NMF(
- const size_t maxIterations,
- const double minResidue,
- const InitializationRule initializeRule,
- const WUpdateRule wUpdate,
- const HUpdateRule hUpdate) :
- maxIterations(maxIterations),
- minResidue(minResidue),
- initializeRule(initializeRule),
- wUpdate(wUpdate),
- hUpdate(hUpdate)
-{
- if (minResidue < 0.0)
- {
- Log::Warn << "NMF::NMF(): minResidue must be a positive value ("
- << minResidue << " given). Setting to the default value of 1e-10.\n";
- this->minResidue = 1e-10;
- }
-}
-
-/**
- * Apply Non-Negative Matrix Factorization to the provided matrix.
- *
- * @param V Input matrix to be factorized
- * @param W Basis matrix to be output
- * @param H Encoding matrix to output
- * @param r Rank r of the factorization
- */
-template<typename InitializationRule,
- typename WUpdateRule,
- typename HUpdateRule>
-void NMF<InitializationRule, WUpdateRule, HUpdateRule>::Apply(
- const arma::mat& V,
- const size_t r,
- arma::mat& W,
- arma::mat& H) const
-{
- const size_t n = V.n_rows;
- const size_t m = V.n_cols;
-
- // Initialize W and H.
- initializeRule.Initialize(V, r, W, H);
-
- Log::Info << "Initialized W and H." << std::endl;
-
- size_t iteration = 1;
- const size_t nm = n * m;
- double residue = minResidue;
- double normOld = 0;
- double norm = 0;
- arma::mat WH;
-
- while (residue >= minResidue && iteration != maxIterations)
- {
- // Update step.
- // Update the value of W and H based on the Update Rules provided
- wUpdate.Update(V, W, H);
- hUpdate.Update(V, W, H);
-
- // Calculate norm of WH after each iteration.
- WH = W * H;
- norm = sqrt(accu(WH % WH) / nm);
-
- if (iteration != 0)
- {
- residue = fabs(normOld - norm);
- residue /= normOld;
- }
-
- normOld = norm;
-
- iteration++;
- }
-
- Log::Info << "NMF converged to residue of " << sqrt(residue) << " in "
- << iteration << " iterations." << std::endl;
-}
-
-}; // namespace nmf
-}; // namespace mlpack
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/nmf_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/nmf/nmf_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/nmf_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/nmf_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,112 @@
+/**
+ * @file nmf.cpp
+ * @author Mohan Rajendran
+ *
+ * Implementation of NMF class to perform Non-Negative Matrix Factorization
+ * on the given matrix.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+namespace mlpack {
+namespace nmf {
+
+/**
+ * Construct the NMF object.
+ */
+template<typename InitializationRule,
+ typename WUpdateRule,
+ typename HUpdateRule>
+NMF<InitializationRule, WUpdateRule, HUpdateRule>::NMF(
+ const size_t maxIterations,
+ const double minResidue,
+ const InitializationRule initializeRule,
+ const WUpdateRule wUpdate,
+ const HUpdateRule hUpdate) :
+ maxIterations(maxIterations),
+ minResidue(minResidue),
+ initializeRule(initializeRule),
+ wUpdate(wUpdate),
+ hUpdate(hUpdate)
+{
+ if (minResidue < 0.0)
+ {
+ Log::Warn << "NMF::NMF(): minResidue must be a positive value ("
+ << minResidue << " given). Setting to the default value of 1e-10.\n";
+ this->minResidue = 1e-10;
+ }
+}
+
+/**
+ * Apply Non-Negative Matrix Factorization to the provided matrix.
+ *
+ * @param V Input matrix to be factorized
+ * @param W Basis matrix to be output
+ * @param H Encoding matrix to output
+ * @param r Rank r of the factorization
+ */
+template<typename InitializationRule,
+ typename WUpdateRule,
+ typename HUpdateRule>
+void NMF<InitializationRule, WUpdateRule, HUpdateRule>::Apply(
+ const arma::mat& V,
+ const size_t r,
+ arma::mat& W,
+ arma::mat& H) const
+{
+ const size_t n = V.n_rows;
+ const size_t m = V.n_cols;
+
+ // Initialize W and H.
+ initializeRule.Initialize(V, r, W, H);
+
+ Log::Info << "Initialized W and H." << std::endl;
+
+ size_t iteration = 1;
+ const size_t nm = n * m;
+ double residue = minResidue;
+ double normOld = 0;
+ double norm = 0;
+ arma::mat WH;
+
+ while (residue >= minResidue && iteration != maxIterations)
+ {
+ // Update step.
+ // Update the value of W and H based on the Update Rules provided
+ wUpdate.Update(V, W, H);
+ hUpdate.Update(V, W, H);
+
+ // Calculate norm of WH after each iteration.
+ WH = W * H;
+ norm = sqrt(accu(WH % WH) / nm);
+
+ if (iteration != 0)
+ {
+ residue = fabs(normOld - norm);
+ residue /= normOld;
+ }
+
+ normOld = norm;
+
+ iteration++;
+ }
+
+ Log::Info << "NMF converged to residue of " << sqrt(residue) << " in "
+ << iteration << " iterations." << std::endl;
+}
+
+}; // namespace nmf
+}; // namespace mlpack
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/nmf_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/nmf/nmf_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/nmf_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,147 +0,0 @@
-/**
- * @file nmf_main.cpp
- * @author Mohan Rajendran
- *
- * Main executable to run NMF.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-
-#include "nmf.hpp"
-
-#include "random_init.hpp"
-#include "mult_dist_update_rules.hpp"
-#include "mult_div_update_rules.hpp"
-#include "als_update_rules.hpp"
-
-using namespace mlpack;
-using namespace mlpack::nmf;
-using namespace std;
-
-// Document program.
-PROGRAM_INFO("Non-negative Matrix Factorization", "This program performs "
- "non-negative matrix factorization on the given dataset, storing the "
- "resulting decomposed matrices in the specified files. For an input "
- "dataset V, NMF decomposes V into two matrices W and H such that "
- "\n\n"
- "V = W * H"
- "\n\n"
- "where all elements in W and H are non-negative. If V is of size (n x m),"
- " then W will be of size (n x r) and H will be of size (r x m), where r is "
- "the rank of the factorization (specified by --rank)."
- "\n\n"
- "Optionally, the desired update rules for each NMF iteration can be chosen "
- "from the following list:"
- "\n\n"
- " - multdist: multiplicative distance-based update rules (Lee and Seung "
- "1999)\n"
- " - multdiv: multiplicative divergence-based update rules (Lee and Seung "
- "1999)\n"
- " - als: alternating least squares update rules (Paatero and Tapper 1994)"
- "\n\n"
- "The maximum number of iterations is specified with --max_iterations, and "
- "the minimum residue required for algorithm termination is specified with "
- "--min_residue.");
-
-// Parameters for program.
-PARAM_STRING_REQ("input_file", "Input dataset to perform NMF on.", "i");
-PARAM_STRING_REQ("w_file", "File to save the calculated W matrix to.", "W");
-PARAM_STRING_REQ("h_file", "File to save the calculated H matrix to.", "H");
-PARAM_INT_REQ("rank", "Rank of the factorization.", "r");
-
-PARAM_INT("max_iterations", "Number of iterations before NMF terminates (0 runs"
- " until convergence.", "m", 10000);
-PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
-PARAM_DOUBLE("min_residue", "The minimum root mean square residue allowed for "
- "each iteration, below which the program terminates.", "e", 1e-5);
-
-PARAM_STRING("update_rules", "Update rules for each iteration; ( multdist | "
- "multdiv | als ).", "u", "multdist");
-
-int main(int argc, char** argv)
-{
- // Parse command line.
- CLI::ParseCommandLine(argc, argv);
-
- // Initialize random seed.
- if (CLI::GetParam<int>("seed") != 0)
- math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
- else
- math::RandomSeed((size_t) std::time(NULL));
-
- // Gather parameters.
- const string inputFile = CLI::GetParam<string>("input_file");
- const string hOutputFile = CLI::GetParam<string>("h_file");
- const string wOutputFile = CLI::GetParam<string>("w_file");
- const size_t r = CLI::GetParam<int>("rank");
- const size_t maxIterations = CLI::GetParam<int>("max_iterations");
- const double minResidue = CLI::GetParam<double>("min_residue");
- const string updateRules = CLI::GetParam<string>("update_rules");
-
- // Validate rank.
- if (r < 1)
- {
- Log::Fatal << "The rank of the factorization cannot be less than 1."
- << std::endl;
- }
-
- if ((updateRules != "multdist") &&
- (updateRules != "multdiv") &&
- (updateRules != "als"))
- {
- Log::Fatal << "Invalid update rules ('" << updateRules << "'); must be '"
- << "multdist', 'multdiv', or 'als'." << std::endl;
- }
-
- // Load input dataset.
- arma::mat V;
- data::Load(inputFile, V, true);
-
- arma::mat W;
- arma::mat H;
-
- // Perform NMF with the specified update rules.
- if (updateRules == "multdist")
- {
- Log::Info << "Performing NMF with multiplicative distance-based update "
- << "rules." << std::endl;
- NMF<> nmf(maxIterations, minResidue);
- nmf.Apply(V, r, W, H);
- }
- else if (updateRules == "multdiv")
- {
- Log::Info << "Performing NMF with multiplicative divergence-based update "
- << "rules." << std::endl;
- NMF<RandomInitialization,
- WMultiplicativeDivergenceRule,
- HMultiplicativeDivergenceRule> nmf(maxIterations, minResidue);
- nmf.Apply(V, r, W, H);
- }
- else if (updateRules == "als")
- {
- Log::Info << "Performing NMF with alternating least squared update rules."
- << std::endl;
- NMF<RandomInitialization,
- WAlternatingLeastSquaresRule,
- HAlternatingLeastSquaresRule> nmf(maxIterations, minResidue);
- nmf.Apply(V, r, W, H);
- }
-
- // Save results.
- data::Save(wOutputFile, W, false);
- data::Save(hOutputFile, H, false);
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/nmf_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/nmf/nmf_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/nmf_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/nmf_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,147 @@
+/**
+ * @file nmf_main.cpp
+ * @author Mohan Rajendran
+ *
+ * Main executable to run NMF.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+
+#include "nmf.hpp"
+
+#include "random_init.hpp"
+#include "mult_dist_update_rules.hpp"
+#include "mult_div_update_rules.hpp"
+#include "als_update_rules.hpp"
+
+using namespace mlpack;
+using namespace mlpack::nmf;
+using namespace std;
+
+// Document program.
+PROGRAM_INFO("Non-negative Matrix Factorization", "This program performs "
+ "non-negative matrix factorization on the given dataset, storing the "
+ "resulting decomposed matrices in the specified files. For an input "
+ "dataset V, NMF decomposes V into two matrices W and H such that "
+ "\n\n"
+ "V = W * H"
+ "\n\n"
+ "where all elements in W and H are non-negative. If V is of size (n x m),"
+ " then W will be of size (n x r) and H will be of size (r x m), where r is "
+ "the rank of the factorization (specified by --rank)."
+ "\n\n"
+ "Optionally, the desired update rules for each NMF iteration can be chosen "
+ "from the following list:"
+ "\n\n"
+ " - multdist: multiplicative distance-based update rules (Lee and Seung "
+ "1999)\n"
+ " - multdiv: multiplicative divergence-based update rules (Lee and Seung "
+ "1999)\n"
+ " - als: alternating least squares update rules (Paatero and Tapper 1994)"
+ "\n\n"
+ "The maximum number of iterations is specified with --max_iterations, and "
+ "the minimum residue required for algorithm termination is specified with "
+ "--min_residue.");
+
+// Parameters for program.
+PARAM_STRING_REQ("input_file", "Input dataset to perform NMF on.", "i");
+PARAM_STRING_REQ("w_file", "File to save the calculated W matrix to.", "W");
+PARAM_STRING_REQ("h_file", "File to save the calculated H matrix to.", "H");
+PARAM_INT_REQ("rank", "Rank of the factorization.", "r");
+
+PARAM_INT("max_iterations", "Number of iterations before NMF terminates (0 runs"
+ " until convergence.", "m", 10000);
+PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
+PARAM_DOUBLE("min_residue", "The minimum root mean square residue allowed for "
+ "each iteration, below which the program terminates.", "e", 1e-5);
+
+PARAM_STRING("update_rules", "Update rules for each iteration; ( multdist | "
+ "multdiv | als ).", "u", "multdist");
+
+int main(int argc, char** argv)
+{
+ // Parse command line.
+ CLI::ParseCommandLine(argc, argv);
+
+ // Initialize random seed.
+ if (CLI::GetParam<int>("seed") != 0)
+ math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
+ else
+ math::RandomSeed((size_t) std::time(NULL));
+
+ // Gather parameters.
+ const string inputFile = CLI::GetParam<string>("input_file");
+ const string hOutputFile = CLI::GetParam<string>("h_file");
+ const string wOutputFile = CLI::GetParam<string>("w_file");
+ const size_t r = CLI::GetParam<int>("rank");
+ const size_t maxIterations = CLI::GetParam<int>("max_iterations");
+ const double minResidue = CLI::GetParam<double>("min_residue");
+ const string updateRules = CLI::GetParam<string>("update_rules");
+
+ // Validate rank.
+ if (r < 1)
+ {
+ Log::Fatal << "The rank of the factorization cannot be less than 1."
+ << std::endl;
+ }
+
+ if ((updateRules != "multdist") &&
+ (updateRules != "multdiv") &&
+ (updateRules != "als"))
+ {
+ Log::Fatal << "Invalid update rules ('" << updateRules << "'); must be '"
+ << "multdist', 'multdiv', or 'als'." << std::endl;
+ }
+
+ // Load input dataset.
+ arma::mat V;
+ data::Load(inputFile, V, true);
+
+ arma::mat W;
+ arma::mat H;
+
+ // Perform NMF with the specified update rules.
+ if (updateRules == "multdist")
+ {
+ Log::Info << "Performing NMF with multiplicative distance-based update "
+ << "rules." << std::endl;
+ NMF<> nmf(maxIterations, minResidue);
+ nmf.Apply(V, r, W, H);
+ }
+ else if (updateRules == "multdiv")
+ {
+ Log::Info << "Performing NMF with multiplicative divergence-based update "
+ << "rules." << std::endl;
+ NMF<RandomInitialization,
+ WMultiplicativeDivergenceRule,
+ HMultiplicativeDivergenceRule> nmf(maxIterations, minResidue);
+ nmf.Apply(V, r, W, H);
+ }
+ else if (updateRules == "als")
+ {
+ Log::Info << "Performing NMF with alternating least squared update rules."
+ << std::endl;
+ NMF<RandomInitialization,
+ WAlternatingLeastSquaresRule,
+ HAlternatingLeastSquaresRule> nmf(maxIterations, minResidue);
+ nmf.Apply(V, r, W, H);
+ }
+
+ // Save results.
+ data::Save(wOutputFile, W, false);
+ data::Save(hOutputFile, H, false);
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/random_acol_init.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/nmf/random_acol_init.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/random_acol_init.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,85 +0,0 @@
-/**
- * @file random_acol_init.hpp
- * @author Mohan Rajendran
- *
- * Intialization rule for Non-Negative Matrix Factorization. This simple
- * initialization is performed by the random Acol initialization introduced in
- * the paper 'Algorithms, Initializations and Convergence' by Langville et al.
- * This method sets each of the columns of W by averaging p randomly chosen
- * columns of V.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NMF_RANDOM_ACOL_INIT_HPP
-#define __MLPACK_METHODS_NMF_RANDOM_ACOL_INIT_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace nmf {
-
-/**
- * This class initializes the W matrix of the NMF algorithm by averaging p
- * randomly chosen columns of V. In this case, p is a template parameter. H is
- * then set randomly.
- *
- * @tparam The number of random columns to average for each column of W.
- */
-template<int p = 5>
-class RandomAcolInitialization
-{
- public:
- // Empty constructor required for the InitializeRule template
- RandomAcolInitialization()
- { }
-
- inline static void Initialize(const arma::mat& V,
- const size_t r,
- arma::mat& W,
- arma::mat& H)
- {
- const size_t n = V.n_rows;
- const size_t m = V.n_cols;
-
- if (p > m)
- {
- Log::Warn << "Number of random columns is more than the number of columns"
- << "available in the V matrix; weird results may ensue!" << std::endl;
- }
-
- W.zeros(n, r);
-
- // Initialize W matrix with random columns.
- for (size_t col = 0; col < r; col++)
- {
- for (size_t randCol = 0; randCol < p; randCol++)
- {
- W.col(col) += V.col(math::RandInt(0, m));
- }
- }
-
- // Now divide by p.
- W /= p;
-
- // Initialize H to random values.
- H.randu(r, m);
- }
-}; // Class RandomAcolInitialization
-
-}; // namespace nmf
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/random_acol_init.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/nmf/random_acol_init.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/random_acol_init.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/random_acol_init.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,85 @@
+/**
+ * @file random_acol_init.hpp
+ * @author Mohan Rajendran
+ *
+ * Intialization rule for Non-Negative Matrix Factorization. This simple
+ * initialization is performed by the random Acol initialization introduced in
+ * the paper 'Algorithms, Initializations and Convergence' by Langville et al.
+ * This method sets each of the columns of W by averaging p randomly chosen
+ * columns of V.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NMF_RANDOM_ACOL_INIT_HPP
+#define __MLPACK_METHODS_NMF_RANDOM_ACOL_INIT_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace nmf {
+
+/**
+ * This class initializes the W matrix of the NMF algorithm by averaging p
+ * randomly chosen columns of V. In this case, p is a template parameter. H is
+ * then set randomly.
+ *
+ * @tparam The number of random columns to average for each column of W.
+ */
+template<int p = 5>
+class RandomAcolInitialization
+{
+ public:
+ // Empty constructor required for the InitializeRule template
+ RandomAcolInitialization()
+ { }
+
+ inline static void Initialize(const arma::mat& V,
+ const size_t r,
+ arma::mat& W,
+ arma::mat& H)
+ {
+ const size_t n = V.n_rows;
+ const size_t m = V.n_cols;
+
+ if (p > m)
+ {
+ Log::Warn << "Number of random columns is more than the number of columns"
+ << "available in the V matrix; weird results may ensue!" << std::endl;
+ }
+
+ W.zeros(n, r);
+
+ // Initialize W matrix with random columns.
+ for (size_t col = 0; col < r; col++)
+ {
+ for (size_t randCol = 0; randCol < p; randCol++)
+ {
+ W.col(col) += V.col(math::RandInt(0, m));
+ }
+ }
+
+ // Now divide by p.
+ W /= p;
+
+ // Initialize H to random values.
+ H.randu(r, m);
+ }
+}; // Class RandomAcolInitialization
+
+}; // namespace nmf
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/random_init.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/nmf/random_init.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/random_init.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,55 +0,0 @@
-/**
- * @file random_init.hpp
- * @author Mohan Rajendran
- *
- * Intialization rule for Non-Negative Matrix Factorization (NMF). This simple
- * initialization is performed by assigning a random matrix to W and H.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NMF_RANDOM_INIT_HPP
-#define __MLPACK_METHODS_NMF_RANDOM_INIT_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace nmf {
-
-class RandomInitialization
-{
- public:
- // Empty constructor required for the InitializeRule template
- RandomInitialization() { }
-
- inline static void Initialize(const arma::mat& V,
- const size_t r,
- arma::mat& W,
- arma::mat& H)
- {
- // Simple implementation (left in the header file due to its simplicity).
- size_t n = V.n_rows;
- size_t m = V.n_cols;
-
- // Intialize to random values.
- W.randu(n, r);
- H.randu(r, m);
- }
-};
-
-}; // namespace nmf
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/random_init.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/nmf/random_init.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/random_init.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/nmf/random_init.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,55 @@
+/**
+ * @file random_init.hpp
+ * @author Mohan Rajendran
+ *
+ * Intialization rule for Non-Negative Matrix Factorization (NMF). This simple
+ * initialization is performed by assigning a random matrix to W and H.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NMF_RANDOM_INIT_HPP
+#define __MLPACK_METHODS_NMF_RANDOM_INIT_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace nmf {
+
+class RandomInitialization
+{
+ public:
+ // Empty constructor required for the InitializeRule template
+ RandomInitialization() { }
+
+ inline static void Initialize(const arma::mat& V,
+ const size_t r,
+ arma::mat& W,
+ arma::mat& H)
+ {
+ // Simple implementation (left in the header file due to its simplicity).
+ size_t n = V.n_rows;
+ size_t m = V.n_cols;
+
+ // Intialize to random values.
+ W.randu(n, r);
+ H.randu(r, m);
+ }
+};
+
+}; // namespace nmf
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/pca/pca.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/pca/pca.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/pca/pca.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,112 +0,0 @@
-/**
- * @file pca.cpp
- * @author Ajinkya Kale
- *
- * Implementation of PCA class to perform Principal Components Analysis on the
- * specified data set.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include "pca.hpp"
-#include <mlpack/core.hpp>
-#include <iostream>
-
-using namespace std;
-namespace mlpack {
-namespace pca {
-
-PCA::PCA(const bool scaleData) :
- scaleData(scaleData)
-{ }
-
-/**
- * Apply Principal Component Analysis to the provided data set.
- *
- * @param data - Data matrix
- * @param transformedData - Data with PCA applied
- * @param eigVal - contains eigen values in a column vector
- * @param coeff - PCA Loadings/Coeffs/EigenVectors
- */
-void PCA::Apply(const arma::mat& data,
- arma::mat& transformedData,
- arma::vec& eigVal,
- arma::mat& coeffs) const
-{
- //Original transpose op goes here.
- arma::mat covMat = ccov(data);
-
- //Centering is built into ccov
- if (scaleData)
- {
- covMat = covMat / (arma::ones<arma::colvec>(covMat.n_rows))
- * stddev(covMat, 0, 0);
- }
-
- arma::eig_sym(eigVal, coeffs, covMat);
-
- int nEigVal = eigVal.n_elem;
- for (int i = 0; i < floor(nEigVal / 2.0); i++)
- eigVal.swap_rows(i, (nEigVal - 1) - i);
-
- coeffs = arma::fliplr(coeffs);
- transformedData = trans(coeffs) * data;
- arma::colvec transformedDataMean = arma::mean(transformedData, 1);
- transformedData = transformedData - (transformedDataMean *
- arma::ones<arma::rowvec>(transformedData.n_cols));
-}
-
-/**
- * Apply Principal Component Analysis to the provided data set.
- *
- * @param data - Data matrix
- * @param transformedData - Data with PCA applied
- * @param eigVal - contains eigen values in a column vector
- */
-void PCA::Apply(const arma::mat& data,
- arma::mat& transformedData,
- arma::vec& eigVal) const
-{
- arma::mat coeffs;
- Apply(data, transformedData, eigVal, coeffs);
-}
-
-/**
- * Apply Dimensionality Reduction using Principal Component Analysis
- * to the provided data set.
- *
- * @param data - M x N Data matrix
- * @param newDimension - matrix consisting of N column vectors,
- * where each vector is the projection of the corresponding data vector
- * from data matrix onto the basis vectors contained in the columns of
- * coeff/eigen vector matrix with only newDimension number of columns chosen.
- */
-void PCA::Apply(arma::mat& data, const size_t newDimension) const
-{
- arma::mat coeffs;
- arma::vec eigVal;
-
- Apply(data, data, eigVal, coeffs);
-
- if (newDimension < coeffs.n_rows && newDimension > 0)
- data.shed_rows(newDimension, data.n_rows - 1);
-}
-
-PCA::~PCA()
-{
-}
-
-}; // namespace mlpack
-}; // namespace pca
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/pca/pca.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/pca/pca.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/pca/pca.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/pca/pca.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,112 @@
+/**
+ * @file pca.cpp
+ * @author Ajinkya Kale
+ *
+ * Implementation of PCA class to perform Principal Components Analysis on the
+ * specified data set.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include "pca.hpp"
+#include <mlpack/core.hpp>
+#include <iostream>
+
+using namespace std;
+namespace mlpack {
+namespace pca {
+
+PCA::PCA(const bool scaleData) :
+ scaleData(scaleData)
+{ }
+
+/**
+ * Apply Principal Component Analysis to the provided data set.
+ *
+ * @param data - Data matrix
+ * @param transformedData - Data with PCA applied
+ * @param eigVal - contains eigen values in a column vector
+ * @param coeff - PCA Loadings/Coeffs/EigenVectors
+ */
+void PCA::Apply(const arma::mat& data,
+ arma::mat& transformedData,
+ arma::vec& eigVal,
+ arma::mat& coeffs) const
+{
+ //Original transpose op goes here.
+ arma::mat covMat = ccov(data);
+
+ //Centering is built into ccov
+ if (scaleData)
+ {
+ covMat = covMat / (arma::ones<arma::colvec>(covMat.n_rows))
+ * stddev(covMat, 0, 0);
+ }
+
+ arma::eig_sym(eigVal, coeffs, covMat);
+
+ int nEigVal = eigVal.n_elem;
+ for (int i = 0; i < floor(nEigVal / 2.0); i++)
+ eigVal.swap_rows(i, (nEigVal - 1) - i);
+
+ coeffs = arma::fliplr(coeffs);
+ transformedData = trans(coeffs) * data;
+ arma::colvec transformedDataMean = arma::mean(transformedData, 1);
+ transformedData = transformedData - (transformedDataMean *
+ arma::ones<arma::rowvec>(transformedData.n_cols));
+}
+
+/**
+ * Apply Principal Component Analysis to the provided data set.
+ *
+ * @param data - Data matrix
+ * @param transformedData - Data with PCA applied
+ * @param eigVal - contains eigen values in a column vector
+ */
+void PCA::Apply(const arma::mat& data,
+ arma::mat& transformedData,
+ arma::vec& eigVal) const
+{
+ arma::mat coeffs;
+ Apply(data, transformedData, eigVal, coeffs);
+}
+
+/**
+ * Apply Dimensionality Reduction using Principal Component Analysis
+ * to the provided data set.
+ *
+ * @param data - M x N Data matrix
+ * @param newDimension - matrix consisting of N column vectors,
+ * where each vector is the projection of the corresponding data vector
+ * from data matrix onto the basis vectors contained in the columns of
+ * coeff/eigen vector matrix with only newDimension number of columns chosen.
+ */
+void PCA::Apply(arma::mat& data, const size_t newDimension) const
+{
+ arma::mat coeffs;
+ arma::vec eigVal;
+
+ Apply(data, data, eigVal, coeffs);
+
+ if (newDimension < coeffs.n_rows && newDimension > 0)
+ data.shed_rows(newDimension, data.n_rows - 1);
+}
+
+PCA::~PCA()
+{
+}
+
+}; // namespace mlpack
+}; // namespace pca
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/pca/pca.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/pca/pca.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/pca/pca.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,88 +0,0 @@
-/**
- * @file pca.hpp
- * @author Ajinkya Kale
- *
- * Defines the PCA class to perform Principal Components Analysis on the
- * specified data set.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_PCA_PCA_HPP
-#define __MLPACK_METHODS_PCA_PCA_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace pca {
-
-class PCA
-{
- public:
- PCA(const bool scaleData = false);
-
- /**
- * Apply Principal Component Analysis to the provided data set.
- *
- * @param data - Data matrix
- * @param transformedData - Data with PCA applied
- * @param eigVal - contains eigen values in a column vector
- * @param coeff - PCA Loadings/Coeffs/EigenVectors
- */
- void Apply(const arma::mat& data, arma::mat& transformedData, arma::vec&
- eigVal, arma::mat& coeff) const;
-
- /**
- * Apply Principal Component Analysis to the provided data set.
- *
- * @param data - Data matrix
- * @param transformedData - Data with PCA applied
- * @param eigVal - contains eigen values in a column vector
- */
- void Apply(const arma::mat& data, arma::mat& transformedData,
- arma::vec& eigVal) const;
-
- /**
- * Apply Dimensionality Reduction using Principal Component Analysis
- * to the provided data set.
- *
- * @param data - M x N Data matrix
- * @param newDimension - matrix consisting of N column vectors,
- * where each vector is the projection of the corresponding data vector
- * from data matrix onto the basis vectors contained in the columns of
- * coeff/eigen vector matrix with only newDimension number of columns chosen.
- */
- void Apply(arma::mat& data, const size_t newDimension) const;
-
- /**
- * Delete PCA object
- */
- ~PCA();
-
- //! Get whether or not this PCA object will scale (by standard deviation) the
- //! data when PCA is performed.
- bool ScaleData() const { return scaleData; }
- //! Modify whether or not this PCA object will scale (by standard deviation)
- //! the data when PCA is performed.
- bool& ScaleData() { return scaleData; }
-
- private:
- bool scaleData;
-}; // class PCA
-
-}; // namespace pca
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/pca/pca.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/pca/pca.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/pca/pca.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/pca/pca.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,88 @@
+/**
+ * @file pca.hpp
+ * @author Ajinkya Kale
+ *
+ * Defines the PCA class to perform Principal Components Analysis on the
+ * specified data set.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_PCA_PCA_HPP
+#define __MLPACK_METHODS_PCA_PCA_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace pca {
+
+class PCA
+{
+ public:
+ PCA(const bool scaleData = false);
+
+ /**
+ * Apply Principal Component Analysis to the provided data set.
+ *
+ * @param data - Data matrix
+ * @param transformedData - Data with PCA applied
+ * @param eigVal - contains eigen values in a column vector
+ * @param coeff - PCA Loadings/Coeffs/EigenVectors
+ */
+ void Apply(const arma::mat& data, arma::mat& transformedData, arma::vec&
+ eigVal, arma::mat& coeff) const;
+
+ /**
+ * Apply Principal Component Analysis to the provided data set.
+ *
+ * @param data - Data matrix
+ * @param transformedData - Data with PCA applied
+ * @param eigVal - contains eigen values in a column vector
+ */
+ void Apply(const arma::mat& data, arma::mat& transformedData,
+ arma::vec& eigVal) const;
+
+ /**
+ * Apply Dimensionality Reduction using Principal Component Analysis
+ * to the provided data set.
+ *
+ * @param data - M x N Data matrix
+ * @param newDimension - matrix consisting of N column vectors,
+ * where each vector is the projection of the corresponding data vector
+ * from data matrix onto the basis vectors contained in the columns of
+ * coeff/eigen vector matrix with only newDimension number of columns chosen.
+ */
+ void Apply(arma::mat& data, const size_t newDimension) const;
+
+ /**
+ * Delete PCA object
+ */
+ ~PCA();
+
+ //! Get whether or not this PCA object will scale (by standard deviation) the
+ //! data when PCA is performed.
+ bool ScaleData() const { return scaleData; }
+ //! Modify whether or not this PCA object will scale (by standard deviation)
+ //! the data when PCA is performed.
+ bool& ScaleData() { return scaleData; }
+
+ private:
+ bool scaleData;
+}; // class PCA
+
+}; // namespace pca
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/pca/pca_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/pca/pca_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/pca/pca_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,81 +0,0 @@
-/**
- * @file pca_main.cpp
- * @author Ryan Curtin
- *
- * Main executable to run PCA.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-
-#include "pca.hpp"
-
-using namespace mlpack;
-using namespace mlpack::pca;
-using namespace std;
-
-// Document program.
-PROGRAM_INFO("Principal Components Analysis", "This program performs principal "
- "components analysis on the given dataset. It will transform the data "
- "onto its principal components, optionally performing dimensionality "
- "reduction by ignoring the principal components with the smallest "
- "eigenvalues.");
-
-// Parameters for program.
-PARAM_STRING_REQ("input_file", "Input dataset to perform PCA on.", "i");
-PARAM_STRING_REQ("output_file", "File to save modified dataset to.", "o");
-PARAM_INT("new_dimensionality", "Desired dimensionality of output dataset. If "
- "0, no dimensionality reduction is performed.", "d", 0);
-
-PARAM_FLAG("scale", "If set, the data will be scaled before running PCA, such "
- "that the variance of each feature is 1.", "s");
-
-int main(int argc, char** argv)
-{
- // Parse commandline.
- CLI::ParseCommandLine(argc, argv);
-
- // Load input dataset.
- string inputFile = CLI::GetParam<string>("input_file");
- arma::mat dataset;
- data::Load(inputFile.c_str(), dataset);
-
- // Find out what dimension we want.
- size_t newDimension = dataset.n_rows; // No reduction, by default.
- if (CLI::GetParam<int>("new_dimensionality") != 0)
- {
- // Validate the parameter.
- newDimension = (size_t) CLI::GetParam<int>("new_dimensionality");
- if (newDimension > dataset.n_rows)
- {
- Log::Fatal << "New dimensionality (" << newDimension
- << ") cannot be greater than existing dimensionality ("
- << dataset.n_rows << ")!" << endl;
- }
- }
-
- // Get the options for running PCA.
- const size_t scale = CLI::HasParam("scale");
-
- // Perform PCA.
- PCA p(scale);
- Log::Info << "Performing PCA on dataset..." << endl;
- p.Apply(dataset, newDimension);
-
- // Now save the results.
- string outputFile = CLI::GetParam<string>("output_file");
- data::Save(outputFile, dataset);
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/pca/pca_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/pca/pca_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/pca/pca_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/pca/pca_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,81 @@
+/**
+ * @file pca_main.cpp
+ * @author Ryan Curtin
+ *
+ * Main executable to run PCA.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+
+#include "pca.hpp"
+
+using namespace mlpack;
+using namespace mlpack::pca;
+using namespace std;
+
+// Document program.
+PROGRAM_INFO("Principal Components Analysis", "This program performs principal "
+ "components analysis on the given dataset. It will transform the data "
+ "onto its principal components, optionally performing dimensionality "
+ "reduction by ignoring the principal components with the smallest "
+ "eigenvalues.");
+
+// Parameters for program.
+PARAM_STRING_REQ("input_file", "Input dataset to perform PCA on.", "i");
+PARAM_STRING_REQ("output_file", "File to save modified dataset to.", "o");
+PARAM_INT("new_dimensionality", "Desired dimensionality of output dataset. If "
+ "0, no dimensionality reduction is performed.", "d", 0);
+
+PARAM_FLAG("scale", "If set, the data will be scaled before running PCA, such "
+ "that the variance of each feature is 1.", "s");
+
+int main(int argc, char** argv)
+{
+ // Parse commandline.
+ CLI::ParseCommandLine(argc, argv);
+
+ // Load input dataset.
+ string inputFile = CLI::GetParam<string>("input_file");
+ arma::mat dataset;
+ data::Load(inputFile.c_str(), dataset);
+
+ // Find out what dimension we want.
+ size_t newDimension = dataset.n_rows; // No reduction, by default.
+ if (CLI::GetParam<int>("new_dimensionality") != 0)
+ {
+ // Validate the parameter.
+ newDimension = (size_t) CLI::GetParam<int>("new_dimensionality");
+ if (newDimension > dataset.n_rows)
+ {
+ Log::Fatal << "New dimensionality (" << newDimension
+ << ") cannot be greater than existing dimensionality ("
+ << dataset.n_rows << ")!" << endl;
+ }
+ }
+
+ // Get the options for running PCA.
+ const size_t scale = CLI::HasParam("scale");
+
+ // Perform PCA.
+ PCA p(scale);
+ Log::Info << "Performing PCA on dataset..." << endl;
+ p.Apply(dataset, newDimension);
+
+ // Now save the results.
+ string outputFile = CLI::GetParam<string>("output_file");
+ data::Save(outputFile, dataset);
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/radical/radical.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/radical/radical.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/radical/radical.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,201 +0,0 @@
-/**
- * @file radical.cpp
- * @author Nishant Mehta
- *
- * Implementation of Radical class
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#include "radical.hpp"
-
-using namespace std;
-using namespace arma;
-using namespace mlpack;
-using namespace mlpack::radical;
-
-// Set the parameters to RADICAL.
-Radical::Radical(const double noiseStdDev,
- const size_t replicates,
- const size_t angles,
- const size_t sweeps,
- const size_t m) :
- noiseStdDev(noiseStdDev),
- replicates(replicates),
- angles(angles),
- sweeps(sweeps),
- m(m)
-{
- // Nothing to do here.
-}
-
-void Radical::CopyAndPerturb(mat& xNew, const mat& x) const
-{
- Timer::Start("radical_copy_and_perturb");
- xNew = repmat(x, replicates, 1) + noiseStdDev * randn(replicates * x.n_rows,
- x.n_cols);
- Timer::Stop("radical_copy_and_perturb");
-}
-
-
-double Radical::Vasicek(vec& z) const
-{
- z = sort(z);
-
- // Apparently slower.
- /*
- vec logs = log(z.subvec(m, z.n_elem - 1) - z.subvec(0, z.n_elem - 1 - m));
- //vec val = sum(log(z.subvec(m, nPoints - 1) - z.subvec(0, nPoints - 1 - m)));
- return (double) sum(logs);
- */
-
- // Apparently faster.
- double sum = 0;
- uword range = z.n_elem - m;
- for (uword i = 0; i < range; i++)
- {
- sum += log(z(i + m) - z(i));
- }
-
- return sum;
-}
-
-
-double Radical::DoRadical2D(const mat& matX)
-{
- CopyAndPerturb(perturbed, matX);
-
- mat::fixed<2, 2> matJacobi;
-
- vec values(angles);
-
- for (size_t i = 0; i < angles; i++)
- {
- const double theta = (i / (double) angles) * M_PI / 2.0;
- const double cosTheta = cos(theta);
- const double sinTheta = sin(theta);
-
- matJacobi(0, 0) = cosTheta;
- matJacobi(1, 0) = -sinTheta;
- matJacobi(0, 1) = sinTheta;
- matJacobi(1, 1) = cosTheta;
-
- candidate = perturbed * matJacobi;
- vec candidateY1 = candidate.unsafe_col(0);
- vec candidateY2 = candidate.unsafe_col(1);
-
- values(i) = Vasicek(candidateY1) + Vasicek(candidateY2);
- }
-
- uword indOpt;
- values.min(indOpt); // we ignore the return value; we don't care about it
- return (indOpt / (double) angles) * M_PI / 2.0;
-}
-
-
-void Radical::DoRadical(const mat& matXT, mat& matY, mat& matW)
-{
- // matX is nPoints by nDims (although less intuitive than columns being
- // points, and although this is the transpose of the ICA literature, this
- // choice is for computational efficiency when repeatedly generating
- // two-dimensional coordinate projections for Radical2D).
- Timer::Start("radical_transpose_data");
- mat matX = trans(matXT);
- Timer::Stop("radical_transpose_data");
-
- // If m was not specified, initialize m as recommended in
- // (Learned-Miller and Fisher, 2003).
- if (m < 1)
- m = floor(sqrt((double) matX.n_rows));
-
- const size_t nDims = matX.n_cols;
- const size_t nPoints = matX.n_rows;
-
- Timer::Start("radical_whiten_data");
- mat matXWhitened;
- mat matWhitening;
- WhitenFeatureMajorMatrix(matX, matY, matWhitening);
- Timer::Stop("radical_whiten_data");
- // matY is now the whitened form of matX.
-
- // In the RADICAL code, they do not copy and perturb initially, although the
- // paper does. We follow the code as it should match their reported results
- // and likely does a better job bouncing out of local optima.
- //GeneratePerturbedX(X, X);
-
- // Initialize the unmixing matrix to the whitening matrix.
- Timer::Start("radical_do_radical");
- matW = matWhitening;
-
- mat matYSubspace(nPoints, 2);
-
- mat matJ = eye(nDims, nDims);
-
- for (size_t sweepNum = 0; sweepNum < sweeps; sweepNum++)
- {
- Log::Info << "RADICAL: sweep " << sweepNum << "." << std::endl;
-
- for (size_t i = 0; i < nDims - 1; i++)
- {
- for (size_t j = i + 1; j < nDims; j++)
- {
- Log::Debug << "RADICAL 2D on dimensions " << i << " and " << j << "."
- << std::endl;
-
- matYSubspace.col(0) = matY.col(i);
- matYSubspace.col(1) = matY.col(j);
-
- const double thetaOpt = DoRadical2D(matYSubspace);
-
- const double cosThetaOpt = cos(thetaOpt);
- const double sinThetaOpt = sin(thetaOpt);
-
- // Set elements of transformation matrix.
- matJ(i, i) = cosThetaOpt;
- matJ(j, i) = -sinThetaOpt;
- matJ(i, j) = sinThetaOpt;
- matJ(j, j) = cosThetaOpt;
-
- matY *= matJ;
-
- // Unset elements of transformation matrix, so matJ = eye() again.
- matJ(i, i) = 1;
- matJ(j, i) = 0;
- matJ(i, j) = 0;
- matJ(j, j) = 1;
- }
- }
- }
- Timer::Stop("radical_do_radical");
-
- // The final transposes provide W and Y in the typical form from the ICA
- // literature.
- Timer::Start("radical_transpose_data");
- matW = trans(matW);
- matY = trans(matY);
- Timer::Stop("radical_transpose_data");
-}
-
-void mlpack::radical::WhitenFeatureMajorMatrix(const mat& matX,
- mat& matXWhitened,
- mat& matWhitening)
-{
- mat matU, matV;
- vec s;
- svd(matU, s, matV, cov(matX));
- matWhitening = matU * diagmat(1 / sqrt(s)) * trans(matV);
- matXWhitened = matX * matWhitening;
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/radical/radical.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/radical/radical.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/radical/radical.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/radical/radical.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,201 @@
+/**
+ * @file radical.cpp
+ * @author Nishant Mehta
+ *
+ * Implementation of Radical class
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include "radical.hpp"
+
+using namespace std;
+using namespace arma;
+using namespace mlpack;
+using namespace mlpack::radical;
+
+// Set the parameters to RADICAL.
+Radical::Radical(const double noiseStdDev,
+ const size_t replicates,
+ const size_t angles,
+ const size_t sweeps,
+ const size_t m) :
+ noiseStdDev(noiseStdDev),
+ replicates(replicates),
+ angles(angles),
+ sweeps(sweeps),
+ m(m)
+{
+ // Nothing to do here.
+}
+
+void Radical::CopyAndPerturb(mat& xNew, const mat& x) const
+{
+ Timer::Start("radical_copy_and_perturb");
+ xNew = repmat(x, replicates, 1) + noiseStdDev * randn(replicates * x.n_rows,
+ x.n_cols);
+ Timer::Stop("radical_copy_and_perturb");
+}
+
+
+double Radical::Vasicek(vec& z) const
+{
+ z = sort(z);
+
+ // Apparently slower.
+ /*
+ vec logs = log(z.subvec(m, z.n_elem - 1) - z.subvec(0, z.n_elem - 1 - m));
+ //vec val = sum(log(z.subvec(m, nPoints - 1) - z.subvec(0, nPoints - 1 - m)));
+ return (double) sum(logs);
+ */
+
+ // Apparently faster.
+ double sum = 0;
+ uword range = z.n_elem - m;
+ for (uword i = 0; i < range; i++)
+ {
+ sum += log(z(i + m) - z(i));
+ }
+
+ return sum;
+}
+
+
+double Radical::DoRadical2D(const mat& matX)
+{
+ CopyAndPerturb(perturbed, matX);
+
+ mat::fixed<2, 2> matJacobi;
+
+ vec values(angles);
+
+ for (size_t i = 0; i < angles; i++)
+ {
+ const double theta = (i / (double) angles) * M_PI / 2.0;
+ const double cosTheta = cos(theta);
+ const double sinTheta = sin(theta);
+
+ matJacobi(0, 0) = cosTheta;
+ matJacobi(1, 0) = -sinTheta;
+ matJacobi(0, 1) = sinTheta;
+ matJacobi(1, 1) = cosTheta;
+
+ candidate = perturbed * matJacobi;
+ vec candidateY1 = candidate.unsafe_col(0);
+ vec candidateY2 = candidate.unsafe_col(1);
+
+ values(i) = Vasicek(candidateY1) + Vasicek(candidateY2);
+ }
+
+ uword indOpt;
+ values.min(indOpt); // we ignore the return value; we don't care about it
+ return (indOpt / (double) angles) * M_PI / 2.0;
+}
+
+
+void Radical::DoRadical(const mat& matXT, mat& matY, mat& matW)
+{
+ // matX is nPoints by nDims (although less intuitive than columns being
+ // points, and although this is the transpose of the ICA literature, this
+ // choice is for computational efficiency when repeatedly generating
+ // two-dimensional coordinate projections for Radical2D).
+ Timer::Start("radical_transpose_data");
+ mat matX = trans(matXT);
+ Timer::Stop("radical_transpose_data");
+
+ // If m was not specified, initialize m as recommended in
+ // (Learned-Miller and Fisher, 2003).
+ if (m < 1)
+ m = floor(sqrt((double) matX.n_rows));
+
+ const size_t nDims = matX.n_cols;
+ const size_t nPoints = matX.n_rows;
+
+ Timer::Start("radical_whiten_data");
+ mat matXWhitened;
+ mat matWhitening;
+ WhitenFeatureMajorMatrix(matX, matY, matWhitening);
+ Timer::Stop("radical_whiten_data");
+ // matY is now the whitened form of matX.
+
+ // In the RADICAL code, they do not copy and perturb initially, although the
+ // paper does. We follow the code as it should match their reported results
+ // and likely does a better job bouncing out of local optima.
+ //GeneratePerturbedX(X, X);
+
+ // Initialize the unmixing matrix to the whitening matrix.
+ Timer::Start("radical_do_radical");
+ matW = matWhitening;
+
+ mat matYSubspace(nPoints, 2);
+
+ mat matJ = eye(nDims, nDims);
+
+ for (size_t sweepNum = 0; sweepNum < sweeps; sweepNum++)
+ {
+ Log::Info << "RADICAL: sweep " << sweepNum << "." << std::endl;
+
+ for (size_t i = 0; i < nDims - 1; i++)
+ {
+ for (size_t j = i + 1; j < nDims; j++)
+ {
+ Log::Debug << "RADICAL 2D on dimensions " << i << " and " << j << "."
+ << std::endl;
+
+ matYSubspace.col(0) = matY.col(i);
+ matYSubspace.col(1) = matY.col(j);
+
+ const double thetaOpt = DoRadical2D(matYSubspace);
+
+ const double cosThetaOpt = cos(thetaOpt);
+ const double sinThetaOpt = sin(thetaOpt);
+
+ // Set elements of transformation matrix.
+ matJ(i, i) = cosThetaOpt;
+ matJ(j, i) = -sinThetaOpt;
+ matJ(i, j) = sinThetaOpt;
+ matJ(j, j) = cosThetaOpt;
+
+ matY *= matJ;
+
+ // Unset elements of transformation matrix, so matJ = eye() again.
+ matJ(i, i) = 1;
+ matJ(j, i) = 0;
+ matJ(i, j) = 0;
+ matJ(j, j) = 1;
+ }
+ }
+ }
+ Timer::Stop("radical_do_radical");
+
+ // The final transposes provide W and Y in the typical form from the ICA
+ // literature.
+ Timer::Start("radical_transpose_data");
+ matW = trans(matW);
+ matY = trans(matY);
+ Timer::Stop("radical_transpose_data");
+}
+
+void mlpack::radical::WhitenFeatureMajorMatrix(const mat& matX,
+ mat& matXWhitened,
+ mat& matWhitening)
+{
+ mat matU, matV;
+ vec s;
+ svd(matU, s, matV, cov(matX));
+ matWhitening = matU * diagmat(1 / sqrt(s)) * trans(matV);
+ matXWhitened = matX * matWhitening;
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/radical/radical.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/radical/radical.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/radical/radical.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,155 +0,0 @@
-/**
- * @file radical.hpp
- * @author Nishant Mehta
- *
- * Declaration of Radical class (RADICAL is Robust, Accurate, Direct ICA
- * aLgorithm).
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#ifndef __MLPACK_METHODS_RADICAL_RADICAL_HPP
-#define __MLPACK_METHODS_RADICAL_RADICAL_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace radical {
-
-/**
- * An implementation of RADICAL, an algorithm for independent component
- * analysis (ICA).
- *
- * Let X be a matrix where each column is a point and each row a dimension.
- * The goal is to find a square unmixing matrix W such that Y = W X and
- * the rows of Y are independent components.
- *
- * For more details, see the following paper:
- *
- * @code
- * @article{learned2003ica,
- * title = {ICA Using Spacings Estimates of Entropy},
- * author = {Learned-Miller, E.G. and Fisher III, J.W.},
- * journal = {Journal of Machine Learning Research},
- * volume = {4},
- * pages = {1271--1295},
- * year = {2003}
- * }
- * @endcode
- */
-class Radical
-{
- public:
- /**
- * Set the parameters to RADICAL.
- *
- * @param noiseStdDev Standard deviation of the Gaussian noise added to the
- * replicates of the data points during Radical2D
- * @param replicates Number of Gaussian-perturbed replicates to use (per
- * point) in Radical2D
- * @param angles Number of angles to consider in brute-force search during
- * Radical2D
- * @param sweeps Number of sweeps. Each sweep calls Radical2D once for each
- * pair of dimensions
- * @param m The variable m from Vasicek's m-spacing estimator of entropy.
- */
- Radical(const double noiseStdDev = 0.175,
- const size_t replicates = 30,
- const size_t angles = 150,
- const size_t sweeps = 0,
- const size_t m = 0);
-
- /**
- * Run RADICAL.
- *
- * @param matX Input data into the algorithm - a matrix where each column is
- * a point and each row is a dimension.
- * @param matY Estimated independent components - a matrix where each column
- * is a point and each row is an estimated independent component.
- * @param matW Estimated unmixing matrix, where matY = matW * matX.
- */
- void DoRadical(const arma::mat& matX, arma::mat& matY, arma::mat& matW);
-
- /**
- * Vasicek's m-spacing estimator of entropy, with overlap modification from
- * (Learned-Miller and Fisher, 2003).
- *
- * @param x Empirical sample (one-dimensional) over which to estimate entropy.
- */
- double Vasicek(arma::vec& x) const;
-
- /**
- * Make replicates of each data point (the number of replicates is set in
- * either the constructor or with Replicates()) and perturb data with Gaussian
- * noise with standard deviation noiseStdDev.
- */
- void CopyAndPerturb(arma::mat& xNew, const arma::mat& x) const;
-
- //! Two-dimensional version of RADICAL.
- double DoRadical2D(const arma::mat& matX);
-
- //! Get the standard deviation of the additive Gaussian noise.
- double NoiseStdDev() const { return noiseStdDev; }
- //! Modify the standard deviation of the additive Gaussian noise.
- double& NoiseStdDev() { return noiseStdDev; }
-
- //! Get the number of Gaussian-perturbed replicates used per point.
- size_t Replicates() const { return replicates; }
- //! Modify the number of Gaussian-perturbed replicates used per point.
- size_t& Replicates() { return replicates; }
-
- //! Get the number of angles considered during brute-force search.
- size_t Angles() const { return angles; }
- //! Modify the number of angles considered during brute-force search.
- size_t& Angles() { return angles; }
-
- //! Get the number of sweeps.
- size_t Sweeps() const { return sweeps; }
- //! Modify the number of sweeps.
- size_t& Sweeps() { return sweeps; }
-
- private:
- //! Standard deviation of the Gaussian noise added to the replicates of
- //! the data points during Radical2D.
- double noiseStdDev;
-
- //! Number of Gaussian-perturbed replicates to use (per point) in Radical2D.
- size_t replicates;
-
- //! Number of angles to consider in brute-force search during Radical2D.
- size_t angles;
-
- //! Number of sweeps; each sweep calls Radical2D once for each pair of
- //! dimensions.
- size_t sweeps;
-
- //! Value of m to use for Vasicek's m-spacing estimator of entropy.
- size_t m;
-
- //! Internal matrix, held as member variable to prevent memory reallocations.
- arma::mat perturbed;
- //! Internal matrix, held as member variable to prevent memory reallocations.
- arma::mat candidate;
-};
-
-void WhitenFeatureMajorMatrix(const arma::mat& matX,
- arma::mat& matXWhitened,
- arma::mat& matWhitening);
-
-}; // namespace radical
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/radical/radical.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/radical/radical.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/radical/radical.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/radical/radical.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,155 @@
+/**
+ * @file radical.hpp
+ * @author Nishant Mehta
+ *
+ * Declaration of Radical class (RADICAL is Robust, Accurate, Direct ICA
+ * aLgorithm).
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef __MLPACK_METHODS_RADICAL_RADICAL_HPP
+#define __MLPACK_METHODS_RADICAL_RADICAL_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace radical {
+
+/**
+ * An implementation of RADICAL, an algorithm for independent component
+ * analysis (ICA).
+ *
+ * Let X be a matrix where each column is a point and each row a dimension.
+ * The goal is to find a square unmixing matrix W such that Y = W X and
+ * the rows of Y are independent components.
+ *
+ * For more details, see the following paper:
+ *
+ * @code
+ * @article{learned2003ica,
+ * title = {ICA Using Spacings Estimates of Entropy},
+ * author = {Learned-Miller, E.G. and Fisher III, J.W.},
+ * journal = {Journal of Machine Learning Research},
+ * volume = {4},
+ * pages = {1271--1295},
+ * year = {2003}
+ * }
+ * @endcode
+ */
+class Radical
+{
+ public:
+ /**
+ * Set the parameters to RADICAL.
+ *
+ * @param noiseStdDev Standard deviation of the Gaussian noise added to the
+ * replicates of the data points during Radical2D
+ * @param replicates Number of Gaussian-perturbed replicates to use (per
+ * point) in Radical2D
+ * @param angles Number of angles to consider in brute-force search during
+ * Radical2D
+ * @param sweeps Number of sweeps. Each sweep calls Radical2D once for each
+ * pair of dimensions
+ * @param m The variable m from Vasicek's m-spacing estimator of entropy.
+ */
+ Radical(const double noiseStdDev = 0.175,
+ const size_t replicates = 30,
+ const size_t angles = 150,
+ const size_t sweeps = 0,
+ const size_t m = 0);
+
+ /**
+ * Run RADICAL.
+ *
+ * @param matX Input data into the algorithm - a matrix where each column is
+ * a point and each row is a dimension.
+ * @param matY Estimated independent components - a matrix where each column
+ * is a point and each row is an estimated independent component.
+ * @param matW Estimated unmixing matrix, where matY = matW * matX.
+ */
+ void DoRadical(const arma::mat& matX, arma::mat& matY, arma::mat& matW);
+
+ /**
+ * Vasicek's m-spacing estimator of entropy, with overlap modification from
+ * (Learned-Miller and Fisher, 2003).
+ *
+ * @param x Empirical sample (one-dimensional) over which to estimate entropy.
+ */
+ double Vasicek(arma::vec& x) const;
+
+ /**
+ * Make replicates of each data point (the number of replicates is set in
+ * either the constructor or with Replicates()) and perturb data with Gaussian
+ * noise with standard deviation noiseStdDev.
+ */
+ void CopyAndPerturb(arma::mat& xNew, const arma::mat& x) const;
+
+ //! Two-dimensional version of RADICAL.
+ double DoRadical2D(const arma::mat& matX);
+
+ //! Get the standard deviation of the additive Gaussian noise.
+ double NoiseStdDev() const { return noiseStdDev; }
+ //! Modify the standard deviation of the additive Gaussian noise.
+ double& NoiseStdDev() { return noiseStdDev; }
+
+ //! Get the number of Gaussian-perturbed replicates used per point.
+ size_t Replicates() const { return replicates; }
+ //! Modify the number of Gaussian-perturbed replicates used per point.
+ size_t& Replicates() { return replicates; }
+
+ //! Get the number of angles considered during brute-force search.
+ size_t Angles() const { return angles; }
+ //! Modify the number of angles considered during brute-force search.
+ size_t& Angles() { return angles; }
+
+ //! Get the number of sweeps.
+ size_t Sweeps() const { return sweeps; }
+ //! Modify the number of sweeps.
+ size_t& Sweeps() { return sweeps; }
+
+ private:
+ //! Standard deviation of the Gaussian noise added to the replicates of
+ //! the data points during Radical2D.
+ double noiseStdDev;
+
+ //! Number of Gaussian-perturbed replicates to use (per point) in Radical2D.
+ size_t replicates;
+
+ //! Number of angles to consider in brute-force search during Radical2D.
+ size_t angles;
+
+ //! Number of sweeps; each sweep calls Radical2D once for each pair of
+ //! dimensions.
+ size_t sweeps;
+
+ //! Value of m to use for Vasicek's m-spacing estimator of entropy.
+ size_t m;
+
+ //! Internal matrix, held as member variable to prevent memory reallocations.
+ arma::mat perturbed;
+ //! Internal matrix, held as member variable to prevent memory reallocations.
+ arma::mat candidate;
+};
+
+void WhitenFeatureMajorMatrix(const arma::mat& matX,
+ arma::mat& matXWhitened,
+ arma::mat& matWhitening);
+
+}; // namespace radical
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/radical/radical_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/radical/radical_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/radical/radical_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,114 +0,0 @@
-/**
- * @file radical_main.cpp
- * @author Nishant Mehta
- *
- * Executable for RADICAL.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#include <mlpack/core.hpp>
-#include "radical.hpp"
-
-PROGRAM_INFO("RADICAL", "An implementation of RADICAL, a method for independent"
- "component analysis (ICA). Assuming that we have an input matrix X, the"
- "goal is to find a square unmixing matrix W such that Y = W * X and the "
- "dimensions of Y are independent components. If the algorithm is running"
- "particularly slowly, try reducing the number of replicates.");
-
-PARAM_STRING_REQ("input_file", "Input dataset filename for ICA.", "i");
-
-PARAM_STRING("output_ic", "File to save independent components to.", "o",
- "output_ic.csv");
-PARAM_STRING("output_unmixing", "File to save unmixing matrix to.", "u",
- "output_unmixing.csv");
-
-PARAM_DOUBLE("noise_std_dev", "Standard deviation of Gaussian noise.", "n",
- 0.175);
-PARAM_INT("replicates", "Number of Gaussian-perturbed replicates to use "
- "(per point) in Radical2D.", "r", 30);
-PARAM_INT("angles", "Number of angles to consider in brute-force search "
- "during Radical2D.", "a", 150);
-PARAM_INT("sweeps", "Number of sweeps; each sweep calls Radical2D once for "
- "each pair of dimensions.", "S", 0);
-PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
-PARAM_FLAG("objective", "If set, an estimate of the final objective function "
- "is printed.", "O");
-
-using namespace mlpack;
-using namespace mlpack::radical;
-using namespace mlpack::math;
-using namespace std;
-using namespace arma;
-
-int main(int argc, char* argv[])
-{
- // Handle parameters.
- CLI::ParseCommandLine(argc, argv);
-
- // Set random seed.
- if (CLI::GetParam<int>("seed") != 0)
- RandomSeed((size_t) CLI::GetParam<int>("seed"));
- else
- RandomSeed((size_t) std::time(NULL));
-
- // Load the data.
- const string matXFilename = CLI::GetParam<string>("input_file");
- mat matX;
- data::Load(matXFilename, matX);
-
- // Load parameters.
- double noiseStdDev = CLI::GetParam<double>("noise_std_dev");
- size_t nReplicates = CLI::GetParam<int>("replicates");
- size_t nAngles = CLI::GetParam<int>("angles");
- size_t nSweeps = CLI::GetParam<int>("sweeps");
-
- if (nSweeps == 0)
- {
- nSweeps = matX.n_rows - 1;
- }
-
- // Run RADICAL.
- Radical rad(noiseStdDev, nReplicates, nAngles, nSweeps);
- mat matY;
- mat matW;
- rad.DoRadical(matX, matY, matW);
-
- // Save results.
- const string matYFilename = CLI::GetParam<string>("output_ic");
- data::Save(matYFilename, matY);
-
- const string matWFilename = CLI::GetParam<string>("output_unmixing");
- data::Save(matWFilename, matW);
-
- if (CLI::HasParam("objective"))
- {
- // Compute and print objective.
- mat matYT = trans(matY);
- double valEst = 0;
- for (size_t i = 0; i < matYT.n_cols; i++)
- {
- vec y = vec(matYT.col(i));
- valEst += rad.Vasicek(y);
- }
-
- // Force output even if --verbose is not given.
- const bool ignoring = Log::Info.ignoreInput;
- Log::Info.ignoreInput = false;
- Log::Info << "Objective (estimate): " << valEst << "." << endl;
- Log::Info.ignoreInput = ignoring;
- }
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/radical/radical_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/radical/radical_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/radical/radical_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/radical/radical_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,114 @@
+/**
+ * @file radical_main.cpp
+ * @author Nishant Mehta
+ *
+ * Executable for RADICAL.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include <mlpack/core.hpp>
+#include "radical.hpp"
+
+PROGRAM_INFO("RADICAL", "An implementation of RADICAL, a method for independent"
+ "component analysis (ICA). Assuming that we have an input matrix X, the"
+ "goal is to find a square unmixing matrix W such that Y = W * X and the "
+ "dimensions of Y are independent components. If the algorithm is running"
+ "particularly slowly, try reducing the number of replicates.");
+
+PARAM_STRING_REQ("input_file", "Input dataset filename for ICA.", "i");
+
+PARAM_STRING("output_ic", "File to save independent components to.", "o",
+ "output_ic.csv");
+PARAM_STRING("output_unmixing", "File to save unmixing matrix to.", "u",
+ "output_unmixing.csv");
+
+PARAM_DOUBLE("noise_std_dev", "Standard deviation of Gaussian noise.", "n",
+ 0.175);
+PARAM_INT("replicates", "Number of Gaussian-perturbed replicates to use "
+ "(per point) in Radical2D.", "r", 30);
+PARAM_INT("angles", "Number of angles to consider in brute-force search "
+ "during Radical2D.", "a", 150);
+PARAM_INT("sweeps", "Number of sweeps; each sweep calls Radical2D once for "
+ "each pair of dimensions.", "S", 0);
+PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
+PARAM_FLAG("objective", "If set, an estimate of the final objective function "
+ "is printed.", "O");
+
+using namespace mlpack;
+using namespace mlpack::radical;
+using namespace mlpack::math;
+using namespace std;
+using namespace arma;
+
+int main(int argc, char* argv[])
+{
+ // Handle parameters.
+ CLI::ParseCommandLine(argc, argv);
+
+ // Set random seed.
+ if (CLI::GetParam<int>("seed") != 0)
+ RandomSeed((size_t) CLI::GetParam<int>("seed"));
+ else
+ RandomSeed((size_t) std::time(NULL));
+
+ // Load the data.
+ const string matXFilename = CLI::GetParam<string>("input_file");
+ mat matX;
+ data::Load(matXFilename, matX);
+
+ // Load parameters.
+ double noiseStdDev = CLI::GetParam<double>("noise_std_dev");
+ size_t nReplicates = CLI::GetParam<int>("replicates");
+ size_t nAngles = CLI::GetParam<int>("angles");
+ size_t nSweeps = CLI::GetParam<int>("sweeps");
+
+ if (nSweeps == 0)
+ {
+ nSweeps = matX.n_rows - 1;
+ }
+
+ // Run RADICAL.
+ Radical rad(noiseStdDev, nReplicates, nAngles, nSweeps);
+ mat matY;
+ mat matW;
+ rad.DoRadical(matX, matY, matW);
+
+ // Save results.
+ const string matYFilename = CLI::GetParam<string>("output_ic");
+ data::Save(matYFilename, matY);
+
+ const string matWFilename = CLI::GetParam<string>("output_unmixing");
+ data::Save(matWFilename, matW);
+
+ if (CLI::HasParam("objective"))
+ {
+ // Compute and print objective.
+ mat matYT = trans(matY);
+ double valEst = 0;
+ for (size_t i = 0; i < matYT.n_cols; i++)
+ {
+ vec y = vec(matYT.col(i));
+ valEst += rad.Vasicek(y);
+ }
+
+ // Force output even if --verbose is not given.
+ const bool ignoring = Log::Info.ignoreInput;
+ Log::Info.ignoreInput = false;
+ Log::Info << "Objective (estimate): " << valEst << "." << endl;
+ Log::Info.ignoreInput = ignoring;
+ }
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/range_search/range_search.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/range_search/range_search.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/range_search/range_search.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,294 +0,0 @@
-/**
- * @file range_search.hpp
- * @author Ryan Curtin
- *
- * Defines the RangeSearch class, which performs a generalized range search on
- * points.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_HPP
-#define __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_HPP
-
-#include <mlpack/core.hpp>
-
-#include <mlpack/core/metrics/lmetric.hpp>
-
-#include <mlpack/core/tree/binary_space_tree.hpp>
-
-namespace mlpack {
-namespace range /** Range-search routines. */ {
-
-/**
- * The RangeSearch class is a template class for performing range searches.
- */
-template<typename MetricType = mlpack::metric::SquaredEuclideanDistance,
- typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
- tree::EmptyStatistic> >
-class RangeSearch
-{
- public:
- /**
- * Initialize the RangeSearch object with a different reference set and a
- * query set. Optionally, perform the computation in naive mode or
- * single-tree mode, and set the leaf size used for tree-building.
- * Additionally, an instantiated metric can be given, for cases where the
- * distance metric holds data.
- *
- * This method will copy the matrices to internal copies, which are rearranged
- * during tree-building. You can avoid this extra copy by pre-constructing
- * the trees and passing them using a different constructor.
- *
- * @param referenceSet Reference dataset.
- * @param querySet Query dataset.
- * @param naive Whether the computation should be done in O(n^2) naive mode.
- * @param singleMode Whether single-tree computation should be used (as
- * opposed to dual-tree computation).
- * @param leafSize The leaf size to be used during tree construction.
- * @param metric Instantiated distance metric.
- */
- RangeSearch(const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
- const bool naive = false,
- const bool singleMode = false,
- const size_t leafSize = 20,
- const MetricType metric = MetricType());
-
- /**
- * Initialize the RangeSearch object with only a reference set, which will
- * also be used as a query set. Optionally, perform the computation in naive
- * mode or single-tree mode, and set the leaf size used for tree-building.
- * Additionally an instantiated metric can be given, for cases where the
- * distance metric holds data.
- *
- * This method will copy the reference matrix to an internal copy, which is
- * rearranged during tree-building. You can avoid this extra copy by
- * pre-constructing the reference tree and passing it using a different
- * constructor.
- *
- * @param referenceSet Reference dataset.
- * @param naive Whether the computation should be done in O(n^2) naive mode.
- * @param singleMode Whether single-tree computation should be used (as
- * opposed to dual-tree computation).
- * @param leafSize The leaf size to be used during tree construction.
- * @param metric Instantiated distance metric.
- */
- RangeSearch(const typename TreeType::Mat& referenceSet,
- const bool naive = false,
- const bool singleMode = false,
- const size_t leafSize = 20,
- const MetricType metric = MetricType());
-
- /**
- * Initialize the RangeSearch object with the given datasets and
- * pre-constructed trees. It is assumed that the points in referenceSet and
- * querySet correspond to the points in referenceTree and queryTree,
- * respectively. Optionally, choose to use single-tree mode. Naive
- * mode is not available as an option for this constructor; instead, to run
- * naive computation, construct a tree with all the points in one leaf (i.e.
- * leafSize = number of points). Additionally, an instantiated distance
- * metric can be given, for cases where the distance metric holds data.
- *
- * There is no copying of the data matrices in this constructor (because
- * tree-building is not necessary), so this is the constructor to use when
- * copies absolutely must be avoided.
- *
- * @note
- * Because tree-building (at least with BinarySpaceTree) modifies the ordering
- * of a matrix, be sure you pass the modified matrix to this object! In
- * addition, mapping the points of the matrix back to their original indices
- * is not done when this constructor is used.
- * @endnote
- *
- * @param referenceTree Pre-built tree for reference points.
- * @param queryTree Pre-built tree for query points.
- * @param referenceSet Set of reference points corresponding to referenceTree.
- * @param querySet Set of query points corresponding to queryTree.
- * @param singleMode Whether single-tree computation should be used (as
- * opposed to dual-tree computation).
- * @param metric Instantiated distance metric.
- */
- RangeSearch(TreeType* referenceTree,
- TreeType* queryTree,
- const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
- const bool singleMode = false,
- const MetricType metric = MetricType());
-
- /**
- * Initialize the RangeSearch object with the given reference dataset and
- * pre-constructed tree. It is assumed that the points in referenceSet
- * correspond to the points in referenceTree. Optionally, choose to use
- * single-tree mode. Naive mode is not available as an option for this
- * constructor; instead, to run naive computation, construct a tree with all
- * the points in one leaf (i.e. leafSize = number of points). Additionally,
- * an instantiated distance metric can be given, for the case where the
- * distance metric holds data.
- *
- * There is no copying of the data matrices in this constructor (because
- * tree-building is not necessary), so this is the constructor to use when
- * copies absolutely must be avoided.
- *
- * @note
- * Because tree-building (at least with BinarySpaceTree) modifies the ordering
- * of a matrix, be sure you pass the modified matrix to this object! In
- * addition, mapping the points of the matrix back to their original indices
- * is not done when this constructor is used.
- * @endnote
- *
- * @param referenceTree Pre-built tree for reference points.
- * @param referenceSet Set of reference points corresponding to referenceTree.
- * @param singleMode Whether single-tree computation should be used (as
- * opposed to dual-tree computation).
- * @param metric Instantiated distance metric.
- */
- RangeSearch(TreeType* referenceTree,
- const typename TreeType::Mat& referenceSet,
- const bool singleMode = false,
- const MetricType metric = MetricType());
-
- /**
- * Destroy the RangeSearch object. If trees were created, they will be
- * deleted.
- */
- ~RangeSearch();
-
- /**
- * Search for all points in the given range, returning the results in the
- * neighbors and distances objects. Each entry in the external vector
- * corresponds to a query point. Each of these entries holds a vector which
- * contains the indices and distances of the reference points falling into the
- * given range.
- *
- * That is:
- *
- * - neighbors.size() and distances.size() both equal the number of query
- * points.
- *
- * - neighbors[i] contains the indices of all the points in the reference set
- * which have distances inside the given range to query point i.
- *
- * - distances[i] contains all of the distances corresponding to the indices
- * contained in neighbors[i].
- *
- * - neighbors[i] and distances[i] are not sorted in any particular order.
- *
- * @param range Range of distances in which to search.
- * @param neighbors Object which will hold the list of neighbors for each
- * point which fell into the given range, for each query point.
- * @param distances Object which will hold the list of distances for each
- * point which fell into the given range, for each query point.
- */
- void Search(const math::Range& range,
- std::vector<std::vector<size_t> >& neighbors,
- std::vector<std::vector<double> >& distances);
-
- private:
- /**
- * Compute the base case, when both referenceNode and queryNode are leaves
- * containing points.
- *
- * @param referenceNode Reference node (must be a leaf).
- * @param queryNode Query node (must be a leaf).
- * @param range Range of distances to search for.
- * @param neighbors Object holding list of neighbors.
- * @param distances Object holding list of distances.
- */
- void ComputeBaseCase(const TreeType* referenceNode,
- const TreeType* queryNode,
- const math::Range& range,
- std::vector<std::vector<size_t> >& neighbors,
- std::vector<std::vector<double> >& distances) const;
-
- /**
- * Perform the dual-tree recursion, which will recurse until the base case is
- * necessary.
- *
- * @param referenceNode Reference node.
- * @param queryNode Query node.
- * @param range Range of distances to search for.
- * @param neighbors Object holding list of neighbors.
- * @param distances Object holding list of distances.
- */
- void DualTreeRecursion(const TreeType* referenceNode,
- const TreeType* queryNode,
- const math::Range& range,
- std::vector<std::vector<size_t> >& neighbors,
- std::vector<std::vector<double> >& distances);
-
- /**
- * Perform the single-tree recursion, which will recurse down the reference
- * tree to get the results for a single point.
- *
- * @param referenceNode Reference node.
- * @param queryPoint Point to query for.
- * @param queryIndex Index of query node.
- * @param range Range of distances to search for.
- * @param neighbors Object holding list of neighbors.
- * @param distances Object holding list of distances.
- */
- template<typename VecType>
- void SingleTreeRecursion(const TreeType* referenceNode,
- const VecType& queryPoint,
- const size_t queryIndex,
- const math::Range& range,
- std::vector<size_t>& neighbors,
- std::vector<double>& distances);
-
- //! Copy of reference matrix; used when a tree is built internally.
- typename TreeType::Mat referenceCopy;
- //! Copy of query matrix; used when a tree is built internally.
- typename TreeType::Mat queryCopy;
-
- //! Reference set (data should be accessed using this).
- const typename TreeType::Mat& referenceSet;
- //! Query set (data should be accessed using this).
- const typename TreeType::Mat& querySet;
-
- //! Reference tree.
- TreeType* referenceTree;
- //! Query tree (may be NULL).
- TreeType* queryTree;
-
- //! Mappings to old reference indices (used when this object builds trees).
- std::vector<size_t> oldFromNewReferences;
- //! Mappings to old query indices (used when this object builds trees).
- std::vector<size_t> oldFromNewQueries;
-
- //! Indicates ownership of the reference tree (meaning we need to delete it).
- bool ownReferenceTree;
- //! Indicates ownership of the query tree (meaning we need to delete it).
- bool ownQueryTree;
-
- //! If true, O(n^2) naive computation is used.
- bool naive;
- //! If true, single-tree computation is used.
- bool singleMode;
-
- //! Instantiated distance metric.
- MetricType metric;
-
- //! The number of pruned nodes during computation.
- size_t numberOfPrunes;
-};
-
-}; // namespace range
-}; // namespace mlpack
-
-// Include implementation.
-#include "range_search_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/range_search/range_search.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/range_search/range_search.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/range_search/range_search.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/range_search/range_search.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,294 @@
+/**
+ * @file range_search.hpp
+ * @author Ryan Curtin
+ *
+ * Defines the RangeSearch class, which performs a generalized range search on
+ * points.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_HPP
+#define __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_HPP
+
+#include <mlpack/core.hpp>
+
+#include <mlpack/core/metrics/lmetric.hpp>
+
+#include <mlpack/core/tree/binary_space_tree.hpp>
+
+namespace mlpack {
+namespace range /** Range-search routines. */ {
+
+/**
+ * The RangeSearch class is a template class for performing range searches.
+ */
+template<typename MetricType = mlpack::metric::SquaredEuclideanDistance,
+ typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
+ tree::EmptyStatistic> >
+class RangeSearch
+{
+ public:
+ /**
+ * Initialize the RangeSearch object with a different reference set and a
+ * query set. Optionally, perform the computation in naive mode or
+ * single-tree mode, and set the leaf size used for tree-building.
+ * Additionally, an instantiated metric can be given, for cases where the
+ * distance metric holds data.
+ *
+ * This method will copy the matrices to internal copies, which are rearranged
+ * during tree-building. You can avoid this extra copy by pre-constructing
+ * the trees and passing them using a different constructor.
+ *
+ * @param referenceSet Reference dataset.
+ * @param querySet Query dataset.
+ * @param naive Whether the computation should be done in O(n^2) naive mode.
+ * @param singleMode Whether single-tree computation should be used (as
+ * opposed to dual-tree computation).
+ * @param leafSize The leaf size to be used during tree construction.
+ * @param metric Instantiated distance metric.
+ */
+ RangeSearch(const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
+ const bool naive = false,
+ const bool singleMode = false,
+ const size_t leafSize = 20,
+ const MetricType metric = MetricType());
+
+ /**
+ * Initialize the RangeSearch object with only a reference set, which will
+ * also be used as a query set. Optionally, perform the computation in naive
+ * mode or single-tree mode, and set the leaf size used for tree-building.
+ * Additionally an instantiated metric can be given, for cases where the
+ * distance metric holds data.
+ *
+ * This method will copy the reference matrix to an internal copy, which is
+ * rearranged during tree-building. You can avoid this extra copy by
+ * pre-constructing the reference tree and passing it using a different
+ * constructor.
+ *
+ * @param referenceSet Reference dataset.
+ * @param naive Whether the computation should be done in O(n^2) naive mode.
+ * @param singleMode Whether single-tree computation should be used (as
+ * opposed to dual-tree computation).
+ * @param leafSize The leaf size to be used during tree construction.
+ * @param metric Instantiated distance metric.
+ */
+ RangeSearch(const typename TreeType::Mat& referenceSet,
+ const bool naive = false,
+ const bool singleMode = false,
+ const size_t leafSize = 20,
+ const MetricType metric = MetricType());
+
+ /**
+ * Initialize the RangeSearch object with the given datasets and
+ * pre-constructed trees. It is assumed that the points in referenceSet and
+ * querySet correspond to the points in referenceTree and queryTree,
+ * respectively. Optionally, choose to use single-tree mode. Naive
+ * mode is not available as an option for this constructor; instead, to run
+ * naive computation, construct a tree with all the points in one leaf (i.e.
+ * leafSize = number of points). Additionally, an instantiated distance
+ * metric can be given, for cases where the distance metric holds data.
+ *
+ * There is no copying of the data matrices in this constructor (because
+ * tree-building is not necessary), so this is the constructor to use when
+ * copies absolutely must be avoided.
+ *
+ * @note
+ * Because tree-building (at least with BinarySpaceTree) modifies the ordering
+ * of a matrix, be sure you pass the modified matrix to this object! In
+ * addition, mapping the points of the matrix back to their original indices
+ * is not done when this constructor is used.
+ * @endnote
+ *
+ * @param referenceTree Pre-built tree for reference points.
+ * @param queryTree Pre-built tree for query points.
+ * @param referenceSet Set of reference points corresponding to referenceTree.
+ * @param querySet Set of query points corresponding to queryTree.
+ * @param singleMode Whether single-tree computation should be used (as
+ * opposed to dual-tree computation).
+ * @param metric Instantiated distance metric.
+ */
+ RangeSearch(TreeType* referenceTree,
+ TreeType* queryTree,
+ const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
+ const bool singleMode = false,
+ const MetricType metric = MetricType());
+
+ /**
+ * Initialize the RangeSearch object with the given reference dataset and
+ * pre-constructed tree. It is assumed that the points in referenceSet
+ * correspond to the points in referenceTree. Optionally, choose to use
+ * single-tree mode. Naive mode is not available as an option for this
+ * constructor; instead, to run naive computation, construct a tree with all
+ * the points in one leaf (i.e. leafSize = number of points). Additionally,
+ * an instantiated distance metric can be given, for the case where the
+ * distance metric holds data.
+ *
+ * There is no copying of the data matrices in this constructor (because
+ * tree-building is not necessary), so this is the constructor to use when
+ * copies absolutely must be avoided.
+ *
+ * @note
+ * Because tree-building (at least with BinarySpaceTree) modifies the ordering
+ * of a matrix, be sure you pass the modified matrix to this object! In
+ * addition, mapping the points of the matrix back to their original indices
+ * is not done when this constructor is used.
+ * @endnote
+ *
+ * @param referenceTree Pre-built tree for reference points.
+ * @param referenceSet Set of reference points corresponding to referenceTree.
+ * @param singleMode Whether single-tree computation should be used (as
+ * opposed to dual-tree computation).
+ * @param metric Instantiated distance metric.
+ */
+ RangeSearch(TreeType* referenceTree,
+ const typename TreeType::Mat& referenceSet,
+ const bool singleMode = false,
+ const MetricType metric = MetricType());
+
+ /**
+ * Destroy the RangeSearch object. If trees were created, they will be
+ * deleted.
+ */
+ ~RangeSearch();
+
+ /**
+ * Search for all points in the given range, returning the results in the
+ * neighbors and distances objects. Each entry in the external vector
+ * corresponds to a query point. Each of these entries holds a vector which
+ * contains the indices and distances of the reference points falling into the
+ * given range.
+ *
+ * That is:
+ *
+ * - neighbors.size() and distances.size() both equal the number of query
+ * points.
+ *
+ * - neighbors[i] contains the indices of all the points in the reference set
+ * which have distances inside the given range to query point i.
+ *
+ * - distances[i] contains all of the distances corresponding to the indices
+ * contained in neighbors[i].
+ *
+ * - neighbors[i] and distances[i] are not sorted in any particular order.
+ *
+ * @param range Range of distances in which to search.
+ * @param neighbors Object which will hold the list of neighbors for each
+ * point which fell into the given range, for each query point.
+ * @param distances Object which will hold the list of distances for each
+ * point which fell into the given range, for each query point.
+ */
+ void Search(const math::Range& range,
+ std::vector<std::vector<size_t> >& neighbors,
+ std::vector<std::vector<double> >& distances);
+
+ private:
+ /**
+ * Compute the base case, when both referenceNode and queryNode are leaves
+ * containing points.
+ *
+ * @param referenceNode Reference node (must be a leaf).
+ * @param queryNode Query node (must be a leaf).
+ * @param range Range of distances to search for.
+ * @param neighbors Object holding list of neighbors.
+ * @param distances Object holding list of distances.
+ */
+ void ComputeBaseCase(const TreeType* referenceNode,
+ const TreeType* queryNode,
+ const math::Range& range,
+ std::vector<std::vector<size_t> >& neighbors,
+ std::vector<std::vector<double> >& distances) const;
+
+ /**
+ * Perform the dual-tree recursion, which will recurse until the base case is
+ * necessary.
+ *
+ * @param referenceNode Reference node.
+ * @param queryNode Query node.
+ * @param range Range of distances to search for.
+ * @param neighbors Object holding list of neighbors.
+ * @param distances Object holding list of distances.
+ */
+ void DualTreeRecursion(const TreeType* referenceNode,
+ const TreeType* queryNode,
+ const math::Range& range,
+ std::vector<std::vector<size_t> >& neighbors,
+ std::vector<std::vector<double> >& distances);
+
+ /**
+ * Perform the single-tree recursion, which will recurse down the reference
+ * tree to get the results for a single point.
+ *
+ * @param referenceNode Reference node.
+ * @param queryPoint Point to query for.
+ * @param queryIndex Index of query node.
+ * @param range Range of distances to search for.
+ * @param neighbors Object holding list of neighbors.
+ * @param distances Object holding list of distances.
+ */
+ template<typename VecType>
+ void SingleTreeRecursion(const TreeType* referenceNode,
+ const VecType& queryPoint,
+ const size_t queryIndex,
+ const math::Range& range,
+ std::vector<size_t>& neighbors,
+ std::vector<double>& distances);
+
+ //! Copy of reference matrix; used when a tree is built internally.
+ typename TreeType::Mat referenceCopy;
+ //! Copy of query matrix; used when a tree is built internally.
+ typename TreeType::Mat queryCopy;
+
+ //! Reference set (data should be accessed using this).
+ const typename TreeType::Mat& referenceSet;
+ //! Query set (data should be accessed using this).
+ const typename TreeType::Mat& querySet;
+
+ //! Reference tree.
+ TreeType* referenceTree;
+ //! Query tree (may be NULL).
+ TreeType* queryTree;
+
+ //! Mappings to old reference indices (used when this object builds trees).
+ std::vector<size_t> oldFromNewReferences;
+ //! Mappings to old query indices (used when this object builds trees).
+ std::vector<size_t> oldFromNewQueries;
+
+ //! Indicates ownership of the reference tree (meaning we need to delete it).
+ bool ownReferenceTree;
+ //! Indicates ownership of the query tree (meaning we need to delete it).
+ bool ownQueryTree;
+
+ //! If true, O(n^2) naive computation is used.
+ bool naive;
+ //! If true, single-tree computation is used.
+ bool singleMode;
+
+ //! Instantiated distance metric.
+ MetricType metric;
+
+ //! The number of pruned nodes during computation.
+ size_t numberOfPrunes;
+};
+
+}; // namespace range
+}; // namespace mlpack
+
+// Include implementation.
+#include "range_search_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/range_search/range_search_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/range_search/range_search_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/range_search/range_search_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,470 +0,0 @@
-/**
- * @file range_search_impl.hpp
- * @author Ryan Curtin
- *
- * Implementation of the RangeSearch class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_IMPL_HPP
-#define __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_IMPL_HPP
-
-// Just in case it hasn't been included.
-#include "range_search.hpp"
-
-namespace mlpack {
-namespace range {
-
-template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::RangeSearch(
- const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
- const bool naive,
- const bool singleMode,
- const size_t leafSize,
- const MetricType metric) :
- referenceCopy(referenceSet),
- queryCopy(querySet),
- referenceSet(referenceCopy),
- querySet(queryCopy),
- ownReferenceTree(true),
- ownQueryTree(true),
- naive(naive),
- singleMode(!naive && singleMode), // Naive overrides single mode.
- metric(metric),
- numberOfPrunes(0)
-{
- // Build the trees.
- Timer::Start("range_search/tree_building");
-
- // Naive sets the leaf size such that the entire tree is one node.
- referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
- (naive ? referenceCopy.n_cols : leafSize));
-
- queryTree = new TreeType(queryCopy, oldFromNewQueries,
- (naive ? queryCopy.n_cols : leafSize));
-
- Timer::Stop("range_search/tree_building");
-}
-
-template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::RangeSearch(
- const typename TreeType::Mat& referenceSet,
- const bool naive,
- const bool singleMode,
- const size_t leafSize,
- const MetricType metric) :
- referenceCopy(referenceSet),
- referenceSet(referenceCopy),
- querySet(referenceCopy),
- queryTree(NULL),
- ownReferenceTree(true),
- ownQueryTree(false),
- naive(naive),
- singleMode(!naive && singleMode), // Naive overrides single mode.
- metric(metric),
- numberOfPrunes(0)
-{
- // Build the trees.
- Timer::Start("range_search/tree_building");
-
- // Naive sets the leaf size such that the entire tree is one node.
- referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
- (naive ? referenceCopy.n_cols : leafSize));
-
- Timer::Stop("range_search/tree_building");
-}
-
-template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::RangeSearch(
- TreeType* referenceTree,
- TreeType* queryTree,
- const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
- const bool singleMode,
- const MetricType metric) :
- referenceSet(referenceSet),
- querySet(querySet),
- referenceTree(referenceTree),
- queryTree(queryTree),
- ownReferenceTree(false),
- ownQueryTree(false),
- naive(false),
- singleMode(singleMode),
- metric(metric),
- numberOfPrunes(0)
-{
- // Nothing else to initialize.
-}
-
-template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::RangeSearch(
- TreeType* referenceTree,
- const typename TreeType::Mat& referenceSet,
- const bool singleMode,
- const MetricType metric) :
- referenceSet(referenceSet),
- querySet(referenceSet),
- referenceTree(referenceTree),
- queryTree(NULL),
- ownReferenceTree(false),
- ownQueryTree(false),
- naive(false),
- singleMode(singleMode),
- metric(metric),
- numberOfPrunes(0)
-{
- // Nothing else to initialize.
-}
-
-template<typename MetricType, typename TreeType>
-RangeSearch<MetricType, TreeType>::~RangeSearch()
-{
- if (ownReferenceTree)
- delete referenceTree;
- if (ownQueryTree)
- delete queryTree;
-}
-
-template<typename MetricType, typename TreeType>
-void RangeSearch<MetricType, TreeType>::Search(
- const math::Range& range,
- std::vector<std::vector<size_t> >& neighbors,
- std::vector<std::vector<double> >& distances)
-{
- Timer::Start("range_search/computing_neighbors");
-
- // Set size of prunes to 0.
- numberOfPrunes = 0;
-
- // If we have built the trees ourselves, then we will have to map all the
- // indices back to their original indices when this computation is finished.
- // To avoid extra copies, we will store the unmapped neighbors and distances
- // in a separate object.
- std::vector<std::vector<size_t> >* neighborPtr = &neighbors;
- std::vector<std::vector<double> >* distancePtr = &distances;
-
- if (ownQueryTree || (ownReferenceTree && !queryTree))
- distancePtr = new std::vector<std::vector<double> >;
- if (ownReferenceTree || ownQueryTree)
- neighborPtr = new std::vector<std::vector<size_t> >;
-
- // Resize each vector.
- neighborPtr->clear(); // Just in case there was anything in it.
- neighborPtr->resize(querySet.n_cols);
- distancePtr->clear();
- distancePtr->resize(querySet.n_cols);
-
- if (naive)
- {
- // Run the base case.
- if (!queryTree)
- ComputeBaseCase(referenceTree, referenceTree, range, *neighborPtr,
- *distancePtr);
- else
- ComputeBaseCase(referenceTree, queryTree, range, *neighborPtr,
- *distancePtr);
- }
- else if (singleMode)
- {
- // Loop over each of the query points.
- for (size_t i = 0; i < querySet.n_cols; i++)
- {
- SingleTreeRecursion(referenceTree, querySet.col(i), i, range,
- (*neighborPtr)[i], (*distancePtr)[i]);
- }
- }
- else
- {
- if (!queryTree) // References are the same as queries.
- DualTreeRecursion(referenceTree, referenceTree, range, *neighborPtr,
- *distancePtr);
- else
- DualTreeRecursion(referenceTree, queryTree, range, *neighborPtr,
- *distancePtr);
- }
-
- Timer::Stop("range_search/computing_neighbors");
-
- // Output number of prunes.
- Log::Info << "Number of pruned nodes during computation: " << numberOfPrunes
- << "." << std::endl;
-
- // Map points back to original indices, if necessary.
- if (!ownReferenceTree && !ownQueryTree)
- {
- // No mapping needed. We are done.
- return;
- }
- else if (ownReferenceTree && ownQueryTree) // Map references and queries.
- {
- neighbors.clear();
- neighbors.resize(querySet.n_cols);
- distances.clear();
- distances.resize(querySet.n_cols);
-
- for (size_t i = 0; i < distances.size(); i++)
- {
- // Map distances (copy a column).
- size_t queryMapping = oldFromNewQueries[i];
- distances[queryMapping] = (*distancePtr)[i];
-
- // Copy each neighbor individually, because we need to map it.
- neighbors[queryMapping].resize(distances[queryMapping].size());
- for (size_t j = 0; j < distances[queryMapping].size(); j++)
- {
- neighbors[queryMapping][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
- }
- }
-
- // Finished with temporary objects.
- delete neighborPtr;
- delete distancePtr;
- }
- else if (ownReferenceTree)
- {
- if (!queryTree) // No query tree -- map both references and queries.
- {
- neighbors.clear();
- neighbors.resize(querySet.n_cols);
- distances.clear();
- distances.resize(querySet.n_cols);
-
- for (size_t i = 0; i < distances.size(); i++)
- {
- // Map distances (copy a column).
- size_t refMapping = oldFromNewReferences[i];
- distances[refMapping] = (*distancePtr)[i];
-
- // Copy each neighbor individually, because we need to map it.
- neighbors[refMapping].resize(distances[refMapping].size());
- for (size_t j = 0; j < distances[refMapping].size(); j++)
- {
- neighbors[refMapping][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
- }
- }
-
- // Finished with temporary objects.
- delete neighborPtr;
- delete distancePtr;
- }
- else // Map only references.
- {
- neighbors.clear();
- neighbors.resize(querySet.n_cols);
-
- // Map indices of neighbors.
- for (size_t i = 0; i < neighbors.size(); i++)
- {
- neighbors[i].resize((*neighborPtr)[i].size());
- for (size_t j = 0; j < neighbors[i].size(); j++)
- {
- neighbors[i][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
- }
- }
-
- // Finished with temporary object.
- delete neighborPtr;
- }
- }
- else if (ownQueryTree)
- {
- neighbors.clear();
- neighbors.resize(querySet.n_cols);
- distances.clear();
- distances.resize(querySet.n_cols);
-
- for (size_t i = 0; i < distances.size(); i++)
- {
- // Map distances (copy a column).
- distances[oldFromNewQueries[i]] = (*distancePtr)[i];
-
- // Map neighbors.
- neighbors[oldFromNewQueries[i]] = (*neighborPtr)[i];
- }
-
- // Finished with temporary objects.
- delete neighborPtr;
- delete distancePtr;
- }
-}
-
-template<typename MetricType, typename TreeType>
-void RangeSearch<MetricType, TreeType>::ComputeBaseCase(
- const TreeType* referenceNode,
- const TreeType* queryNode,
- const math::Range& range,
- std::vector<std::vector<size_t> >& neighbors,
- std::vector<std::vector<double> >& distances) const
-{
- // node->Begin() is the index of the first point in the node,
- // node->End() is one past the last index.
- for (size_t queryIndex = queryNode->Begin(); queryIndex < queryNode->End();
- queryIndex++)
- {
- double minDistance =
- referenceNode->Bound().MinDistance(querySet.col(queryIndex));
- double maxDistance =
- referenceNode->Bound().MaxDistance(querySet.col(queryIndex));
-
- // Now see if any points could fall into the range.
- if (range.Contains(math::Range(minDistance, maxDistance)))
- {
- // Loop through the reference points and see which fall into the range.
- for (size_t referenceIndex = referenceNode->Begin();
- referenceIndex < referenceNode->End(); referenceIndex++)
- {
- // We can't add points that are ourselves.
- if (referenceNode != queryNode || referenceIndex != queryIndex)
- {
- double distance = metric.Evaluate(querySet.col(queryIndex),
- referenceSet.col(referenceIndex));
-
- // If this lies in the range, add it.
- if (range.Contains(distance))
- {
- neighbors[queryIndex].push_back(referenceIndex);
- distances[queryIndex].push_back(distance);
- }
- }
- }
- }
- }
-}
-
-template<typename MetricType, typename TreeType>
-void RangeSearch<MetricType, TreeType>::DualTreeRecursion(
- const TreeType* referenceNode,
- const TreeType* queryNode,
- const math::Range& range,
- std::vector<std::vector<size_t> >& neighbors,
- std::vector<std::vector<double> >& distances)
-{
- // See if we can prune this node.
- math::Range distance =
- referenceNode->Bound().RangeDistance(queryNode->Bound());
-
- if (!range.Contains(distance))
- {
- numberOfPrunes++; // Don't recurse. These nodes can't contain anything.
- return;
- }
-
- // If both nodes are leaves, then we compute the base case.
- if (referenceNode->IsLeaf() && queryNode->IsLeaf())
- {
- ComputeBaseCase(referenceNode, queryNode, range, neighbors, distances);
- }
- else if (referenceNode->IsLeaf())
- {
- // We must descend down the query node to get a leaf.
- DualTreeRecursion(referenceNode, queryNode->Left(), range, neighbors,
- distances);
- DualTreeRecursion(referenceNode, queryNode->Right(), range, neighbors,
- distances);
- }
- else if (queryNode->IsLeaf())
- {
- // We must descend down the reference node to get a leaf.
- DualTreeRecursion(referenceNode->Left(), queryNode, range, neighbors,
- distances);
- DualTreeRecursion(referenceNode->Right(), queryNode, range, neighbors,
- distances);
- }
- else
- {
- // First descend the left reference node.
- DualTreeRecursion(referenceNode->Left(), queryNode->Left(), range,
- neighbors, distances);
- DualTreeRecursion(referenceNode->Left(), queryNode->Right(), range,
- neighbors, distances);
-
- // Now descend the right reference node.
- DualTreeRecursion(referenceNode->Right(), queryNode->Left(), range,
- neighbors, distances);
- DualTreeRecursion(referenceNode->Right(), queryNode->Right(), range,
- neighbors, distances);
- }
-}
-
-template<typename MetricType, typename TreeType>
-template<typename VecType>
-void RangeSearch<MetricType, TreeType>::SingleTreeRecursion(
- const TreeType* referenceNode,
- const VecType& queryPoint,
- const size_t queryIndex,
- const math::Range& range,
- std::vector<size_t>& neighbors,
- std::vector<double>& distances)
-{
- // See if we need to recurse or if we can perform base-case computations.
- if (referenceNode->IsLeaf())
- {
- // Base case: reference node is a leaf.
- for (size_t referenceIndex = referenceNode->Begin(); referenceIndex !=
- referenceNode->End(); referenceIndex++)
- {
- // Don't add this point if it is the same as the query point.
- if (!queryTree && !(referenceIndex == queryIndex))
- {
- double distance = metric.Evaluate(queryPoint,
- referenceSet.col(referenceIndex));
-
- // See if the point is in the range we are looking for.
- if (range.Contains(distance))
- {
- neighbors.push_back(referenceIndex);
- distances.push_back(distance);
- }
- }
- }
- }
- else
- {
- // Recurse down the tree.
- math::Range distanceLeft =
- referenceNode->Left()->Bound().RangeDistance(queryPoint);
- math::Range distanceRight =
- referenceNode->Right()->Bound().RangeDistance(queryPoint);
-
- if (range.Contains(distanceLeft))
- {
- // The left may have points we want to recurse to.
- SingleTreeRecursion(referenceNode->Left(), queryPoint, queryIndex,
- range, neighbors, distances);
- }
- else
- {
- numberOfPrunes++;
- }
-
- if (range.Contains(distanceRight))
- {
- // The right may have points we want to recurse to.
- SingleTreeRecursion(referenceNode->Right(), queryPoint, queryIndex,
- range, neighbors, distances);
- }
- else
- {
- numberOfPrunes++;
- }
- }
-}
-
-}; // namespace range
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/range_search/range_search_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/range_search/range_search_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/range_search/range_search_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/range_search/range_search_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,470 @@
+/**
+ * @file range_search_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the RangeSearch class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_IMPL_HPP
+#define __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_IMPL_HPP
+
+// Just in case it hasn't been included.
+#include "range_search.hpp"
+
+namespace mlpack {
+namespace range {
+
+template<typename MetricType, typename TreeType>
+RangeSearch<MetricType, TreeType>::RangeSearch(
+ const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
+ const bool naive,
+ const bool singleMode,
+ const size_t leafSize,
+ const MetricType metric) :
+ referenceCopy(referenceSet),
+ queryCopy(querySet),
+ referenceSet(referenceCopy),
+ querySet(queryCopy),
+ ownReferenceTree(true),
+ ownQueryTree(true),
+ naive(naive),
+ singleMode(!naive && singleMode), // Naive overrides single mode.
+ metric(metric),
+ numberOfPrunes(0)
+{
+ // Build the trees.
+ Timer::Start("range_search/tree_building");
+
+ // Naive sets the leaf size such that the entire tree is one node.
+ referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
+ (naive ? referenceCopy.n_cols : leafSize));
+
+ queryTree = new TreeType(queryCopy, oldFromNewQueries,
+ (naive ? queryCopy.n_cols : leafSize));
+
+ Timer::Stop("range_search/tree_building");
+}
+
+template<typename MetricType, typename TreeType>
+RangeSearch<MetricType, TreeType>::RangeSearch(
+ const typename TreeType::Mat& referenceSet,
+ const bool naive,
+ const bool singleMode,
+ const size_t leafSize,
+ const MetricType metric) :
+ referenceCopy(referenceSet),
+ referenceSet(referenceCopy),
+ querySet(referenceCopy),
+ queryTree(NULL),
+ ownReferenceTree(true),
+ ownQueryTree(false),
+ naive(naive),
+ singleMode(!naive && singleMode), // Naive overrides single mode.
+ metric(metric),
+ numberOfPrunes(0)
+{
+ // Build the trees.
+ Timer::Start("range_search/tree_building");
+
+ // Naive sets the leaf size such that the entire tree is one node.
+ referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
+ (naive ? referenceCopy.n_cols : leafSize));
+
+ Timer::Stop("range_search/tree_building");
+}
+
+template<typename MetricType, typename TreeType>
+RangeSearch<MetricType, TreeType>::RangeSearch(
+ TreeType* referenceTree,
+ TreeType* queryTree,
+ const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
+ const bool singleMode,
+ const MetricType metric) :
+ referenceSet(referenceSet),
+ querySet(querySet),
+ referenceTree(referenceTree),
+ queryTree(queryTree),
+ ownReferenceTree(false),
+ ownQueryTree(false),
+ naive(false),
+ singleMode(singleMode),
+ metric(metric),
+ numberOfPrunes(0)
+{
+ // Nothing else to initialize.
+}
+
+template<typename MetricType, typename TreeType>
+RangeSearch<MetricType, TreeType>::RangeSearch(
+ TreeType* referenceTree,
+ const typename TreeType::Mat& referenceSet,
+ const bool singleMode,
+ const MetricType metric) :
+ referenceSet(referenceSet),
+ querySet(referenceSet),
+ referenceTree(referenceTree),
+ queryTree(NULL),
+ ownReferenceTree(false),
+ ownQueryTree(false),
+ naive(false),
+ singleMode(singleMode),
+ metric(metric),
+ numberOfPrunes(0)
+{
+ // Nothing else to initialize.
+}
+
+template<typename MetricType, typename TreeType>
+RangeSearch<MetricType, TreeType>::~RangeSearch()
+{
+ if (ownReferenceTree)
+ delete referenceTree;
+ if (ownQueryTree)
+ delete queryTree;
+}
+
+template<typename MetricType, typename TreeType>
+void RangeSearch<MetricType, TreeType>::Search(
+ const math::Range& range,
+ std::vector<std::vector<size_t> >& neighbors,
+ std::vector<std::vector<double> >& distances)
+{
+ Timer::Start("range_search/computing_neighbors");
+
+ // Set size of prunes to 0.
+ numberOfPrunes = 0;
+
+ // If we have built the trees ourselves, then we will have to map all the
+ // indices back to their original indices when this computation is finished.
+ // To avoid extra copies, we will store the unmapped neighbors and distances
+ // in a separate object.
+ std::vector<std::vector<size_t> >* neighborPtr = &neighbors;
+ std::vector<std::vector<double> >* distancePtr = &distances;
+
+ if (ownQueryTree || (ownReferenceTree && !queryTree))
+ distancePtr = new std::vector<std::vector<double> >;
+ if (ownReferenceTree || ownQueryTree)
+ neighborPtr = new std::vector<std::vector<size_t> >;
+
+ // Resize each vector.
+ neighborPtr->clear(); // Just in case there was anything in it.
+ neighborPtr->resize(querySet.n_cols);
+ distancePtr->clear();
+ distancePtr->resize(querySet.n_cols);
+
+ if (naive)
+ {
+ // Run the base case.
+ if (!queryTree)
+ ComputeBaseCase(referenceTree, referenceTree, range, *neighborPtr,
+ *distancePtr);
+ else
+ ComputeBaseCase(referenceTree, queryTree, range, *neighborPtr,
+ *distancePtr);
+ }
+ else if (singleMode)
+ {
+ // Loop over each of the query points.
+ for (size_t i = 0; i < querySet.n_cols; i++)
+ {
+ SingleTreeRecursion(referenceTree, querySet.col(i), i, range,
+ (*neighborPtr)[i], (*distancePtr)[i]);
+ }
+ }
+ else
+ {
+ if (!queryTree) // References are the same as queries.
+ DualTreeRecursion(referenceTree, referenceTree, range, *neighborPtr,
+ *distancePtr);
+ else
+ DualTreeRecursion(referenceTree, queryTree, range, *neighborPtr,
+ *distancePtr);
+ }
+
+ Timer::Stop("range_search/computing_neighbors");
+
+ // Output number of prunes.
+ Log::Info << "Number of pruned nodes during computation: " << numberOfPrunes
+ << "." << std::endl;
+
+ // Map points back to original indices, if necessary.
+ if (!ownReferenceTree && !ownQueryTree)
+ {
+ // No mapping needed. We are done.
+ return;
+ }
+ else if (ownReferenceTree && ownQueryTree) // Map references and queries.
+ {
+ neighbors.clear();
+ neighbors.resize(querySet.n_cols);
+ distances.clear();
+ distances.resize(querySet.n_cols);
+
+ for (size_t i = 0; i < distances.size(); i++)
+ {
+ // Map distances (copy a column).
+ size_t queryMapping = oldFromNewQueries[i];
+ distances[queryMapping] = (*distancePtr)[i];
+
+ // Copy each neighbor individually, because we need to map it.
+ neighbors[queryMapping].resize(distances[queryMapping].size());
+ for (size_t j = 0; j < distances[queryMapping].size(); j++)
+ {
+ neighbors[queryMapping][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
+ }
+ }
+
+ // Finished with temporary objects.
+ delete neighborPtr;
+ delete distancePtr;
+ }
+ else if (ownReferenceTree)
+ {
+ if (!queryTree) // No query tree -- map both references and queries.
+ {
+ neighbors.clear();
+ neighbors.resize(querySet.n_cols);
+ distances.clear();
+ distances.resize(querySet.n_cols);
+
+ for (size_t i = 0; i < distances.size(); i++)
+ {
+ // Map distances (copy a column).
+ size_t refMapping = oldFromNewReferences[i];
+ distances[refMapping] = (*distancePtr)[i];
+
+ // Copy each neighbor individually, because we need to map it.
+ neighbors[refMapping].resize(distances[refMapping].size());
+ for (size_t j = 0; j < distances[refMapping].size(); j++)
+ {
+ neighbors[refMapping][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
+ }
+ }
+
+ // Finished with temporary objects.
+ delete neighborPtr;
+ delete distancePtr;
+ }
+ else // Map only references.
+ {
+ neighbors.clear();
+ neighbors.resize(querySet.n_cols);
+
+ // Map indices of neighbors.
+ for (size_t i = 0; i < neighbors.size(); i++)
+ {
+ neighbors[i].resize((*neighborPtr)[i].size());
+ for (size_t j = 0; j < neighbors[i].size(); j++)
+ {
+ neighbors[i][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
+ }
+ }
+
+ // Finished with temporary object.
+ delete neighborPtr;
+ }
+ }
+ else if (ownQueryTree)
+ {
+ neighbors.clear();
+ neighbors.resize(querySet.n_cols);
+ distances.clear();
+ distances.resize(querySet.n_cols);
+
+ for (size_t i = 0; i < distances.size(); i++)
+ {
+ // Map distances (copy a column).
+ distances[oldFromNewQueries[i]] = (*distancePtr)[i];
+
+ // Map neighbors.
+ neighbors[oldFromNewQueries[i]] = (*neighborPtr)[i];
+ }
+
+ // Finished with temporary objects.
+ delete neighborPtr;
+ delete distancePtr;
+ }
+}
+
+template<typename MetricType, typename TreeType>
+void RangeSearch<MetricType, TreeType>::ComputeBaseCase(
+ const TreeType* referenceNode,
+ const TreeType* queryNode,
+ const math::Range& range,
+ std::vector<std::vector<size_t> >& neighbors,
+ std::vector<std::vector<double> >& distances) const
+{
+ // node->Begin() is the index of the first point in the node,
+ // node->End() is one past the last index.
+ for (size_t queryIndex = queryNode->Begin(); queryIndex < queryNode->End();
+ queryIndex++)
+ {
+ double minDistance =
+ referenceNode->Bound().MinDistance(querySet.col(queryIndex));
+ double maxDistance =
+ referenceNode->Bound().MaxDistance(querySet.col(queryIndex));
+
+ // Now see if any points could fall into the range.
+ if (range.Contains(math::Range(minDistance, maxDistance)))
+ {
+ // Loop through the reference points and see which fall into the range.
+ for (size_t referenceIndex = referenceNode->Begin();
+ referenceIndex < referenceNode->End(); referenceIndex++)
+ {
+ // We can't add points that are ourselves.
+ if (referenceNode != queryNode || referenceIndex != queryIndex)
+ {
+ double distance = metric.Evaluate(querySet.col(queryIndex),
+ referenceSet.col(referenceIndex));
+
+ // If this lies in the range, add it.
+ if (range.Contains(distance))
+ {
+ neighbors[queryIndex].push_back(referenceIndex);
+ distances[queryIndex].push_back(distance);
+ }
+ }
+ }
+ }
+ }
+}
+
+template<typename MetricType, typename TreeType>
+void RangeSearch<MetricType, TreeType>::DualTreeRecursion(
+ const TreeType* referenceNode,
+ const TreeType* queryNode,
+ const math::Range& range,
+ std::vector<std::vector<size_t> >& neighbors,
+ std::vector<std::vector<double> >& distances)
+{
+ // See if we can prune this node.
+ math::Range distance =
+ referenceNode->Bound().RangeDistance(queryNode->Bound());
+
+ if (!range.Contains(distance))
+ {
+ numberOfPrunes++; // Don't recurse. These nodes can't contain anything.
+ return;
+ }
+
+ // If both nodes are leaves, then we compute the base case.
+ if (referenceNode->IsLeaf() && queryNode->IsLeaf())
+ {
+ ComputeBaseCase(referenceNode, queryNode, range, neighbors, distances);
+ }
+ else if (referenceNode->IsLeaf())
+ {
+ // We must descend down the query node to get a leaf.
+ DualTreeRecursion(referenceNode, queryNode->Left(), range, neighbors,
+ distances);
+ DualTreeRecursion(referenceNode, queryNode->Right(), range, neighbors,
+ distances);
+ }
+ else if (queryNode->IsLeaf())
+ {
+ // We must descend down the reference node to get a leaf.
+ DualTreeRecursion(referenceNode->Left(), queryNode, range, neighbors,
+ distances);
+ DualTreeRecursion(referenceNode->Right(), queryNode, range, neighbors,
+ distances);
+ }
+ else
+ {
+ // First descend the left reference node.
+ DualTreeRecursion(referenceNode->Left(), queryNode->Left(), range,
+ neighbors, distances);
+ DualTreeRecursion(referenceNode->Left(), queryNode->Right(), range,
+ neighbors, distances);
+
+ // Now descend the right reference node.
+ DualTreeRecursion(referenceNode->Right(), queryNode->Left(), range,
+ neighbors, distances);
+ DualTreeRecursion(referenceNode->Right(), queryNode->Right(), range,
+ neighbors, distances);
+ }
+}
+
+template<typename MetricType, typename TreeType>
+template<typename VecType>
+void RangeSearch<MetricType, TreeType>::SingleTreeRecursion(
+ const TreeType* referenceNode,
+ const VecType& queryPoint,
+ const size_t queryIndex,
+ const math::Range& range,
+ std::vector<size_t>& neighbors,
+ std::vector<double>& distances)
+{
+ // See if we need to recurse or if we can perform base-case computations.
+ if (referenceNode->IsLeaf())
+ {
+ // Base case: reference node is a leaf.
+ for (size_t referenceIndex = referenceNode->Begin(); referenceIndex !=
+ referenceNode->End(); referenceIndex++)
+ {
+ // Don't add this point if it is the same as the query point.
+ if (!queryTree && !(referenceIndex == queryIndex))
+ {
+ double distance = metric.Evaluate(queryPoint,
+ referenceSet.col(referenceIndex));
+
+ // See if the point is in the range we are looking for.
+ if (range.Contains(distance))
+ {
+ neighbors.push_back(referenceIndex);
+ distances.push_back(distance);
+ }
+ }
+ }
+ }
+ else
+ {
+ // Recurse down the tree.
+ math::Range distanceLeft =
+ referenceNode->Left()->Bound().RangeDistance(queryPoint);
+ math::Range distanceRight =
+ referenceNode->Right()->Bound().RangeDistance(queryPoint);
+
+ if (range.Contains(distanceLeft))
+ {
+ // The left may have points we want to recurse to.
+ SingleTreeRecursion(referenceNode->Left(), queryPoint, queryIndex,
+ range, neighbors, distances);
+ }
+ else
+ {
+ numberOfPrunes++;
+ }
+
+ if (range.Contains(distanceRight))
+ {
+ // The right may have points we want to recurse to.
+ SingleTreeRecursion(referenceNode->Right(), queryPoint, queryIndex,
+ range, neighbors, distances);
+ }
+ else
+ {
+ numberOfPrunes++;
+ }
+ }
+}
+
+}; // namespace range
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/range_search/range_search_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/range_search/range_search_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/range_search/range_search_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,301 +0,0 @@
-/**
- * @file range_search_main.cpp
- * @author Ryan Curtin
- * @author Matthew Amidon
- *
- * Implementation of the RangeSearch executable. Allows some number of standard
- * options.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-
-#include "range_search.hpp"
-
-using namespace std;
-using namespace mlpack;
-using namespace mlpack::range;
-using namespace mlpack::tree;
-
-// Information about the program itself.
-PROGRAM_INFO("Range Search",
- "This program implements range search with a Euclidean distance metric. "
- "For a given query point, a given range, and a given set of reference "
- "points, the program will return all of the reference points with distance "
- "to the query point in the given range. This is performed for an entire "
- "set of query points. You may specify a separate set of reference and query"
- " points, or only a reference set -- which is then used as both the "
- "reference and query set. The given range is taken to be inclusive (that "
- "is, points with a distance exactly equal to the minimum and maximum of the"
- " range are included in the results)."
- "\n\n"
- "For example, the following will calculate the points within the range [2, "
- "5] of each point in 'input.csv' and store the distances in 'distances.csv'"
- " and the neighbors in 'neighbors.csv':"
- "\n\n"
- "$ range_search --min=2 --max=5 --reference_file=input.csv\n"
- " --distances_file=distances.csv --neighbors_file=neighbors.csv"
- "\n\n"
- "The output files are organized such that line i corresponds to the points "
- "found for query point i. Because sometimes 0 points may be found in the "
- "given range, lines of the output files may be empty. The points are not "
- "ordered in any specific manner."
- "\n\n"
- "Because the number of points returned for each query point may differ, the"
- " resultant CSV-like files may not be loadable by many programs. However, "
- "at this time a better way to store this non-square result is not known. "
- "As a result, any output files will be written as CSVs in this manner, "
- "regardless of the given extension.");
-
-// Define our input parameters that this program will take.
-PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
- "r");
-PARAM_STRING_REQ("distances_file", "File to output distances into.", "d");
-PARAM_STRING_REQ("neighbors_file", "File to output neighbors into.", "n");
-
-PARAM_DOUBLE_REQ("max", "Upper bound in range.", "M");
-PARAM_DOUBLE("min", "Lower bound in range.", "m", 0.0);
-
-PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
-
-PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
-PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
-PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
- "dual-tree search.", "s");
-
-typedef RangeSearch<metric::SquaredEuclideanDistance,
- BinarySpaceTree<bound::HRectBound<2>, EmptyStatistic> > RSType;
-
-int main(int argc, char *argv[])
-{
- // Give CLI the command line parameters the user passed in.
- CLI::ParseCommandLine(argc, argv);
-
- // Get all the parameters.
- string referenceFile = CLI::GetParam<string>("reference_file");
-
- string distancesFile = CLI::GetParam<string>("distances_file");
- string neighborsFile = CLI::GetParam<string>("neighbors_file");
-
- int lsInt = CLI::GetParam<int>("leaf_size");
-
- double max = CLI::GetParam<double>("max");
- double min = CLI::GetParam<double>("min");
-
- bool naive = CLI::HasParam("naive");
- bool singleMode = CLI::HasParam("single_mode");
-
- arma::mat referenceData;
- arma::mat queryData; // So it doesn't go out of scope.
- if (!data::Load(referenceFile.c_str(), referenceData))
- Log::Fatal << "Reference file " << referenceFile << "not found." << endl;
-
- Log::Info << "Loaded reference data from '" << referenceFile << "'." << endl;
-
- // Sanity check on range value: max must be greater than min.
- if (max <= min)
- {
- Log::Fatal << "Invalid range: maximum (" << max << ") must be greater than "
- << "minimum (" << min << ")." << endl;
- }
-
- // Sanity check on leaf size.
- if (lsInt < 0)
- {
- Log::Fatal << "Invalid leaf size: " << lsInt << ". Must be greater "
- "than or equal to 0." << endl;
- }
- size_t leafSize = lsInt;
-
- // Naive mode overrides single mode.
- if (singleMode && naive)
- {
- Log::Warn << "--single_mode ignored because --naive is present." << endl;
- }
-
- if (naive)
- leafSize = referenceData.n_cols;
-
- vector<vector<size_t> > neighbors;
- vector<vector<double> > distances;
-
- // Because we may construct it differently, we need a pointer.
- RSType* rangeSearch = NULL;
-
- // Mappings for when we build the tree.
- vector<size_t> oldFromNewRefs;
-
- // Build trees by hand, so we can save memory: if we pass a tree to
- // NeighborSearch, it does not copy the matrix.
- Log::Info << "Building reference tree..." << endl;
- Timer::Start("tree_building");
-
- BinarySpaceTree<bound::HRectBound<2>, tree::EmptyStatistic>
- refTree(referenceData, oldFromNewRefs, leafSize);
- BinarySpaceTree<bound::HRectBound<2>, tree::EmptyStatistic>*
- queryTree = NULL; // Empty for now.
-
- Timer::Stop("tree_building");
-
- std::vector<size_t> oldFromNewQueries;
-
- if (CLI::GetParam<string>("query_file") != "")
- {
- string queryFile = CLI::GetParam<string>("query_file");
-
- if (!data::Load(queryFile.c_str(), queryData))
- Log::Fatal << "Query file " << queryFile << " not found" << endl;
-
- if (naive && leafSize < queryData.n_cols)
- leafSize = queryData.n_cols;
-
- Log::Info << "Loaded query data from '" << queryFile << "'." << endl;
-
- Log::Info << "Building query tree..." << endl;
-
- // Build trees by hand, so we can save memory: if we pass a tree to
- // NeighborSearch, it does not copy the matrix.
- Timer::Start("tree_building");
-
- queryTree = new BinarySpaceTree<bound::HRectBound<2>,
- tree::EmptyStatistic >(queryData, oldFromNewQueries,
- leafSize);
-
- Timer::Stop("tree_building");
-
- rangeSearch = new RSType(&refTree, queryTree, referenceData, queryData,
- singleMode);
-
- Log::Info << "Tree built." << endl;
- }
- else
- {
- rangeSearch = new RSType(&refTree, referenceData, singleMode);
-
- Log::Info << "Trees built." << endl;
- }
-
- Log::Info << "Computing neighbors within range [" << min << ", " << max
- << "]." << endl;
-
- math::Range r = math::Range(min, max);
- rangeSearch->Search(r, neighbors, distances);
-
- Log::Info << "Neighbors computed." << endl;
-
- // We have to map back to the original indices from before the tree
- // construction.
- Log::Info << "Re-mapping indices..." << endl;
-
- vector<vector<double> > distancesOut;
- distancesOut.resize(distances.size());
- vector<vector<size_t> > neighborsOut;
- neighborsOut.resize(neighbors.size());
-
- // Do the actual remapping.
- if (CLI::GetParam<string>("query_file") != "")
- {
- for (size_t i = 0; i < distances.size(); ++i)
- {
- // Map distances (copy a column).
- distancesOut[oldFromNewQueries[i]] = distances[i];
-
- // Map indices of neighbors.
- neighborsOut[oldFromNewQueries[i]].resize(neighbors[i].size());
- for (size_t j = 0; j < distances[i].size(); ++j)
- {
- neighborsOut[oldFromNewQueries[i]][j] = oldFromNewRefs[neighbors[i][j]];
- }
- }
- }
- else
- {
- for (size_t i = 0; i < distances.size(); ++i)
- {
- // Map distances (copy a column).
- distancesOut[oldFromNewRefs[i]] = distances[i];
-
- // Map indices of neighbors.
- neighborsOut[oldFromNewRefs[i]].resize(neighbors[i].size());
- for (size_t j = 0; j < distances[i].size(); ++j)
- {
- neighborsOut[oldFromNewRefs[i]][j] = oldFromNewRefs[neighbors[i][j]];
- }
- }
- }
-
- // Clean up.
- if (queryTree)
- delete queryTree;
-
- // Save output. We have to do this by hand.
- fstream distancesStr(distancesFile.c_str(), fstream::out);
- if (!distancesStr.is_open())
- {
- Log::Warn << "Cannot open file '" << distancesFile << "' to save output "
- << "distances to!" << endl;
- }
- else
- {
- // Loop over each point.
- for (size_t i = 0; i < distancesOut.size(); ++i)
- {
- // Store the distances of each point. We may have 0 points to store, so
- // we must account for that possibility.
- for (size_t j = 0; j + 1 < distancesOut[i].size(); ++j)
- {
- distancesStr << distancesOut[i][j] << ", ";
- }
-
- if (distancesOut[i].size() > 0)
- distancesStr << distancesOut[i][distancesOut[i].size() - 1];
-
- distancesStr << endl;
- }
-
- distancesStr.close();
- }
-
- fstream neighborsStr(neighborsFile.c_str(), fstream::out);
- if (!neighborsStr.is_open())
- {
- Log::Warn << "Cannot open file '" << neighborsFile << "' to save output "
- << "neighbor indices to!" << endl;
- }
- else
- {
- // Loop over each point.
- for (size_t i = 0; i < neighborsOut.size(); ++i)
- {
- // Store the neighbors of each point. We may have 0 points to store, so
- // we must account for that possibility.
- for (size_t j = 0; j + 1 < neighborsOut[i].size(); ++j)
- {
- neighborsStr << neighborsOut[i][j] << ", ";
- }
-
- if (neighborsOut[i].size() > 0)
- neighborsStr << neighborsOut[i][neighborsOut[i].size() - 1];
-
- neighborsStr << endl;
- }
-
- neighborsStr.close();
- }
-
- delete rangeSearch;
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/range_search/range_search_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/range_search/range_search_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/range_search/range_search_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/range_search/range_search_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,301 @@
+/**
+ * @file range_search_main.cpp
+ * @author Ryan Curtin
+ * @author Matthew Amidon
+ *
+ * Implementation of the RangeSearch executable. Allows some number of standard
+ * options.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+
+#include "range_search.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::range;
+using namespace mlpack::tree;
+
+// Information about the program itself.
+PROGRAM_INFO("Range Search",
+ "This program implements range search with a Euclidean distance metric. "
+ "For a given query point, a given range, and a given set of reference "
+ "points, the program will return all of the reference points with distance "
+ "to the query point in the given range. This is performed for an entire "
+ "set of query points. You may specify a separate set of reference and query"
+ " points, or only a reference set -- which is then used as both the "
+ "reference and query set. The given range is taken to be inclusive (that "
+ "is, points with a distance exactly equal to the minimum and maximum of the"
+ " range are included in the results)."
+ "\n\n"
+ "For example, the following will calculate the points within the range [2, "
+ "5] of each point in 'input.csv' and store the distances in 'distances.csv'"
+ " and the neighbors in 'neighbors.csv':"
+ "\n\n"
+ "$ range_search --min=2 --max=5 --reference_file=input.csv\n"
+ " --distances_file=distances.csv --neighbors_file=neighbors.csv"
+ "\n\n"
+ "The output files are organized such that line i corresponds to the points "
+ "found for query point i. Because sometimes 0 points may be found in the "
+ "given range, lines of the output files may be empty. The points are not "
+ "ordered in any specific manner."
+ "\n\n"
+ "Because the number of points returned for each query point may differ, the"
+ " resultant CSV-like files may not be loadable by many programs. However, "
+ "at this time a better way to store this non-square result is not known. "
+ "As a result, any output files will be written as CSVs in this manner, "
+ "regardless of the given extension.");
+
+// Define our input parameters that this program will take.
+PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
+ "r");
+PARAM_STRING_REQ("distances_file", "File to output distances into.", "d");
+PARAM_STRING_REQ("neighbors_file", "File to output neighbors into.", "n");
+
+PARAM_DOUBLE_REQ("max", "Upper bound in range.", "M");
+PARAM_DOUBLE("min", "Lower bound in range.", "m", 0.0);
+
+PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
+
+PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
+PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
+PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
+ "dual-tree search.", "s");
+
+typedef RangeSearch<metric::SquaredEuclideanDistance,
+ BinarySpaceTree<bound::HRectBound<2>, EmptyStatistic> > RSType;
+
+int main(int argc, char *argv[])
+{
+ // Give CLI the command line parameters the user passed in.
+ CLI::ParseCommandLine(argc, argv);
+
+ // Get all the parameters.
+ string referenceFile = CLI::GetParam<string>("reference_file");
+
+ string distancesFile = CLI::GetParam<string>("distances_file");
+ string neighborsFile = CLI::GetParam<string>("neighbors_file");
+
+ int lsInt = CLI::GetParam<int>("leaf_size");
+
+ double max = CLI::GetParam<double>("max");
+ double min = CLI::GetParam<double>("min");
+
+ bool naive = CLI::HasParam("naive");
+ bool singleMode = CLI::HasParam("single_mode");
+
+ arma::mat referenceData;
+ arma::mat queryData; // So it doesn't go out of scope.
+ if (!data::Load(referenceFile.c_str(), referenceData))
+ Log::Fatal << "Reference file " << referenceFile << "not found." << endl;
+
+ Log::Info << "Loaded reference data from '" << referenceFile << "'." << endl;
+
+ // Sanity check on range value: max must be greater than min.
+ if (max <= min)
+ {
+ Log::Fatal << "Invalid range: maximum (" << max << ") must be greater than "
+ << "minimum (" << min << ")." << endl;
+ }
+
+ // Sanity check on leaf size.
+ if (lsInt < 0)
+ {
+ Log::Fatal << "Invalid leaf size: " << lsInt << ". Must be greater "
+ "than or equal to 0." << endl;
+ }
+ size_t leafSize = lsInt;
+
+ // Naive mode overrides single mode.
+ if (singleMode && naive)
+ {
+ Log::Warn << "--single_mode ignored because --naive is present." << endl;
+ }
+
+ if (naive)
+ leafSize = referenceData.n_cols;
+
+ vector<vector<size_t> > neighbors;
+ vector<vector<double> > distances;
+
+ // Because we may construct it differently, we need a pointer.
+ RSType* rangeSearch = NULL;
+
+ // Mappings for when we build the tree.
+ vector<size_t> oldFromNewRefs;
+
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ Log::Info << "Building reference tree..." << endl;
+ Timer::Start("tree_building");
+
+ BinarySpaceTree<bound::HRectBound<2>, tree::EmptyStatistic>
+ refTree(referenceData, oldFromNewRefs, leafSize);
+ BinarySpaceTree<bound::HRectBound<2>, tree::EmptyStatistic>*
+ queryTree = NULL; // Empty for now.
+
+ Timer::Stop("tree_building");
+
+ std::vector<size_t> oldFromNewQueries;
+
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ string queryFile = CLI::GetParam<string>("query_file");
+
+ if (!data::Load(queryFile.c_str(), queryData))
+ Log::Fatal << "Query file " << queryFile << " not found" << endl;
+
+ if (naive && leafSize < queryData.n_cols)
+ leafSize = queryData.n_cols;
+
+ Log::Info << "Loaded query data from '" << queryFile << "'." << endl;
+
+ Log::Info << "Building query tree..." << endl;
+
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ Timer::Start("tree_building");
+
+ queryTree = new BinarySpaceTree<bound::HRectBound<2>,
+ tree::EmptyStatistic >(queryData, oldFromNewQueries,
+ leafSize);
+
+ Timer::Stop("tree_building");
+
+ rangeSearch = new RSType(&refTree, queryTree, referenceData, queryData,
+ singleMode);
+
+ Log::Info << "Tree built." << endl;
+ }
+ else
+ {
+ rangeSearch = new RSType(&refTree, referenceData, singleMode);
+
+ Log::Info << "Trees built." << endl;
+ }
+
+ Log::Info << "Computing neighbors within range [" << min << ", " << max
+ << "]." << endl;
+
+ math::Range r = math::Range(min, max);
+ rangeSearch->Search(r, neighbors, distances);
+
+ Log::Info << "Neighbors computed." << endl;
+
+ // We have to map back to the original indices from before the tree
+ // construction.
+ Log::Info << "Re-mapping indices..." << endl;
+
+ vector<vector<double> > distancesOut;
+ distancesOut.resize(distances.size());
+ vector<vector<size_t> > neighborsOut;
+ neighborsOut.resize(neighbors.size());
+
+ // Do the actual remapping.
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ for (size_t i = 0; i < distances.size(); ++i)
+ {
+ // Map distances (copy a column).
+ distancesOut[oldFromNewQueries[i]] = distances[i];
+
+ // Map indices of neighbors.
+ neighborsOut[oldFromNewQueries[i]].resize(neighbors[i].size());
+ for (size_t j = 0; j < distances[i].size(); ++j)
+ {
+ neighborsOut[oldFromNewQueries[i]][j] = oldFromNewRefs[neighbors[i][j]];
+ }
+ }
+ }
+ else
+ {
+ for (size_t i = 0; i < distances.size(); ++i)
+ {
+ // Map distances (copy a column).
+ distancesOut[oldFromNewRefs[i]] = distances[i];
+
+ // Map indices of neighbors.
+ neighborsOut[oldFromNewRefs[i]].resize(neighbors[i].size());
+ for (size_t j = 0; j < distances[i].size(); ++j)
+ {
+ neighborsOut[oldFromNewRefs[i]][j] = oldFromNewRefs[neighbors[i][j]];
+ }
+ }
+ }
+
+ // Clean up.
+ if (queryTree)
+ delete queryTree;
+
+ // Save output. We have to do this by hand.
+ fstream distancesStr(distancesFile.c_str(), fstream::out);
+ if (!distancesStr.is_open())
+ {
+ Log::Warn << "Cannot open file '" << distancesFile << "' to save output "
+ << "distances to!" << endl;
+ }
+ else
+ {
+ // Loop over each point.
+ for (size_t i = 0; i < distancesOut.size(); ++i)
+ {
+ // Store the distances of each point. We may have 0 points to store, so
+ // we must account for that possibility.
+ for (size_t j = 0; j + 1 < distancesOut[i].size(); ++j)
+ {
+ distancesStr << distancesOut[i][j] << ", ";
+ }
+
+ if (distancesOut[i].size() > 0)
+ distancesStr << distancesOut[i][distancesOut[i].size() - 1];
+
+ distancesStr << endl;
+ }
+
+ distancesStr.close();
+ }
+
+ fstream neighborsStr(neighborsFile.c_str(), fstream::out);
+ if (!neighborsStr.is_open())
+ {
+ Log::Warn << "Cannot open file '" << neighborsFile << "' to save output "
+ << "neighbor indices to!" << endl;
+ }
+ else
+ {
+ // Loop over each point.
+ for (size_t i = 0; i < neighborsOut.size(); ++i)
+ {
+ // Store the neighbors of each point. We may have 0 points to store, so
+ // we must account for that possibility.
+ for (size_t j = 0; j + 1 < neighborsOut[i].size(); ++j)
+ {
+ neighborsStr << neighborsOut[i][j] << ", ";
+ }
+
+ if (neighborsOut[i].size() > 0)
+ neighborsStr << neighborsOut[i][neighborsOut[i].size() - 1];
+
+ neighborsStr << endl;
+ }
+
+ neighborsStr.close();
+ }
+
+ delete rangeSearch;
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/allkrann_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/rann/allkrann_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/allkrann_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,310 +0,0 @@
-/**
- * @file allkrann_main.cpp
- * @author Parikshit Ram
- *
- * Implementation of the AllkRANN executable. Allows some number of standard
- * options.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <time.h>
-
-#include <mlpack/core.hpp>
-#include <mlpack/core/tree/cover_tree.hpp>
-
-#include <string>
-#include <fstream>
-#include <iostream>
-
-#include "ra_search.hpp"
-
-using namespace std;
-using namespace mlpack;
-using namespace mlpack::neighbor;
-using namespace mlpack::tree;
-
-// Information about the program itself.
-PROGRAM_INFO("All K-Rank-Approximate-Nearest-Neighbors",
- "This program will calculate the k rank-approximate-nearest-neighbors "
- "of a set of points. You may specify a separate set of reference "
- "points and query points, or just a reference set which will be "
- "used as both the reference and query set. You must specify the "
- "rank approximation (in \%) (and maybe the success probability)."
- "\n\n"
- "For example, the following will return 5 neighbors from the top 0.1\% "
- "of the data (with probability 0.95) for each point in 'input.csv' "
- "and store the distances in 'distances.csv' and the neighbors in the "
- "file 'neighbors.csv':"
- "\n\n"
- "$ allkrann -k 5 -r input.csv -d distances.csv -n neighbors.csv --tau=0.1"
- "\n\n"
- "The output files are organized such that row i and column j in the "
- "neighbors output file corresponds to the index of the point in the "
- "reference set which is the i'th nearest neighbor from the point in the "
- "query set with index j. Row i and column j in the distances output file "
- "corresponds to the distance between those two points.");
-
-// Define our input parameters that this program will take.
-PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
- "r");
-PARAM_STRING("distances_file", "File to output distances into.", "d", "");
-PARAM_STRING("neighbors_file", "File to output neighbors into.", "n", "");
-
-PARAM_INT_REQ("k", "Number of furthest neighbors to find.", "k");
-
-PARAM_STRING("query_file", "File containing query points (optional).",
- "q", "");
-
-PARAM_DOUBLE("tau", "The allowed rank-error in terms of the percentile of "
- "the data.", "t", 0.1);
-PARAM_DOUBLE("alpha", "The desired success probability.", "a", 0.95);
-
-PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
-PARAM_FLAG("naive", "If true, sampling will be done without using a tree.",
- "N");
-PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
- "dual-tree search.", "s");
-PARAM_FLAG("cover_tree", "If true, use cover trees to perform the search.",
- "c");
-
-PARAM_FLAG("sample_at_leaves", "The flag to trigger sampling at leaves.", "L");
-PARAM_FLAG("first_leaf_exact", "The flag to trigger sampling only after "
- "exactly exploring the first leaf.", "X");
-PARAM_INT("single_sample_limit", "The limit on the maximum number of "
- "samples (and hence the largest node you can approximate).", "S", 20);
-
-int main(int argc, char *argv[])
-{
- // Give CLI the command line parameters the user passed in.
- CLI::ParseCommandLine(argc, argv);
- math::RandomSeed(time(NULL));
-
- // Get all the parameters.
- string referenceFile = CLI::GetParam<string>("reference_file");
- string distancesFile = CLI::GetParam<string>("distances_file");
- string neighborsFile = CLI::GetParam<string>("neighbors_file");
-
- int lsInt = CLI::GetParam<int>("leaf_size");
- size_t singleSampleLimit = CLI::GetParam<int>("single_sample_limit");
-
- size_t k = CLI::GetParam<int>("k");
-
- double tau = CLI::GetParam<double>("tau");
- double alpha = CLI::GetParam<double>("alpha");
-
- bool naive = CLI::HasParam("naive");
- bool singleMode = CLI::HasParam("single_mode");
- bool sampleAtLeaves = CLI::HasParam("sample_at_leaves");
- bool firstLeafExact = CLI::HasParam("first_leaf_exact");
-
- arma::mat referenceData;
- arma::mat queryData; // So it doesn't go out of scope.
- data::Load(referenceFile.c_str(), referenceData, true);
-
- Log::Info << "Loaded reference data from '" << referenceFile << "' ("
- << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
-
- // Sanity check on k value: must be greater than 0, must be less than the
- // number of reference points.
- if (k > referenceData.n_cols)
- {
- Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
- Log::Fatal << "than or equal to the number of reference points (";
- Log::Fatal << referenceData.n_cols << ")." << endl;
- }
-
- // Sanity check on the value of 'tau' with respect to 'k' so that
- // 'k' neighbors are not requested from the top-'rank_error' neighbors
- // where 'rank_error' <= 'k'.
- size_t rank_error
- = (size_t) ceil(tau * (double) referenceData.n_cols / 100.0);
- if (rank_error <= k)
- Log::Fatal << "Invalid 'tau' (" << tau << ") - k (" << k << ") " <<
- "combination. Increase 'tau' or decrease 'k'." << endl;
-
- // Sanity check on leaf size.
- if (lsInt < 0)
- Log::Fatal << "Invalid leaf size: " << lsInt << ". Must be greater "
- "than or equal to 0." << endl;
- size_t leafSize = lsInt;
-
- // Naive mode overrides single mode.
- if (singleMode && naive)
- Log::Warn << "--single_mode ignored because --naive is present." << endl;
-
- // The actual output after the remapping
- arma::Mat<size_t> neighbors;
- arma::mat distances;
-
- if (naive)
- {
- AllkRANN* allkrann;
- if (CLI::GetParam<string>("query_file") != "")
- {
- string queryFile = CLI::GetParam<string>("query_file");
-
- data::Load(queryFile.c_str(), queryData, true);
-
- Log::Info << "Loaded query data from '" << queryFile << "' (" <<
- queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
-
- allkrann = new AllkRANN(referenceData, queryData, naive);
- }
- else
- allkrann = new AllkRANN(referenceData, naive);
-
- Log::Info << "Computing " << k << " nearest neighbors " << "with " <<
- tau << "% rank approximation..." << endl;
-
- allkrann->Search(k, neighbors, distances, tau, alpha);
-
- Log::Info << "Neighbors computed." << endl;
-
- delete allkrann;
- }
- else
- {
- // The results output by the AllkRANN class
- // shuffled because the tree construction shuffles the point sets.
- arma::Mat<size_t> neighborsOut;
- arma::mat distancesOut;
-
- if (!CLI::HasParam("cover_tree"))
- {
- // Because we may construct it differently, we need a pointer.
- AllkRANN* allkrann = NULL;
-
- // Mappings for when we build the tree.
- std::vector<size_t> oldFromNewRefs;
-
- // Build trees by hand, so we can save memory: if we pass a tree to
- // NeighborSearch, it does not copy the matrix.
- Log::Info << "Building reference tree..." << endl;
- Timer::Start("tree_building");
-
- BinarySpaceTree<bound::HRectBound<2, false>,
- RAQueryStat<NearestNeighborSort> >
- refTree(referenceData, oldFromNewRefs, leafSize);
- BinarySpaceTree<bound::HRectBound<2, false>,
- RAQueryStat<NearestNeighborSort> >*
- queryTree = NULL; // Empty for now.
-
- Timer::Stop("tree_building");
-
- std::vector<size_t> oldFromNewQueries;
-
- if (CLI::GetParam<string>("query_file") != "")
- {
- string queryFile = CLI::GetParam<string>("query_file");
-
- data::Load(queryFile.c_str(), queryData, true);
-
- if (naive && leafSize < queryData.n_cols)
- leafSize = queryData.n_cols;
-
- Log::Info << "Loaded query data from '" << queryFile << "' (" <<
- queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
-
- Log::Info << "Building query tree..." << endl;
-
- // Build trees by hand, so we can save memory: if we pass a tree to
- // NeighborSearch, it does not copy the matrix.
- Timer::Start("tree_building");
-
- queryTree = new BinarySpaceTree<bound::HRectBound<2, false>,
- RAQueryStat<NearestNeighborSort> >
- (queryData, oldFromNewQueries, leafSize);
- Timer::Stop("tree_building");
-
- allkrann = new AllkRANN(&refTree, queryTree, referenceData, queryData,
- singleMode);
-
- Log::Info << "Tree built." << endl;
- }
- else
- {
- allkrann = new AllkRANN(&refTree, referenceData, singleMode);
- Log::Info << "Trees built." << endl;
- }
-
- Log::Info << "Computing " << k << " nearest neighbors " << "with " <<
- tau << "% rank approximation..." << endl;
- allkrann->Search(k, neighborsOut, distancesOut,
- tau, alpha, sampleAtLeaves,
- firstLeafExact, singleSampleLimit);
-
- Log::Info << "Neighbors computed." << endl;
-
- // We have to map back to the original indices from before the tree
- // construction.
- Log::Info << "Re-mapping indices..." << endl;
-
- neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
- distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
-
- // Do the actual remapping.
- if (CLI::GetParam<string>("query_file") != "")
- {
- for (size_t i = 0; i < distancesOut.n_cols; ++i)
- {
- // Map distances (copy a column).
- distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
-
- // Map indices of neighbors.
- for (size_t j = 0; j < distancesOut.n_rows; ++j)
- {
- neighbors(j, oldFromNewQueries[i])
- = oldFromNewRefs[neighborsOut(j, i)];
- }
- }
- }
- else
- {
- for (size_t i = 0; i < distancesOut.n_cols; ++i)
- {
- // Map distances (copy a column).
- distances.col(oldFromNewRefs[i]) = distancesOut.col(i);
-
- // Map indices of neighbors.
- for (size_t j = 0; j < distancesOut.n_rows; ++j)
- {
- neighbors(j, oldFromNewRefs[i])
- = oldFromNewRefs[neighborsOut(j, i)];
- }
- }
- }
-
- // Clean up.
- if (queryTree)
- delete queryTree;
-
- delete allkrann;
- }
- else // Cover trees.
- {
- Log::Fatal << "Cover tree case not implemented yet..." << endl;
- }
- }
-
- // Save output.
- if (distancesFile != "")
- data::Save(distancesFile, distances);
-
- if (neighborsFile != "")
- data::Save(neighborsFile, neighbors);
-
- return 0;
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/allkrann_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/rann/allkrann_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/allkrann_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/allkrann_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,310 @@
+/**
+ * @file allkrann_main.cpp
+ * @author Parikshit Ram
+ *
+ * Implementation of the AllkRANN executable. Allows some number of standard
+ * options.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <time.h>
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
+
+#include <string>
+#include <fstream>
+#include <iostream>
+
+#include "ra_search.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::neighbor;
+using namespace mlpack::tree;
+
+// Information about the program itself.
+PROGRAM_INFO("All K-Rank-Approximate-Nearest-Neighbors",
+ "This program will calculate the k rank-approximate-nearest-neighbors "
+ "of a set of points. You may specify a separate set of reference "
+ "points and query points, or just a reference set which will be "
+ "used as both the reference and query set. You must specify the "
+ "rank approximation (in \%) (and maybe the success probability)."
+ "\n\n"
+ "For example, the following will return 5 neighbors from the top 0.1\% "
+ "of the data (with probability 0.95) for each point in 'input.csv' "
+ "and store the distances in 'distances.csv' and the neighbors in the "
+ "file 'neighbors.csv':"
+ "\n\n"
+ "$ allkrann -k 5 -r input.csv -d distances.csv -n neighbors.csv --tau=0.1"
+ "\n\n"
+ "The output files are organized such that row i and column j in the "
+ "neighbors output file corresponds to the index of the point in the "
+ "reference set which is the i'th nearest neighbor from the point in the "
+ "query set with index j. Row i and column j in the distances output file "
+ "corresponds to the distance between those two points.");
+
+// Define our input parameters that this program will take.
+PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
+ "r");
+PARAM_STRING("distances_file", "File to output distances into.", "d", "");
+PARAM_STRING("neighbors_file", "File to output neighbors into.", "n", "");
+
+PARAM_INT_REQ("k", "Number of furthest neighbors to find.", "k");
+
+PARAM_STRING("query_file", "File containing query points (optional).",
+ "q", "");
+
+PARAM_DOUBLE("tau", "The allowed rank-error in terms of the percentile of "
+ "the data.", "t", 0.1);
+PARAM_DOUBLE("alpha", "The desired success probability.", "a", 0.95);
+
+PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
+PARAM_FLAG("naive", "If true, sampling will be done without using a tree.",
+ "N");
+PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
+ "dual-tree search.", "s");
+PARAM_FLAG("cover_tree", "If true, use cover trees to perform the search.",
+ "c");
+
+PARAM_FLAG("sample_at_leaves", "The flag to trigger sampling at leaves.", "L");
+PARAM_FLAG("first_leaf_exact", "The flag to trigger sampling only after "
+ "exactly exploring the first leaf.", "X");
+PARAM_INT("single_sample_limit", "The limit on the maximum number of "
+ "samples (and hence the largest node you can approximate).", "S", 20);
+
+int main(int argc, char *argv[])
+{
+ // Give CLI the command line parameters the user passed in.
+ CLI::ParseCommandLine(argc, argv);
+ math::RandomSeed(time(NULL));
+
+ // Get all the parameters.
+ string referenceFile = CLI::GetParam<string>("reference_file");
+ string distancesFile = CLI::GetParam<string>("distances_file");
+ string neighborsFile = CLI::GetParam<string>("neighbors_file");
+
+ int lsInt = CLI::GetParam<int>("leaf_size");
+ size_t singleSampleLimit = CLI::GetParam<int>("single_sample_limit");
+
+ size_t k = CLI::GetParam<int>("k");
+
+ double tau = CLI::GetParam<double>("tau");
+ double alpha = CLI::GetParam<double>("alpha");
+
+ bool naive = CLI::HasParam("naive");
+ bool singleMode = CLI::HasParam("single_mode");
+ bool sampleAtLeaves = CLI::HasParam("sample_at_leaves");
+ bool firstLeafExact = CLI::HasParam("first_leaf_exact");
+
+ arma::mat referenceData;
+ arma::mat queryData; // So it doesn't go out of scope.
+ data::Load(referenceFile.c_str(), referenceData, true);
+
+ Log::Info << "Loaded reference data from '" << referenceFile << "' ("
+ << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
+
+ // Sanity check on k value: must be greater than 0, must be less than the
+ // number of reference points.
+ if (k > referenceData.n_cols)
+ {
+ Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
+ Log::Fatal << "than or equal to the number of reference points (";
+ Log::Fatal << referenceData.n_cols << ")." << endl;
+ }
+
+ // Sanity check on the value of 'tau' with respect to 'k' so that
+ // 'k' neighbors are not requested from the top-'rank_error' neighbors
+ // where 'rank_error' <= 'k'.
+ size_t rank_error
+ = (size_t) ceil(tau * (double) referenceData.n_cols / 100.0);
+ if (rank_error <= k)
+ Log::Fatal << "Invalid 'tau' (" << tau << ") - k (" << k << ") " <<
+ "combination. Increase 'tau' or decrease 'k'." << endl;
+
+ // Sanity check on leaf size.
+ if (lsInt < 0)
+ Log::Fatal << "Invalid leaf size: " << lsInt << ". Must be greater "
+ "than or equal to 0." << endl;
+ size_t leafSize = lsInt;
+
+ // Naive mode overrides single mode.
+ if (singleMode && naive)
+ Log::Warn << "--single_mode ignored because --naive is present." << endl;
+
+ // The actual output after the remapping
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ if (naive)
+ {
+ AllkRANN* allkrann;
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ string queryFile = CLI::GetParam<string>("query_file");
+
+ data::Load(queryFile.c_str(), queryData, true);
+
+ Log::Info << "Loaded query data from '" << queryFile << "' (" <<
+ queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+
+ allkrann = new AllkRANN(referenceData, queryData, naive);
+ }
+ else
+ allkrann = new AllkRANN(referenceData, naive);
+
+ Log::Info << "Computing " << k << " nearest neighbors " << "with " <<
+ tau << "% rank approximation..." << endl;
+
+ allkrann->Search(k, neighbors, distances, tau, alpha);
+
+ Log::Info << "Neighbors computed." << endl;
+
+ delete allkrann;
+ }
+ else
+ {
+ // The results output by the AllkRANN class
+ // shuffled because the tree construction shuffles the point sets.
+ arma::Mat<size_t> neighborsOut;
+ arma::mat distancesOut;
+
+ if (!CLI::HasParam("cover_tree"))
+ {
+ // Because we may construct it differently, we need a pointer.
+ AllkRANN* allkrann = NULL;
+
+ // Mappings for when we build the tree.
+ std::vector<size_t> oldFromNewRefs;
+
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ Log::Info << "Building reference tree..." << endl;
+ Timer::Start("tree_building");
+
+ BinarySpaceTree<bound::HRectBound<2, false>,
+ RAQueryStat<NearestNeighborSort> >
+ refTree(referenceData, oldFromNewRefs, leafSize);
+ BinarySpaceTree<bound::HRectBound<2, false>,
+ RAQueryStat<NearestNeighborSort> >*
+ queryTree = NULL; // Empty for now.
+
+ Timer::Stop("tree_building");
+
+ std::vector<size_t> oldFromNewQueries;
+
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ string queryFile = CLI::GetParam<string>("query_file");
+
+ data::Load(queryFile.c_str(), queryData, true);
+
+ if (naive && leafSize < queryData.n_cols)
+ leafSize = queryData.n_cols;
+
+ Log::Info << "Loaded query data from '" << queryFile << "' (" <<
+ queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+
+ Log::Info << "Building query tree..." << endl;
+
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ Timer::Start("tree_building");
+
+ queryTree = new BinarySpaceTree<bound::HRectBound<2, false>,
+ RAQueryStat<NearestNeighborSort> >
+ (queryData, oldFromNewQueries, leafSize);
+ Timer::Stop("tree_building");
+
+ allkrann = new AllkRANN(&refTree, queryTree, referenceData, queryData,
+ singleMode);
+
+ Log::Info << "Tree built." << endl;
+ }
+ else
+ {
+ allkrann = new AllkRANN(&refTree, referenceData, singleMode);
+ Log::Info << "Trees built." << endl;
+ }
+
+ Log::Info << "Computing " << k << " nearest neighbors " << "with " <<
+ tau << "% rank approximation..." << endl;
+ allkrann->Search(k, neighborsOut, distancesOut,
+ tau, alpha, sampleAtLeaves,
+ firstLeafExact, singleSampleLimit);
+
+ Log::Info << "Neighbors computed." << endl;
+
+ // We have to map back to the original indices from before the tree
+ // construction.
+ Log::Info << "Re-mapping indices..." << endl;
+
+ neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
+ distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
+
+ // Do the actual remapping.
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ for (size_t i = 0; i < distancesOut.n_cols; ++i)
+ {
+ // Map distances (copy a column).
+ distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
+
+ // Map indices of neighbors.
+ for (size_t j = 0; j < distancesOut.n_rows; ++j)
+ {
+ neighbors(j, oldFromNewQueries[i])
+ = oldFromNewRefs[neighborsOut(j, i)];
+ }
+ }
+ }
+ else
+ {
+ for (size_t i = 0; i < distancesOut.n_cols; ++i)
+ {
+ // Map distances (copy a column).
+ distances.col(oldFromNewRefs[i]) = distancesOut.col(i);
+
+ // Map indices of neighbors.
+ for (size_t j = 0; j < distancesOut.n_rows; ++j)
+ {
+ neighbors(j, oldFromNewRefs[i])
+ = oldFromNewRefs[neighborsOut(j, i)];
+ }
+ }
+ }
+
+ // Clean up.
+ if (queryTree)
+ delete queryTree;
+
+ delete allkrann;
+ }
+ else // Cover trees.
+ {
+ Log::Fatal << "Cover tree case not implemented yet..." << endl;
+ }
+ }
+
+ // Save output.
+ if (distancesFile != "")
+ data::Save(distancesFile, distances);
+
+ if (neighborsFile != "")
+ data::Save(neighborsFile, neighbors);
+
+ return 0;
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/rann/ra_search.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,344 +0,0 @@
-/**
- * @file ra_search.hpp
- * @author Parikshit Ram
- *
- * Defines the RASearch class, which performs an abstract
- * rank-approximate nearest/farthest neighbor query on two datasets.
- *
- * The details of this method can be found in the following paper:
- *
- * @inproceedings{ram2009rank,
- * title={{Rank-Approximate Nearest Neighbor Search: Retaining Meaning and
- * Speed in High Dimensions}},
- * author={{Ram, P. and Lee, D. and Ouyang, H. and Gray, A. G.}},
- * booktitle={{Advances of Neural Information Processing Systems}},
- * year={2009}
- * }
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_RA_SEARCH_HPP
-#define __MLPACK_METHODS_NEIGHBOR_SEARCH_RA_SEARCH_HPP
-
-#include <mlpack/core.hpp>
-
-#include <mlpack/core/tree/binary_space_tree.hpp>
-
-#include <mlpack/core/metrics/lmetric.hpp>
-#include <mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp>
-
-namespace mlpack {
-namespace neighbor /** Neighbor-search routines. These include
- * all-nearest-neighbors and all-furthest-neighbors
- * searches. */ {
-
-/**
- * Extra data for each node in the tree. For neighbor searches, each node only
- * needs to store a bound on neighbor distances.
- *
- * Every query is required to make a minimum number of samples to guarantee the
- * desired approximation error. The 'numSamplesMade' keeps track of the minimum
- * number of samples made by all queries in the node in question.
- */
-template<typename SortPolicy>
-class RAQueryStat
-{
- private:
- //! The bound on the node's neighbor distances.
- double bound;
-
- //! The minimum number of samples made by any query in this node.
- size_t numSamplesMade;
-
- public:
- /**
- * Initialize the statistic with the worst possible distance according to our
- * sorting policy.
- */
- RAQueryStat() : bound(SortPolicy::WorstDistance()), numSamplesMade(0) { }
-
- /**
- * Initialization for a node.
- */
- template<typename TreeType>
- RAQueryStat(const TreeType& /* node */) :
- bound(SortPolicy::WorstDistance()),
- numSamplesMade(0)
- { }
-
- //! Get the bound.
- double Bound() const { return bound; }
- //! Modify the bound.
- double& Bound() { return bound; }
-
- //! Get the number of samples made.
- size_t NumSamplesMade() const { return numSamplesMade; }
- //! Modify the number of samples made.
- size_t& NumSamplesMade() { return numSamplesMade; }
-};
-
-/**
- * The RASearch class: This class provides a generic manner to perform
- * rank-approximate search via random-sampling. If the 'naive' option is chosen,
- * this rank-approximate search will be done by randomly sampled from the whole
- * set. If the 'naive' option is not chosen, the sampling is done in a
- * stratified manner in the tree as mentioned in the algorithms in Figure 2 of
- * the following paper:
- *
- * @inproceedings{ram2009rank,
- * title={{Rank-Approximate Nearest Neighbor Search: Retaining Meaning and
- * Speed in High Dimensions}},
- * author={{Ram, P. and Lee, D. and Ouyang, H. and Gray, A. G.}},
- * booktitle={{Advances of Neural Information Processing Systems}},
- * year={2009}
- * }
- *
- * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
- * @tparam MetricType The metric to use for computation.
- * @tparam TreeType The tree type to use.
- */
-template<typename SortPolicy = NearestNeighborSort,
- typename MetricType = mlpack::metric::SquaredEuclideanDistance,
- typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2, false>,
- RAQueryStat<SortPolicy> > >
-class RASearch
-{
- public:
- /**
- * Initialize the RASearch object, passing both a query and reference dataset.
- * Optionally, perform the computation in naive mode or single-tree mode, and
- * set the leaf size used for tree-building. An initialized distance metric
- * can be given, for cases where the metric has internal data (i.e. the
- * distance::MahalanobisDistance class).
- *
- * This method will copy the matrices to internal copies, which are rearranged
- * during tree-building. You can avoid this extra copy by pre-constructing
- * the trees and passing them using a diferent constructor.
- *
- * @param referenceSet Set of reference points.
- * @param querySet Set of query points.
- * @param naive If true, the rank-approximate search will be performed by
- * directly sampling the whole set instead of using the stratified
- * sampling on the tree.
- * @param singleMode If true, single-tree search will be used (as opposed to
- * dual-tree search).
- * @param leafSize Leaf size for tree construction (ignored if tree is given).
- * @param metric An optional instance of the MetricType class.
- */
- RASearch(const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
- const bool naive = false,
- const bool singleMode = false,
- const size_t leafSize = 20,
- const MetricType metric = MetricType());
-
- /**
- * Initialize the RASearch object, passing only one dataset, which is
- * used as both the query and the reference dataset. Optionally, perform the
- * computation in naive mode or single-tree mode, and set the leaf size used
- * for tree-building. An initialized distance metric can be given, for cases
- * where the metric has internal data (i.e. the distance::MahalanobisDistance
- * class).
- *
- * If naive mode is being used and a pre-built tree is given, it may not work:
- * naive mode operates by building a one-node tree (the root node holds all
- * the points). If that condition is not satisfied with the pre-built tree,
- * then naive mode will not work.
- *
- * @param referenceSet Set of reference points.
- * @param naive If true, the rank-approximate search will be performed
- * by directly sampling the whole set instead of using the stratified
- * sampling on the tree.
- * @param singleMode If true, single-tree search will be used (as opposed to
- * dual-tree search).
- * @param leafSize Leaf size for tree construction (ignored if tree is given).
- * @param metric An optional instance of the MetricType class.
- */
- RASearch(const typename TreeType::Mat& referenceSet,
- const bool naive = false,
- const bool singleMode = false,
- const size_t leafSize = 20,
- const MetricType metric = MetricType());
-
- /**
- * Initialize the RASearch object with the given datasets and
- * pre-constructed trees. It is assumed that the points in referenceSet and
- * querySet correspond to the points in referenceTree and queryTree,
- * respectively. Optionally, choose to use single-tree mode. Naive mode is
- * not available as an option for this constructor; instead, to run naive
- * computation, construct a tree with all of the points in one leaf (i.e.
- * leafSize = number of points). Additionally, an instantiated distance
- * metric can be given, for cases where the distance metric holds data.
- *
- * There is no copying of the data matrices in this constructor (because
- * tree-building is not necessary), so this is the constructor to use when
- * copies absolutely must be avoided.
- *
- * @note
- * Because tree-building (at least with BinarySpaceTree) modifies the ordering
- * of a matrix, be sure you pass the modified matrix to this object! In
- * addition, mapping the points of the matrix back to their original indices
- * is not done when this constructor is used.
- * @endnote
- *
- * @param referenceTree Pre-built tree for reference points.
- * @param queryTree Pre-built tree for query points.
- * @param referenceSet Set of reference points corresponding to referenceTree.
- * @param querySet Set of query points corresponding to queryTree.
- * @param singleMode Whether single-tree computation should be used (as
- * opposed to dual-tree computation).
- * @param metric Instantiated distance metric.
- */
- RASearch(TreeType* referenceTree,
- TreeType* queryTree,
- const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
- const bool singleMode = false,
- const MetricType metric = MetricType());
-
- /**
- * Initialize the RASearch object with the given reference dataset and
- * pre-constructed tree. It is assumed that the points in referenceSet
- * correspond to the points in referenceTree. Optionally, choose to use
- * single-tree mode. Naive mode is not available as an option for this
- * constructor; instead, to run naive computation, construct a tree with all
- * the points in one leaf (i.e. leafSize = number of points). Additionally,
- * an instantiated distance metric can be given, for the case where the
- * distance metric holds data.
- *
- * There is no copying of the data matrices in this constructor (because
- * tree-building is not necessary), so this is the constructor to use when
- * copies absolutely must be avoided.
- *
- * @note
- * Because tree-building (at least with BinarySpaceTree) modifies the ordering
- * of a matrix, be sure you pass the modified matrix to this object! In
- * addition, mapping the points of the matrix back to their original indices
- * is not done when this constructor is used.
- * @endnote
- *
- * @param referenceTree Pre-built tree for reference points.
- * @param referenceSet Set of reference points corresponding to referenceTree.
- * @param singleMode Whether single-tree computation should be used (as
- * opposed to dual-tree computation).
- * @param metric Instantiated distance metric.
- */
- RASearch(TreeType* referenceTree,
- const typename TreeType::Mat& referenceSet,
- const bool singleMode = false,
- const MetricType metric = MetricType());
-
- /**
- * Delete the RASearch object. The tree is the only member we are
- * responsible for deleting. The others will take care of themselves.
- */
- ~RASearch();
-
- /**
- * Compute the rank approximate nearest neighbors and store the output in the
- * given matrices. The matrices will be set to the size of n columns by k
- * rows, where n is the number of points in the query dataset and k is the
- * number of neighbors being searched for.
- *
- * @param k Number of neighbors to search for.
- * @param resultingNeighbors Matrix storing lists of neighbors for each query
- * point.
- * @param distances Matrix storing distances of neighbors for each query
- * point.
- * @param tau The rank-approximation in percentile of the data. The default
- * value is 0.1%.
- * @param alpha The desired success probability. The default value is 0.95.
- * @param sampleAtLeaves Sample at leaves for faster but less accurate
- * computation. This defaults to 'false'.
- * @param firstLeafExact Traverse to the first leaf without approximation.
- * This can ensure that the query definitely finds its (near) duplicate
- * if there exists one. This defaults to 'false' for now.
- * @param singleSampleLimit The limit on the largest node that can be
- * approximated by sampling. This defaults to 20.
- */
- void Search(const size_t k,
- arma::Mat<size_t>& resultingNeighbors,
- arma::mat& distances,
- const double tau = 0.1,
- const double alpha = 0.95,
- const bool sampleAtLeaves = false,
- const bool firstLeafExact = false,
- const size_t singleSampleLimit = 20);
-
- /**
- * This function recursively resets the RAQueryStat of the queryTree to set
- * 'bound' to WorstDistance and the 'numSamplesMade' to 0. This allows a user
- * to perform multiple searches on the same pair of trees, possibly with
- * different levels of approximation without requiring to build a new pair of
- * trees for every new (approximate) search.
- */
- void ResetQueryTree();
-
- private:
- //! Copy of reference dataset (if we need it, because tree building modifies
- //! it).
- arma::mat referenceCopy;
- //! Copy of query dataset (if we need it, because tree building modifies it).
- arma::mat queryCopy;
-
- //! Reference dataset.
- const arma::mat& referenceSet;
- //! Query dataset (may not be given).
- const arma::mat& querySet;
-
- //! Pointer to the root of the reference tree.
- TreeType* referenceTree;
- //! Pointer to the root of the query tree (might not exist).
- TreeType* queryTree;
-
- //! Indicates if we should free the reference tree at deletion time.
- bool ownReferenceTree;
- //! Indicates if we should free the query tree at deletion time.
- bool ownQueryTree;
-
- //! Indicates if naive random sampling on the set is being used.
- bool naive;
- //! Indicates if single-tree search is being used (opposed to dual-tree).
- bool singleMode;
-
- //! Instantiation of kernel.
- MetricType metric;
-
- //! Permutations of reference points during tree building.
- std::vector<size_t> oldFromNewReferences;
- //! Permutations of query points during tree building.
- std::vector<size_t> oldFromNewQueries;
-
- //! Total number of pruned nodes during the neighbor search.
- size_t numberOfPrunes;
-
- /**
- * @param treeNode The node of the tree whose RAQueryStat is reset
- * and whose children are to be explored recursively.
- */
- void ResetRAQueryStat(TreeType* treeNode);
-}; // class RASearch
-
-}; // namespace neighbor
-}; // namespace mlpack
-
-// Include implementation.
-#include "ra_search_impl.hpp"
-
-// Include convenient typedefs.
-#include "ra_typedef.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/rann/ra_search.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,344 @@
+/**
+ * @file ra_search.hpp
+ * @author Parikshit Ram
+ *
+ * Defines the RASearch class, which performs an abstract
+ * rank-approximate nearest/farthest neighbor query on two datasets.
+ *
+ * The details of this method can be found in the following paper:
+ *
+ * @inproceedings{ram2009rank,
+ * title={{Rank-Approximate Nearest Neighbor Search: Retaining Meaning and
+ * Speed in High Dimensions}},
+ * author={{Ram, P. and Lee, D. and Ouyang, H. and Gray, A. G.}},
+ * booktitle={{Advances of Neural Information Processing Systems}},
+ * year={2009}
+ * }
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_RA_SEARCH_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_RA_SEARCH_HPP
+
+#include <mlpack/core.hpp>
+
+#include <mlpack/core/tree/binary_space_tree.hpp>
+
+#include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp>
+
+namespace mlpack {
+namespace neighbor /** Neighbor-search routines. These include
+ * all-nearest-neighbors and all-furthest-neighbors
+ * searches. */ {
+
+/**
+ * Extra data for each node in the tree. For neighbor searches, each node only
+ * needs to store a bound on neighbor distances.
+ *
+ * Every query is required to make a minimum number of samples to guarantee the
+ * desired approximation error. The 'numSamplesMade' keeps track of the minimum
+ * number of samples made by all queries in the node in question.
+ */
+template<typename SortPolicy>
+class RAQueryStat
+{
+ private:
+ //! The bound on the node's neighbor distances.
+ double bound;
+
+ //! The minimum number of samples made by any query in this node.
+ size_t numSamplesMade;
+
+ public:
+ /**
+ * Initialize the statistic with the worst possible distance according to our
+ * sorting policy.
+ */
+ RAQueryStat() : bound(SortPolicy::WorstDistance()), numSamplesMade(0) { }
+
+ /**
+ * Initialization for a node.
+ */
+ template<typename TreeType>
+ RAQueryStat(const TreeType& /* node */) :
+ bound(SortPolicy::WorstDistance()),
+ numSamplesMade(0)
+ { }
+
+ //! Get the bound.
+ double Bound() const { return bound; }
+ //! Modify the bound.
+ double& Bound() { return bound; }
+
+ //! Get the number of samples made.
+ size_t NumSamplesMade() const { return numSamplesMade; }
+ //! Modify the number of samples made.
+ size_t& NumSamplesMade() { return numSamplesMade; }
+};
+
+/**
+ * The RASearch class: This class provides a generic manner to perform
+ * rank-approximate search via random-sampling. If the 'naive' option is chosen,
+ * this rank-approximate search will be done by randomly sampled from the whole
+ * set. If the 'naive' option is not chosen, the sampling is done in a
+ * stratified manner in the tree as mentioned in the algorithms in Figure 2 of
+ * the following paper:
+ *
+ * @inproceedings{ram2009rank,
+ * title={{Rank-Approximate Nearest Neighbor Search: Retaining Meaning and
+ * Speed in High Dimensions}},
+ * author={{Ram, P. and Lee, D. and Ouyang, H. and Gray, A. G.}},
+ * booktitle={{Advances of Neural Information Processing Systems}},
+ * year={2009}
+ * }
+ *
+ * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
+ * @tparam MetricType The metric to use for computation.
+ * @tparam TreeType The tree type to use.
+ */
+template<typename SortPolicy = NearestNeighborSort,
+ typename MetricType = mlpack::metric::SquaredEuclideanDistance,
+ typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2, false>,
+ RAQueryStat<SortPolicy> > >
+class RASearch
+{
+ public:
+ /**
+ * Initialize the RASearch object, passing both a query and reference dataset.
+ * Optionally, perform the computation in naive mode or single-tree mode, and
+ * set the leaf size used for tree-building. An initialized distance metric
+ * can be given, for cases where the metric has internal data (i.e. the
+ * distance::MahalanobisDistance class).
+ *
+ * This method will copy the matrices to internal copies, which are rearranged
+ * during tree-building. You can avoid this extra copy by pre-constructing
+ * the trees and passing them using a diferent constructor.
+ *
+ * @param referenceSet Set of reference points.
+ * @param querySet Set of query points.
+ * @param naive If true, the rank-approximate search will be performed by
+ * directly sampling the whole set instead of using the stratified
+ * sampling on the tree.
+ * @param singleMode If true, single-tree search will be used (as opposed to
+ * dual-tree search).
+ * @param leafSize Leaf size for tree construction (ignored if tree is given).
+ * @param metric An optional instance of the MetricType class.
+ */
+ RASearch(const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
+ const bool naive = false,
+ const bool singleMode = false,
+ const size_t leafSize = 20,
+ const MetricType metric = MetricType());
+
+ /**
+ * Initialize the RASearch object, passing only one dataset, which is
+ * used as both the query and the reference dataset. Optionally, perform the
+ * computation in naive mode or single-tree mode, and set the leaf size used
+ * for tree-building. An initialized distance metric can be given, for cases
+ * where the metric has internal data (i.e. the distance::MahalanobisDistance
+ * class).
+ *
+ * If naive mode is being used and a pre-built tree is given, it may not work:
+ * naive mode operates by building a one-node tree (the root node holds all
+ * the points). If that condition is not satisfied with the pre-built tree,
+ * then naive mode will not work.
+ *
+ * @param referenceSet Set of reference points.
+ * @param naive If true, the rank-approximate search will be performed
+ * by directly sampling the whole set instead of using the stratified
+ * sampling on the tree.
+ * @param singleMode If true, single-tree search will be used (as opposed to
+ * dual-tree search).
+ * @param leafSize Leaf size for tree construction (ignored if tree is given).
+ * @param metric An optional instance of the MetricType class.
+ */
+ RASearch(const typename TreeType::Mat& referenceSet,
+ const bool naive = false,
+ const bool singleMode = false,
+ const size_t leafSize = 20,
+ const MetricType metric = MetricType());
+
+ /**
+ * Initialize the RASearch object with the given datasets and
+ * pre-constructed trees. It is assumed that the points in referenceSet and
+ * querySet correspond to the points in referenceTree and queryTree,
+ * respectively. Optionally, choose to use single-tree mode. Naive mode is
+ * not available as an option for this constructor; instead, to run naive
+ * computation, construct a tree with all of the points in one leaf (i.e.
+ * leafSize = number of points). Additionally, an instantiated distance
+ * metric can be given, for cases where the distance metric holds data.
+ *
+ * There is no copying of the data matrices in this constructor (because
+ * tree-building is not necessary), so this is the constructor to use when
+ * copies absolutely must be avoided.
+ *
+ * @note
+ * Because tree-building (at least with BinarySpaceTree) modifies the ordering
+ * of a matrix, be sure you pass the modified matrix to this object! In
+ * addition, mapping the points of the matrix back to their original indices
+ * is not done when this constructor is used.
+ * @endnote
+ *
+ * @param referenceTree Pre-built tree for reference points.
+ * @param queryTree Pre-built tree for query points.
+ * @param referenceSet Set of reference points corresponding to referenceTree.
+ * @param querySet Set of query points corresponding to queryTree.
+ * @param singleMode Whether single-tree computation should be used (as
+ * opposed to dual-tree computation).
+ * @param metric Instantiated distance metric.
+ */
+ RASearch(TreeType* referenceTree,
+ TreeType* queryTree,
+ const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
+ const bool singleMode = false,
+ const MetricType metric = MetricType());
+
+ /**
+ * Initialize the RASearch object with the given reference dataset and
+ * pre-constructed tree. It is assumed that the points in referenceSet
+ * correspond to the points in referenceTree. Optionally, choose to use
+ * single-tree mode. Naive mode is not available as an option for this
+ * constructor; instead, to run naive computation, construct a tree with all
+ * the points in one leaf (i.e. leafSize = number of points). Additionally,
+ * an instantiated distance metric can be given, for the case where the
+ * distance metric holds data.
+ *
+ * There is no copying of the data matrices in this constructor (because
+ * tree-building is not necessary), so this is the constructor to use when
+ * copies absolutely must be avoided.
+ *
+ * @note
+ * Because tree-building (at least with BinarySpaceTree) modifies the ordering
+ * of a matrix, be sure you pass the modified matrix to this object! In
+ * addition, mapping the points of the matrix back to their original indices
+ * is not done when this constructor is used.
+ * @endnote
+ *
+ * @param referenceTree Pre-built tree for reference points.
+ * @param referenceSet Set of reference points corresponding to referenceTree.
+ * @param singleMode Whether single-tree computation should be used (as
+ * opposed to dual-tree computation).
+ * @param metric Instantiated distance metric.
+ */
+ RASearch(TreeType* referenceTree,
+ const typename TreeType::Mat& referenceSet,
+ const bool singleMode = false,
+ const MetricType metric = MetricType());
+
+ /**
+ * Delete the RASearch object. The tree is the only member we are
+ * responsible for deleting. The others will take care of themselves.
+ */
+ ~RASearch();
+
+ /**
+ * Compute the rank approximate nearest neighbors and store the output in the
+ * given matrices. The matrices will be set to the size of n columns by k
+ * rows, where n is the number of points in the query dataset and k is the
+ * number of neighbors being searched for.
+ *
+ * @param k Number of neighbors to search for.
+ * @param resultingNeighbors Matrix storing lists of neighbors for each query
+ * point.
+ * @param distances Matrix storing distances of neighbors for each query
+ * point.
+ * @param tau The rank-approximation in percentile of the data. The default
+ * value is 0.1%.
+ * @param alpha The desired success probability. The default value is 0.95.
+ * @param sampleAtLeaves Sample at leaves for faster but less accurate
+ * computation. This defaults to 'false'.
+ * @param firstLeafExact Traverse to the first leaf without approximation.
+ * This can ensure that the query definitely finds its (near) duplicate
+ * if there exists one. This defaults to 'false' for now.
+ * @param singleSampleLimit The limit on the largest node that can be
+ * approximated by sampling. This defaults to 20.
+ */
+ void Search(const size_t k,
+ arma::Mat<size_t>& resultingNeighbors,
+ arma::mat& distances,
+ const double tau = 0.1,
+ const double alpha = 0.95,
+ const bool sampleAtLeaves = false,
+ const bool firstLeafExact = false,
+ const size_t singleSampleLimit = 20);
+
+ /**
+ * This function recursively resets the RAQueryStat of the queryTree to set
+ * 'bound' to WorstDistance and the 'numSamplesMade' to 0. This allows a user
+ * to perform multiple searches on the same pair of trees, possibly with
+ * different levels of approximation without requiring to build a new pair of
+ * trees for every new (approximate) search.
+ */
+ void ResetQueryTree();
+
+ private:
+ //! Copy of reference dataset (if we need it, because tree building modifies
+ //! it).
+ arma::mat referenceCopy;
+ //! Copy of query dataset (if we need it, because tree building modifies it).
+ arma::mat queryCopy;
+
+ //! Reference dataset.
+ const arma::mat& referenceSet;
+ //! Query dataset (may not be given).
+ const arma::mat& querySet;
+
+ //! Pointer to the root of the reference tree.
+ TreeType* referenceTree;
+ //! Pointer to the root of the query tree (might not exist).
+ TreeType* queryTree;
+
+ //! Indicates if we should free the reference tree at deletion time.
+ bool ownReferenceTree;
+ //! Indicates if we should free the query tree at deletion time.
+ bool ownQueryTree;
+
+ //! Indicates if naive random sampling on the set is being used.
+ bool naive;
+ //! Indicates if single-tree search is being used (opposed to dual-tree).
+ bool singleMode;
+
+ //! Instantiation of kernel.
+ MetricType metric;
+
+ //! Permutations of reference points during tree building.
+ std::vector<size_t> oldFromNewReferences;
+ //! Permutations of query points during tree building.
+ std::vector<size_t> oldFromNewQueries;
+
+ //! Total number of pruned nodes during the neighbor search.
+ size_t numberOfPrunes;
+
+ /**
+ * @param treeNode The node of the tree whose RAQueryStat is reset
+ * and whose children are to be explored recursively.
+ */
+ void ResetRAQueryStat(TreeType* treeNode);
+}; // class RASearch
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+// Include implementation.
+#include "ra_search_impl.hpp"
+
+// Include convenient typedefs.
+#include "ra_typedef.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/rann/ra_search_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,370 +0,0 @@
-/**
- * @file ra_search_impl.hpp
- * @author Parikshit Ram
- *
- * Implementation of RASearch class to perform rank-approximate
- * all-nearest-neighbors on two specified data sets.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_RA_SEARCH_IMPL_HPP
-#define __MLPACK_METHODS_NEIGHBOR_SEARCH_RA_SEARCH_IMPL_HPP
-
-#include <mlpack/core.hpp>
-
-#include "ra_search_rules.hpp"
-
-namespace mlpack {
-namespace neighbor {
-
-// Construct the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-RASearch<SortPolicy, MetricType, TreeType>::
-RASearch(const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
- const bool naive,
- const bool singleMode,
- const size_t leafSize,
- const MetricType metric) :
- referenceCopy(referenceSet),
- queryCopy(querySet),
- referenceSet(referenceCopy),
- querySet(queryCopy),
- referenceTree(NULL),
- queryTree(NULL),
- ownReferenceTree(true), // False if a tree was passed.
- ownQueryTree(true), // False if a tree was passed.
- naive(naive),
- singleMode(!naive && singleMode), // No single mode if naive.
- metric(metric),
- numberOfPrunes(0)
-{
- // We'll time tree building.
- Timer::Start("tree_building");
-
- // Construct as a naive object if we need to.
- referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
- (naive ? referenceCopy.n_cols : leafSize));
-
- queryTree = new TreeType(queryCopy, oldFromNewQueries,
- (naive ? querySet.n_cols : leafSize));
-
- // Stop the timer we started above.
- Timer::Stop("tree_building");
-}
-
-// Construct the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-RASearch<SortPolicy, MetricType, TreeType>::
-RASearch(const typename TreeType::Mat& referenceSet,
- const bool naive,
- const bool singleMode,
- const size_t leafSize,
- const MetricType metric) :
- referenceCopy(referenceSet),
- referenceSet(referenceCopy),
- querySet(referenceCopy),
- referenceTree(NULL),
- queryTree(NULL),
- ownReferenceTree(true),
- ownQueryTree(false), // Since it will be the same as referenceTree.
- naive(naive),
- singleMode(!naive && singleMode), // No single mode if naive.
- metric(metric),
- numberOfPrunes(0)
-{
- // We'll time tree building.
- Timer::Start("tree_building");
-
- // Construct as a naive object if we need to.
- referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
- (naive ? referenceSet.n_cols : leafSize));
-
- // Stop the timer we started above.
- Timer::Stop("tree_building");
-}
-
-// Construct the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-RASearch<SortPolicy, MetricType, TreeType>::
-RASearch(TreeType* referenceTree,
- TreeType* queryTree,
- const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
- const bool singleMode,
- const MetricType metric) :
- referenceSet(referenceSet),
- querySet(querySet),
- referenceTree(referenceTree),
- queryTree(queryTree),
- ownReferenceTree(false),
- ownQueryTree(false),
- naive(false),
- singleMode(singleMode),
- metric(metric),
- numberOfPrunes(0)
-// Nothing else to initialize.
-{ }
-
-// Construct the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-RASearch<SortPolicy, MetricType, TreeType>::
-RASearch(TreeType* referenceTree,
- const typename TreeType::Mat& referenceSet,
- const bool singleMode,
- const MetricType metric) :
- referenceSet(referenceSet),
- querySet(referenceSet),
- referenceTree(referenceTree),
- queryTree(NULL),
- ownReferenceTree(false),
- ownQueryTree(false),
- naive(false),
- singleMode(singleMode),
- metric(metric),
- numberOfPrunes(0)
-// Nothing else to initialize.
-{ }
-
-/**
- * The tree is the only member we may be responsible for deleting. The others
- * will take care of themselves.
- */
-template<typename SortPolicy, typename MetricType, typename TreeType>
-RASearch<SortPolicy, MetricType, TreeType>::
-~RASearch()
-{
- if (ownReferenceTree)
- delete referenceTree;
- if (ownQueryTree)
- delete queryTree;
-}
-
-/**
- * Computes the best neighbors and stores them in resultingNeighbors and
- * distances.
- */
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void RASearch<SortPolicy, MetricType, TreeType>::
-Search(const size_t k,
- arma::Mat<size_t>& resultingNeighbors,
- arma::mat& distances,
- const double tau,
- const double alpha,
- const bool sampleAtLeaves,
- const bool firstLeafExact,
- const size_t singleSampleLimit)
-{
- Timer::Start("computing_neighbors");
-
- // If we have built the trees ourselves, then we will have to map all the
- // indices back to their original indices when this computation is finished.
- // To avoid an extra copy, we will store the neighbors and distances in a
- // separate matrix.
- arma::Mat<size_t>* neighborPtr = &resultingNeighbors;
- arma::mat* distancePtr = &distances;
-
- if (!naive) // If naive, no re-mapping required since points are not mapped.
- {
- if (ownQueryTree || (ownReferenceTree && !queryTree))
- distancePtr = new arma::mat; // Query indices need to be mapped.
- if (ownReferenceTree || ownQueryTree)
- neighborPtr = new arma::Mat<size_t>; // All indices need mapping.
- }
-
- // Set the size of the neighbor and distance matrices.
- neighborPtr->set_size(k, querySet.n_cols);
- distancePtr->set_size(k, querySet.n_cols);
- distancePtr->fill(SortPolicy::WorstDistance());
-
- size_t numPrunes = 0;
-
- if (singleMode || naive)
- {
- // Create the helper object for the tree traversal.
- typedef RASearchRules<SortPolicy, MetricType, TreeType> RuleType;
- RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr,
- metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact,
- singleSampleLimit);
-
- if (!referenceTree->IsLeaf())
- {
- Log::Info << "Performing single-tree traversal..." << std::endl;
-
- // Create the traverser.
- typename TreeType::template SingleTreeTraverser<RuleType>
- traverser(rules);
-
- // Now have it traverse for each point.
- for (size_t i = 0; i < querySet.n_cols; ++i)
- traverser.Traverse(i, *referenceTree);
-
- numPrunes = traverser.NumPrunes();
- }
- else
- {
- assert(naive);
- Log::Info << "Naive sampling already done!" << std::endl;
- }
-
- Log::Info << "Single-tree traversal done; number of distance calculations: "
- << (rules.NumDistComputations() / querySet.n_cols) << std::endl;
- }
- else // Dual-tree recursion.
- {
- Log::Info << "Performing dual-tree traversal..." << std::endl;
-
- typedef RASearchRules<SortPolicy, MetricType, TreeType> RuleType;
- RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr,
- metric, tau, alpha, sampleAtLeaves, firstLeafExact,
- singleSampleLimit);
-
- typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
-
- Log::Info << "Dual-tree traversal; query statistic pre-search: "
- << queryTree->Stat().NumSamplesMade() << std::endl;
-
- if (queryTree)
- traverser.Traverse(*queryTree, *referenceTree);
- else
- traverser.Traverse(*referenceTree, *referenceTree);
-
- numPrunes = traverser.NumPrunes();
-
- Log::Info << "Dual-tree traversal done; number of distance calculations: "
- << (rules.NumDistComputations() / querySet.n_cols) << std::endl;
- }
-
- Timer::Stop("computing_neighbors");
- Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
-
- // Now, do we need to do mapping of indices?
- if ((!ownReferenceTree && !ownQueryTree) || naive)
- {
- // No mapping needed if we do not own the trees or if we are doing naive
- // sampling. We are done.
- return;
- }
- else if (ownReferenceTree && ownQueryTree) // Map references and queries.
- {
- // Set size of output matrices correctly.
- resultingNeighbors.set_size(k, querySet.n_cols);
- distances.set_size(k, querySet.n_cols);
-
- for (size_t i = 0; i < distances.n_cols; i++)
- {
- // Map distances (copy a column).
- distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
-
- // Map indices of neighbors.
- for (size_t j = 0; j < distances.n_rows; j++)
- {
- resultingNeighbors(j, oldFromNewQueries[i]) =
- oldFromNewReferences[(*neighborPtr)(j, i)];
- }
- }
-
- // Finished with temporary matrices.
- delete neighborPtr;
- delete distancePtr;
- }
- else if (ownReferenceTree)
- {
- if (!queryTree) // No query tree -- map both references and queries.
- {
- resultingNeighbors.set_size(k, querySet.n_cols);
- distances.set_size(k, querySet.n_cols);
-
- for (size_t i = 0; i < distances.n_cols; i++)
- {
- // Map distances (copy a column).
- distances.col(oldFromNewReferences[i]) = distancePtr->col(i);
-
- // Map indices of neighbors.
- for (size_t j = 0; j < distances.n_rows; j++)
- {
- resultingNeighbors(j, oldFromNewReferences[i]) =
- oldFromNewReferences[(*neighborPtr)(j, i)];
- }
- }
- }
- else // Map only references.
- {
- // Set size of neighbor indices matrix correctly.
- resultingNeighbors.set_size(k, querySet.n_cols);
-
- // Map indices of neighbors.
- for (size_t i = 0; i < resultingNeighbors.n_cols; i++)
- {
- for (size_t j = 0; j < resultingNeighbors.n_rows; j++)
- {
- resultingNeighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
- }
- }
- }
-
- // Finished with temporary matrix.
- delete neighborPtr;
- }
- else if (ownQueryTree)
- {
- // Set size of matrices correctly.
- resultingNeighbors.set_size(k, querySet.n_cols);
- distances.set_size(k, querySet.n_cols);
-
- for (size_t i = 0; i < distances.n_cols; i++)
- {
- // Map distances (copy a column).
- distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
-
- // Map indices of neighbors.
- resultingNeighbors.col(oldFromNewQueries[i]) = neighborPtr->col(i);
- }
-
- // Finished with temporary matrices.
- delete neighborPtr;
- delete distancePtr;
- }
-} // Search
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void RASearch<SortPolicy, MetricType, TreeType>::
-ResetQueryTree()
-{
- if (!singleMode)
- {
- if (queryTree)
- ResetRAQueryStat(queryTree);
- else
- ResetRAQueryStat(referenceTree);
- }
-}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void RASearch<SortPolicy, MetricType, TreeType>::
-ResetRAQueryStat(TreeType* treeNode)
-{
- treeNode->Stat().Bound() = SortPolicy::WorstDistance();
- treeNode->Stat().NumSamplesMade() = 0;
-
- for (size_t i = 0; i < treeNode->NumChildren(); i++)
- ResetRAQueryStat(&treeNode->Child(i));
-} // ResetRAQueryStat
-
-}; // namespace neighbor
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/rann/ra_search_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,370 @@
+/**
+ * @file ra_search_impl.hpp
+ * @author Parikshit Ram
+ *
+ * Implementation of RASearch class to perform rank-approximate
+ * all-nearest-neighbors on two specified data sets.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_RA_SEARCH_IMPL_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_RA_SEARCH_IMPL_HPP
+
+#include <mlpack/core.hpp>
+
+#include "ra_search_rules.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+// Construct the object.
+template<typename SortPolicy, typename MetricType, typename TreeType>
+RASearch<SortPolicy, MetricType, TreeType>::
+RASearch(const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
+ const bool naive,
+ const bool singleMode,
+ const size_t leafSize,
+ const MetricType metric) :
+ referenceCopy(referenceSet),
+ queryCopy(querySet),
+ referenceSet(referenceCopy),
+ querySet(queryCopy),
+ referenceTree(NULL),
+ queryTree(NULL),
+ ownReferenceTree(true), // False if a tree was passed.
+ ownQueryTree(true), // False if a tree was passed.
+ naive(naive),
+ singleMode(!naive && singleMode), // No single mode if naive.
+ metric(metric),
+ numberOfPrunes(0)
+{
+ // We'll time tree building.
+ Timer::Start("tree_building");
+
+ // Construct as a naive object if we need to.
+ referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
+ (naive ? referenceCopy.n_cols : leafSize));
+
+ queryTree = new TreeType(queryCopy, oldFromNewQueries,
+ (naive ? querySet.n_cols : leafSize));
+
+ // Stop the timer we started above.
+ Timer::Stop("tree_building");
+}
+
+// Construct the object.
+template<typename SortPolicy, typename MetricType, typename TreeType>
+RASearch<SortPolicy, MetricType, TreeType>::
+RASearch(const typename TreeType::Mat& referenceSet,
+ const bool naive,
+ const bool singleMode,
+ const size_t leafSize,
+ const MetricType metric) :
+ referenceCopy(referenceSet),
+ referenceSet(referenceCopy),
+ querySet(referenceCopy),
+ referenceTree(NULL),
+ queryTree(NULL),
+ ownReferenceTree(true),
+ ownQueryTree(false), // Since it will be the same as referenceTree.
+ naive(naive),
+ singleMode(!naive && singleMode), // No single mode if naive.
+ metric(metric),
+ numberOfPrunes(0)
+{
+ // We'll time tree building.
+ Timer::Start("tree_building");
+
+ // Construct as a naive object if we need to.
+ referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
+ (naive ? referenceSet.n_cols : leafSize));
+
+ // Stop the timer we started above.
+ Timer::Stop("tree_building");
+}
+
+// Construct the object.
+template<typename SortPolicy, typename MetricType, typename TreeType>
+RASearch<SortPolicy, MetricType, TreeType>::
+RASearch(TreeType* referenceTree,
+ TreeType* queryTree,
+ const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
+ const bool singleMode,
+ const MetricType metric) :
+ referenceSet(referenceSet),
+ querySet(querySet),
+ referenceTree(referenceTree),
+ queryTree(queryTree),
+ ownReferenceTree(false),
+ ownQueryTree(false),
+ naive(false),
+ singleMode(singleMode),
+ metric(metric),
+ numberOfPrunes(0)
+// Nothing else to initialize.
+{ }
+
+// Construct the object.
+template<typename SortPolicy, typename MetricType, typename TreeType>
+RASearch<SortPolicy, MetricType, TreeType>::
+RASearch(TreeType* referenceTree,
+ const typename TreeType::Mat& referenceSet,
+ const bool singleMode,
+ const MetricType metric) :
+ referenceSet(referenceSet),
+ querySet(referenceSet),
+ referenceTree(referenceTree),
+ queryTree(NULL),
+ ownReferenceTree(false),
+ ownQueryTree(false),
+ naive(false),
+ singleMode(singleMode),
+ metric(metric),
+ numberOfPrunes(0)
+// Nothing else to initialize.
+{ }
+
+/**
+ * The tree is the only member we may be responsible for deleting. The others
+ * will take care of themselves.
+ */
+template<typename SortPolicy, typename MetricType, typename TreeType>
+RASearch<SortPolicy, MetricType, TreeType>::
+~RASearch()
+{
+ if (ownReferenceTree)
+ delete referenceTree;
+ if (ownQueryTree)
+ delete queryTree;
+}
+
+/**
+ * Computes the best neighbors and stores them in resultingNeighbors and
+ * distances.
+ */
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void RASearch<SortPolicy, MetricType, TreeType>::
+Search(const size_t k,
+ arma::Mat<size_t>& resultingNeighbors,
+ arma::mat& distances,
+ const double tau,
+ const double alpha,
+ const bool sampleAtLeaves,
+ const bool firstLeafExact,
+ const size_t singleSampleLimit)
+{
+ Timer::Start("computing_neighbors");
+
+ // If we have built the trees ourselves, then we will have to map all the
+ // indices back to their original indices when this computation is finished.
+ // To avoid an extra copy, we will store the neighbors and distances in a
+ // separate matrix.
+ arma::Mat<size_t>* neighborPtr = &resultingNeighbors;
+ arma::mat* distancePtr = &distances;
+
+ if (!naive) // If naive, no re-mapping required since points are not mapped.
+ {
+ if (ownQueryTree || (ownReferenceTree && !queryTree))
+ distancePtr = new arma::mat; // Query indices need to be mapped.
+ if (ownReferenceTree || ownQueryTree)
+ neighborPtr = new arma::Mat<size_t>; // All indices need mapping.
+ }
+
+ // Set the size of the neighbor and distance matrices.
+ neighborPtr->set_size(k, querySet.n_cols);
+ distancePtr->set_size(k, querySet.n_cols);
+ distancePtr->fill(SortPolicy::WorstDistance());
+
+ size_t numPrunes = 0;
+
+ if (singleMode || naive)
+ {
+ // Create the helper object for the tree traversal.
+ typedef RASearchRules<SortPolicy, MetricType, TreeType> RuleType;
+ RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr,
+ metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact,
+ singleSampleLimit);
+
+ if (!referenceTree->IsLeaf())
+ {
+ Log::Info << "Performing single-tree traversal..." << std::endl;
+
+ // Create the traverser.
+ typename TreeType::template SingleTreeTraverser<RuleType>
+ traverser(rules);
+
+ // Now have it traverse for each point.
+ for (size_t i = 0; i < querySet.n_cols; ++i)
+ traverser.Traverse(i, *referenceTree);
+
+ numPrunes = traverser.NumPrunes();
+ }
+ else
+ {
+ assert(naive);
+ Log::Info << "Naive sampling already done!" << std::endl;
+ }
+
+ Log::Info << "Single-tree traversal done; number of distance calculations: "
+ << (rules.NumDistComputations() / querySet.n_cols) << std::endl;
+ }
+ else // Dual-tree recursion.
+ {
+ Log::Info << "Performing dual-tree traversal..." << std::endl;
+
+ typedef RASearchRules<SortPolicy, MetricType, TreeType> RuleType;
+ RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr,
+ metric, tau, alpha, sampleAtLeaves, firstLeafExact,
+ singleSampleLimit);
+
+ typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+
+ Log::Info << "Dual-tree traversal; query statistic pre-search: "
+ << queryTree->Stat().NumSamplesMade() << std::endl;
+
+ if (queryTree)
+ traverser.Traverse(*queryTree, *referenceTree);
+ else
+ traverser.Traverse(*referenceTree, *referenceTree);
+
+ numPrunes = traverser.NumPrunes();
+
+ Log::Info << "Dual-tree traversal done; number of distance calculations: "
+ << (rules.NumDistComputations() / querySet.n_cols) << std::endl;
+ }
+
+ Timer::Stop("computing_neighbors");
+ Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
+
+ // Now, do we need to do mapping of indices?
+ if ((!ownReferenceTree && !ownQueryTree) || naive)
+ {
+ // No mapping needed if we do not own the trees or if we are doing naive
+ // sampling. We are done.
+ return;
+ }
+ else if (ownReferenceTree && ownQueryTree) // Map references and queries.
+ {
+ // Set size of output matrices correctly.
+ resultingNeighbors.set_size(k, querySet.n_cols);
+ distances.set_size(k, querySet.n_cols);
+
+ for (size_t i = 0; i < distances.n_cols; i++)
+ {
+ // Map distances (copy a column).
+ distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
+
+ // Map indices of neighbors.
+ for (size_t j = 0; j < distances.n_rows; j++)
+ {
+ resultingNeighbors(j, oldFromNewQueries[i]) =
+ oldFromNewReferences[(*neighborPtr)(j, i)];
+ }
+ }
+
+ // Finished with temporary matrices.
+ delete neighborPtr;
+ delete distancePtr;
+ }
+ else if (ownReferenceTree)
+ {
+ if (!queryTree) // No query tree -- map both references and queries.
+ {
+ resultingNeighbors.set_size(k, querySet.n_cols);
+ distances.set_size(k, querySet.n_cols);
+
+ for (size_t i = 0; i < distances.n_cols; i++)
+ {
+ // Map distances (copy a column).
+ distances.col(oldFromNewReferences[i]) = distancePtr->col(i);
+
+ // Map indices of neighbors.
+ for (size_t j = 0; j < distances.n_rows; j++)
+ {
+ resultingNeighbors(j, oldFromNewReferences[i]) =
+ oldFromNewReferences[(*neighborPtr)(j, i)];
+ }
+ }
+ }
+ else // Map only references.
+ {
+ // Set size of neighbor indices matrix correctly.
+ resultingNeighbors.set_size(k, querySet.n_cols);
+
+ // Map indices of neighbors.
+ for (size_t i = 0; i < resultingNeighbors.n_cols; i++)
+ {
+ for (size_t j = 0; j < resultingNeighbors.n_rows; j++)
+ {
+ resultingNeighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
+ }
+ }
+ }
+
+ // Finished with temporary matrix.
+ delete neighborPtr;
+ }
+ else if (ownQueryTree)
+ {
+ // Set size of matrices correctly.
+ resultingNeighbors.set_size(k, querySet.n_cols);
+ distances.set_size(k, querySet.n_cols);
+
+ for (size_t i = 0; i < distances.n_cols; i++)
+ {
+ // Map distances (copy a column).
+ distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
+
+ // Map indices of neighbors.
+ resultingNeighbors.col(oldFromNewQueries[i]) = neighborPtr->col(i);
+ }
+
+ // Finished with temporary matrices.
+ delete neighborPtr;
+ delete distancePtr;
+ }
+} // Search
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void RASearch<SortPolicy, MetricType, TreeType>::
+ResetQueryTree()
+{
+ if (!singleMode)
+ {
+ if (queryTree)
+ ResetRAQueryStat(queryTree);
+ else
+ ResetRAQueryStat(referenceTree);
+ }
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void RASearch<SortPolicy, MetricType, TreeType>::
+ResetRAQueryStat(TreeType* treeNode)
+{
+ treeNode->Stat().Bound() = SortPolicy::WorstDistance();
+ treeNode->Stat().NumSamplesMade() = 0;
+
+ for (size_t i = 0; i < treeNode->NumChildren(); i++)
+ ResetRAQueryStat(&treeNode->Child(i));
+} // ResetRAQueryStat
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search_rules.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/rann/ra_search_rules.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search_rules.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,325 +0,0 @@
-/**
- * @file ra_search_rules.hpp
- * @author Parikshit Ram
- *
- * Defines the pruning rules and base case rules necessary to perform a
- * tree-based rank-approximate search (with an arbitrary tree)
- * for the RASearch class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_RA_SEARCH_RULES_HPP
-#define __MLPACK_METHODS_NEIGHBOR_SEARCH_RA_SEARCH_RULES_HPP
-
-namespace mlpack {
-namespace neighbor {
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-class RASearchRules
-{
- public:
- RASearchRules(const arma::mat& referenceSet,
- const arma::mat& querySet,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances,
- MetricType& metric,
- const double tau = 0.1,
- const double alpha = 0.95,
- const bool naive = false,
- const bool sampleAtLeaves = false,
- const bool firstLeafExact = false,
- const size_t singleSampleLimit = 20);
-
-
-
- double BaseCase(const size_t queryIndex, const size_t referenceIndex);
-
- /**
- * TOFIX: This function is specified for the cover tree (usually) so
- * I need to think about it more algorithmically and keep its
- * implementation mostly empty.
- * Also, since the access to the points in the subtree of a cover tree
- * is non-trivial, we might have to re-work this.
- * FOR NOW: I am just using as for a BSP-tree, I will fix it when
- * we figure out cover trees.
- *
- */
-
- double Prescore(TreeType& queryNode,
- TreeType& referenceNode,
- TreeType& referenceChildNode,
- const double baseCaseResult) const;
- double PrescoreQ(TreeType& queryNode,
- TreeType& queryChildNode,
- TreeType& referenceNode,
- const double baseCaseResult) const;
-
-
-
- /**
- * Get the score for recursion order. A low score indicates priority for
- * recursion, while DBL_MAX indicates that the node should not be recursed
- * into at all (it should be pruned).
- *
- * For rank-approximation, the scoring function first checks if pruning
- * by distance is possible.
- * If yes, then the node is given the score of
- * 'DBL_MAX' and the expected number of samples from that node are
- * added to the number of samples made for the query.
- *
- * If no, then the function tries to see if the node can be pruned by
- * approximation. If number of samples required from this node is small
- * enough, then that number of samples are acquired from this node
- * and the score is set to be 'DBL_MAX'.
- *
- * If the pruning by approximation is not possible either, the algorithm
- * continues with the usual tree-traversal.
- *
- * @param queryIndex Index of query point.
- * @param referenceNode Candidate node to be recursed into.
- */
- double Score(const size_t queryIndex, TreeType& referenceNode);
-
- /**
- * Get the score for recursion order. A low score indicates priority for
- * recursion, while DBL_MAX indicates that the node should not be recursed
- * into at all (it should be pruned).
- *
- * For rank-approximation, the scoring function first checks if pruning
- * by distance is possible.
- * If yes, then the node is given the score of
- * 'DBL_MAX' and the expected number of samples from that node are
- * added to the number of samples made for the query.
- *
- * If no, then the function tries to see if the node can be pruned by
- * approximation. If number of samples required from this node is small
- * enough, then that number of samples are acquired from this node
- * and the score is set to be 'DBL_MAX'.
- *
- * If the pruning by approximation is not possible either, the algorithm
- * continues with the usual tree-traversal.
- *
- * @param queryIndex Index of query point.
- * @param referenceNode Candidate node to be recursed into.
- * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
- */
- double Score(const size_t queryIndex,
- TreeType& referenceNode,
- const double baseCaseResult);
-
- /**
- * Re-evaluate the score for recursion order. A low score indicates priority
- * for recursion, while DBL_MAX indicates that the node should not be
- * recursed into at all (it should be pruned). This is used when the score
- * has already been calculated, but another recursion may have modified the
- * bounds for pruning. So the old score is checked against the new pruning
- * bound.
- *
- * For rank-approximation, it also checks if the number of samples left
- * for a query to satisfy the rank constraint is small enough at this
- * point of the algorithm, then this node is approximated by sampling
- * and given a new score of 'DBL_MAX'.
- *
- * @param queryIndex Index of query point.
- * @param referenceNode Candidate node to be recursed into.
- * @param oldScore Old score produced by Score() (or Rescore()).
- */
- double Rescore(const size_t queryIndex,
- TreeType& referenceNode,
- const double oldScore);
-
- /**
- * Get the score for recursion order. A low score indicates priority for
- * recursionm while DBL_MAX indicates that the node should not be recursed
- * into at all (it should be pruned).
- *
- * For the rank-approximation, we check if the referenceNode can be
- * approximated by sampling. If it can be, enough samples are made for
- * every query in the queryNode. No further query-tree traversal is
- * performed.
- *
- * The 'NumSamplesMade' query stat is propagated up the tree. And then
- * if pruning occurs (by distance or by sampling), the 'NumSamplesMade'
- * stat is not propagated down the tree. If no pruning occurs, the
- * stat is propagated down the tree.
- *
- * @param queryNode Candidate query node to recurse into.
- * @param referenceNode Candidate reference node to recurse into.
- */
- double Score(TreeType& queryNode, TreeType& referenceNode);
-
- /**
- * Get the score for recursion order, passing the base case result (in the
- * situation where it may be needed to calculate the recursion order). A low
- * score indicates priority for recursion, while DBL_MAX indicates that the
- * node should not be recursed into at all (it should be pruned).
- *
- * For the rank-approximation, we check if the referenceNode can be
- * approximated by sampling. If it can be, enough samples are made for
- * every query in the queryNode. No further query-tree traversal is
- * performed.
- *
- * The 'NumSamplesMade' query stat is propagated up the tree. And then
- * if pruning occurs (by distance or by sampling), the 'NumSamplesMade'
- * stat is not propagated down the tree. If no pruning occurs, the
- * stat is propagated down the tree.
- *
- * @param queryNode Candidate query node to recurse into.
- * @param referenceNode Candidate reference node to recurse into.
- * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
- */
- double Score(TreeType& queryNode,
- TreeType& referenceNode,
- const double baseCaseResult);
-
- /**
- * Re-evaluate the score for recursion order. A low score indicates priority
- * for recursion, while DBL_MAX indicates that the node should not be
- * recursed into at all (it should be pruned). This is used when the score
- * has already been calculated, but another recursion may have modified the
- * bounds for pruning. So the old score is checked against the new pruning
- * bound.
- *
- * For the rank-approximation, we check if the referenceNode can be
- * approximated by sampling. If it can be, enough samples are made for
- * every query in the queryNode. No further query-tree traversal is
- * performed.
- *
- * The 'NumSamplesMade' query stat is propagated up the tree. And then
- * if pruning occurs (by distance or by sampling), the 'NumSamplesMade'
- * stat is not propagated down the tree. If no pruning occurs, the
- * stat is propagated down the tree.
- *
- * @param queryNode Candidate query node to recurse into.
- * @param referenceNode Candidate reference node to recurse into.
- * @param oldScore Old score produced by Socre() (or Rescore()).
- */
- double Rescore(TreeType& queryNode,
- TreeType& referenceNode,
- const double oldScore);
-
-
- size_t NumDistComputations() { return numDistComputations; }
- size_t NumEffectiveSamples()
- {
- if (numSamplesMade.n_elem == 0)
- return 0;
- else
- return arma::sum(numSamplesMade);
- }
-
- private:
- //! The reference set.
- const arma::mat& referenceSet;
-
- //! The query set.
- const arma::mat& querySet;
-
- //! The matrix the resultant neighbor indices should be stored in.
- arma::Mat<size_t>& neighbors;
-
- //! The matrix the resultant neighbor distances should be stored in.
- arma::mat& distances;
-
- //! The instantiated metric.
- MetricType& metric;
-
- //! Whether to sample at leaves or just use all of it
- bool sampleAtLeaves;
-
- //! Whether to do exact computation on the first leaf before any sampling
- bool firstLeafExact;
-
- //! The limit on the largest node that can be approximated by sampling
- size_t singleSampleLimit;
-
- //! The minimum number of samples required per query
- size_t numSamplesReqd;
-
- //! The number of samples made for every query
- arma::Col<size_t> numSamplesMade;
-
- //! The sampling ratio
- double samplingRatio;
-
- // TO REMOVE: just for testing
- size_t numDistComputations;
-
-
- /**
- * Insert a point into the neighbors and distances matrices; this is a helper
- * function.
- *
- * @param queryIndex Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
- * @param neighbor Index of reference point which is being inserted.
- * @param distance Distance from query point to reference point.
- */
- void InsertNeighbor(const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance);
-
- /**
- * Compute the minimum number of samples required to guarantee
- * the given rank-approximation and success probability.
- *
- * @param n Size of the set to be sampled from.
- * @param k The number of neighbors required within the rank-approximation.
- * @param tau The rank-approximation in percentile of the data.
- * @param alpha The success probability desired.
- */
- size_t MinimumSamplesReqd(const size_t n,
- const size_t k,
- const double tau,
- const double alpha) const;
-
- /**
- * Compute the success probability of obtaining 'k'-neighbors from a
- * set of size 'n' within the top 't' neighbors if 'm' samples are made.
- *
- * @param n Size of the set being sampled from.
- * @param k The number of neighbors required within the rank-approximation.
- * @param m The number of random samples.
- * @param t The desired rank-approximation.
- */
- double SuccessProbability(const size_t n,
- const size_t k,
- const size_t m,
- const size_t t) const;
-
- /**
- * Pick up desired number of samples (with replacement) from a given range
- * of integers so that only the distinct samples are returned from
- * the range [0 - specified upper bound)
- *
- * @param numSamples Number of random samples.
- * @param rangeUpperBound The upper bound on the range of integers.
- * @param distinctSamples The list of the distinct samples.
- */
- void ObtainDistinctSamples(const size_t numSamples,
- const size_t rangeUpperBound,
- arma::uvec& distinctSamples) const;
-
-}; // class RASearchRules
-
-}; // namespace neighbor
-}; // namespace mlpack
-
-// Include implementation.
-#include "ra_search_rules_impl.hpp"
-
-#endif // __MLPACK_METHODS_NEIGHBOR_SEARCH_RA_SEARCH_RULES_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search_rules.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/rann/ra_search_rules.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search_rules.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search_rules.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,325 @@
+/**
+ * @file ra_search_rules.hpp
+ * @author Parikshit Ram
+ *
+ * Defines the pruning rules and base case rules necessary to perform a
+ * tree-based rank-approximate search (with an arbitrary tree)
+ * for the RASearch class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_RA_SEARCH_RULES_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_RA_SEARCH_RULES_HPP
+
+namespace mlpack {
+namespace neighbor {
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+class RASearchRules
+{
+ public:
+ RASearchRules(const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances,
+ MetricType& metric,
+ const double tau = 0.1,
+ const double alpha = 0.95,
+ const bool naive = false,
+ const bool sampleAtLeaves = false,
+ const bool firstLeafExact = false,
+ const size_t singleSampleLimit = 20);
+
+
+
+ double BaseCase(const size_t queryIndex, const size_t referenceIndex);
+
+ /**
+ * TOFIX: This function is specified for the cover tree (usually) so
+ * I need to think about it more algorithmically and keep its
+ * implementation mostly empty.
+ * Also, since the access to the points in the subtree of a cover tree
+ * is non-trivial, we might have to re-work this.
+ * FOR NOW: I am just using as for a BSP-tree, I will fix it when
+ * we figure out cover trees.
+ *
+ */
+
+ double Prescore(TreeType& queryNode,
+ TreeType& referenceNode,
+ TreeType& referenceChildNode,
+ const double baseCaseResult) const;
+ double PrescoreQ(TreeType& queryNode,
+ TreeType& queryChildNode,
+ TreeType& referenceNode,
+ const double baseCaseResult) const;
+
+
+
+ /**
+ * Get the score for recursion order. A low score indicates priority for
+ * recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned).
+ *
+ * For rank-approximation, the scoring function first checks if pruning
+ * by distance is possible.
+ * If yes, then the node is given the score of
+ * 'DBL_MAX' and the expected number of samples from that node are
+ * added to the number of samples made for the query.
+ *
+ * If no, then the function tries to see if the node can be pruned by
+ * approximation. If number of samples required from this node is small
+ * enough, then that number of samples are acquired from this node
+ * and the score is set to be 'DBL_MAX'.
+ *
+ * If the pruning by approximation is not possible either, the algorithm
+ * continues with the usual tree-traversal.
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ */
+ double Score(const size_t queryIndex, TreeType& referenceNode);
+
+ /**
+ * Get the score for recursion order. A low score indicates priority for
+ * recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned).
+ *
+ * For rank-approximation, the scoring function first checks if pruning
+ * by distance is possible.
+ * If yes, then the node is given the score of
+ * 'DBL_MAX' and the expected number of samples from that node are
+ * added to the number of samples made for the query.
+ *
+ * If no, then the function tries to see if the node can be pruned by
+ * approximation. If number of samples required from this node is small
+ * enough, then that number of samples are acquired from this node
+ * and the score is set to be 'DBL_MAX'.
+ *
+ * If the pruning by approximation is not possible either, the algorithm
+ * continues with the usual tree-traversal.
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
+ */
+ double Score(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double baseCaseResult);
+
+ /**
+ * Re-evaluate the score for recursion order. A low score indicates priority
+ * for recursion, while DBL_MAX indicates that the node should not be
+ * recursed into at all (it should be pruned). This is used when the score
+ * has already been calculated, but another recursion may have modified the
+ * bounds for pruning. So the old score is checked against the new pruning
+ * bound.
+ *
+ * For rank-approximation, it also checks if the number of samples left
+ * for a query to satisfy the rank constraint is small enough at this
+ * point of the algorithm, then this node is approximated by sampling
+ * and given a new score of 'DBL_MAX'.
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ * @param oldScore Old score produced by Score() (or Rescore()).
+ */
+ double Rescore(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double oldScore);
+
+ /**
+ * Get the score for recursion order. A low score indicates priority for
+ * recursionm while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned).
+ *
+ * For the rank-approximation, we check if the referenceNode can be
+ * approximated by sampling. If it can be, enough samples are made for
+ * every query in the queryNode. No further query-tree traversal is
+ * performed.
+ *
+ * The 'NumSamplesMade' query stat is propagated up the tree. And then
+ * if pruning occurs (by distance or by sampling), the 'NumSamplesMade'
+ * stat is not propagated down the tree. If no pruning occurs, the
+ * stat is propagated down the tree.
+ *
+ * @param queryNode Candidate query node to recurse into.
+ * @param referenceNode Candidate reference node to recurse into.
+ */
+ double Score(TreeType& queryNode, TreeType& referenceNode);
+
+ /**
+ * Get the score for recursion order, passing the base case result (in the
+ * situation where it may be needed to calculate the recursion order). A low
+ * score indicates priority for recursion, while DBL_MAX indicates that the
+ * node should not be recursed into at all (it should be pruned).
+ *
+ * For the rank-approximation, we check if the referenceNode can be
+ * approximated by sampling. If it can be, enough samples are made for
+ * every query in the queryNode. No further query-tree traversal is
+ * performed.
+ *
+ * The 'NumSamplesMade' query stat is propagated up the tree. And then
+ * if pruning occurs (by distance or by sampling), the 'NumSamplesMade'
+ * stat is not propagated down the tree. If no pruning occurs, the
+ * stat is propagated down the tree.
+ *
+ * @param queryNode Candidate query node to recurse into.
+ * @param referenceNode Candidate reference node to recurse into.
+ * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
+ */
+ double Score(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double baseCaseResult);
+
+ /**
+ * Re-evaluate the score for recursion order. A low score indicates priority
+ * for recursion, while DBL_MAX indicates that the node should not be
+ * recursed into at all (it should be pruned). This is used when the score
+ * has already been calculated, but another recursion may have modified the
+ * bounds for pruning. So the old score is checked against the new pruning
+ * bound.
+ *
+ * For the rank-approximation, we check if the referenceNode can be
+ * approximated by sampling. If it can be, enough samples are made for
+ * every query in the queryNode. No further query-tree traversal is
+ * performed.
+ *
+ * The 'NumSamplesMade' query stat is propagated up the tree. And then
+ * if pruning occurs (by distance or by sampling), the 'NumSamplesMade'
+ * stat is not propagated down the tree. If no pruning occurs, the
+ * stat is propagated down the tree.
+ *
+ * @param queryNode Candidate query node to recurse into.
+ * @param referenceNode Candidate reference node to recurse into.
+ * @param oldScore Old score produced by Socre() (or Rescore()).
+ */
+ double Rescore(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double oldScore);
+
+
+ size_t NumDistComputations() { return numDistComputations; }
+ size_t NumEffectiveSamples()
+ {
+ if (numSamplesMade.n_elem == 0)
+ return 0;
+ else
+ return arma::sum(numSamplesMade);
+ }
+
+ private:
+ //! The reference set.
+ const arma::mat& referenceSet;
+
+ //! The query set.
+ const arma::mat& querySet;
+
+ //! The matrix the resultant neighbor indices should be stored in.
+ arma::Mat<size_t>& neighbors;
+
+ //! The matrix the resultant neighbor distances should be stored in.
+ arma::mat& distances;
+
+ //! The instantiated metric.
+ MetricType& metric;
+
+ //! Whether to sample at leaves or just use all of it
+ bool sampleAtLeaves;
+
+ //! Whether to do exact computation on the first leaf before any sampling
+ bool firstLeafExact;
+
+ //! The limit on the largest node that can be approximated by sampling
+ size_t singleSampleLimit;
+
+ //! The minimum number of samples required per query
+ size_t numSamplesReqd;
+
+ //! The number of samples made for every query
+ arma::Col<size_t> numSamplesMade;
+
+ //! The sampling ratio
+ double samplingRatio;
+
+ // TO REMOVE: just for testing
+ size_t numDistComputations;
+
+
+ /**
+ * Insert a point into the neighbors and distances matrices; this is a helper
+ * function.
+ *
+ * @param queryIndex Index of point whose neighbors we are inserting into.
+ * @param pos Position in list to insert into.
+ * @param neighbor Index of reference point which is being inserted.
+ * @param distance Distance from query point to reference point.
+ */
+ void InsertNeighbor(const size_t queryIndex,
+ const size_t pos,
+ const size_t neighbor,
+ const double distance);
+
+ /**
+ * Compute the minimum number of samples required to guarantee
+ * the given rank-approximation and success probability.
+ *
+ * @param n Size of the set to be sampled from.
+ * @param k The number of neighbors required within the rank-approximation.
+ * @param tau The rank-approximation in percentile of the data.
+ * @param alpha The success probability desired.
+ */
+ size_t MinimumSamplesReqd(const size_t n,
+ const size_t k,
+ const double tau,
+ const double alpha) const;
+
+ /**
+ * Compute the success probability of obtaining 'k'-neighbors from a
+ * set of size 'n' within the top 't' neighbors if 'm' samples are made.
+ *
+ * @param n Size of the set being sampled from.
+ * @param k The number of neighbors required within the rank-approximation.
+ * @param m The number of random samples.
+ * @param t The desired rank-approximation.
+ */
+ double SuccessProbability(const size_t n,
+ const size_t k,
+ const size_t m,
+ const size_t t) const;
+
+ /**
+ * Pick up desired number of samples (with replacement) from a given range
+ * of integers so that only the distinct samples are returned from
+ * the range [0 - specified upper bound)
+ *
+ * @param numSamples Number of random samples.
+ * @param rangeUpperBound The upper bound on the range of integers.
+ * @param distinctSamples The list of the distinct samples.
+ */
+ void ObtainDistinctSamples(const size_t numSamples,
+ const size_t rangeUpperBound,
+ arma::uvec& distinctSamples) const;
+
+}; // class RASearchRules
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+// Include implementation.
+#include "ra_search_rules_impl.hpp"
+
+#endif // __MLPACK_METHODS_NEIGHBOR_SEARCH_RA_SEARCH_RULES_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search_rules_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/rann/ra_search_rules_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search_rules_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,1314 +0,0 @@
-/**
- * @file ra_search_rules_impl.hpp
- * @author Parikshit Ram
- *
- * Implementation of RASearchRules.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_RA_SEARCH_RULES_IMPL_HPP
-#define __MLPACK_METHODS_NEIGHBOR_SEARCH_RA_SEARCH_RULES_IMPL_HPP
-
-// In case it hasn't been included yet.
-#include "ra_search_rules.hpp"
-
-namespace mlpack {
-namespace neighbor {
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-RASearchRules<SortPolicy, MetricType, TreeType>::
-RASearchRules(const arma::mat& referenceSet,
- const arma::mat& querySet,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances,
- MetricType& metric,
- const double tau,
- const double alpha,
- const bool naive,
- const bool sampleAtLeaves,
- const bool firstLeafExact,
- const size_t singleSampleLimit) :
- referenceSet(referenceSet),
- querySet(querySet),
- neighbors(neighbors),
- distances(distances),
- metric(metric),
- sampleAtLeaves(sampleAtLeaves),
- firstLeafExact(firstLeafExact),
- singleSampleLimit(singleSampleLimit)
-{
- Timer::Start("computing_number_of_samples_reqd");
- numSamplesReqd = MinimumSamplesReqd(referenceSet.n_cols, neighbors.n_rows,
- tau, alpha);
- Timer::Stop("computing_number_of_samples_reqd");
-
- // initializing some stats to be collected during the search
- numSamplesMade = arma::zeros<arma::Col<size_t> >(querySet.n_cols);
- numDistComputations = 0;
- samplingRatio = (double) numSamplesReqd / (double) referenceSet.n_cols;
-
- Log::Info << "Minimum Samples Required per-query: " << numSamplesReqd <<
- ", Sampling Ratio: " << samplingRatio << std::endl;
-
- if (naive) // no tree traversal going to happen, just do naive sampling here.
- {
- // sample enough number of points
- for (size_t i = 0; i < querySet.n_cols; ++i)
- {
- arma::uvec distinctSamples;
- ObtainDistinctSamples(numSamplesReqd, referenceSet.n_cols,
- distinctSamples);
- for (size_t j = 0; j < distinctSamples.n_elem; j++)
- BaseCase(i, (size_t) distinctSamples[j]);
- }
- }
-}
-
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline force_inline
-void RASearchRules<SortPolicy, MetricType, TreeType>::
-ObtainDistinctSamples(const size_t numSamples,
- const size_t rangeUpperBound,
- arma::uvec& distinctSamples) const
-{
- // keep track of the points that are sampled
- arma::Col<size_t> sampledPoints;
- sampledPoints.zeros(rangeUpperBound);
-
- for (size_t i = 0; i < numSamples; i++)
- sampledPoints[(size_t) math::RandInt(rangeUpperBound)]++;
-
- distinctSamples = arma::find(sampledPoints > 0);
- return;
-}
-
-
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-size_t RASearchRules<SortPolicy, MetricType, TreeType>::
-MinimumSamplesReqd(const size_t n,
- const size_t k,
- const double tau,
- const double alpha) const
-{
- size_t ub = n; // the upper bound on the binary search
- size_t lb = k; // the lower bound on the binary search
- size_t m = lb; // the minimum number of random samples
-
- // The rank approximation
- size_t t = (size_t) std::ceil(tau * (double) n / 100.0);
-
- double prob;
- assert(alpha <= 1.0);
-
- // going through all values of sample sizes
- // to find the minimum samples required to satisfy the
- // desired bound
- bool done = false;
-
- // This performs a binary search on the integer values between 'lb = k'
- // and 'ub = n' to find the minimum number of samples 'm' required to obtain
- // the desired success probability 'alpha'.
- do
- {
- prob = SuccessProbability(n, k, m, t);
-
- if (prob > alpha)
- {
- if (prob - alpha < 0.001 || ub < lb + 2) {
- done = true;
- break;
- }
- else
- ub = m;
- }
- else
- {
- if (prob < alpha)
- {
- if (m == lb)
- {
- m++;
- continue;
- }
- else
- lb = m;
- }
- else
- {
- done = true;
- break;
- }
- }
- m = (ub + lb) / 2;
-
- } while (!done);
-
- return (m + 1);
-}
-
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-double RASearchRules<SortPolicy, MetricType, TreeType>::
-SuccessProbability(const size_t n,
- const size_t k,
- const size_t m,
- const size_t t) const
-{
- if (k == 1)
- {
- if (m > n - t)
- return 1.0;
-
- double eps = (double) t / (double) n;
-
- return 1.0 - std::pow(1.0 - eps, (double) m);
-
- } // faster implementation for topK = 1
- else
- {
- if (m < k)
- return 0.0;
-
- if (m > n - t + k)
- return 1.0;
-
- double eps = (double) t / (double) n;
- double sum = 0.0;
-
- // The probability that 'k' of the 'm' samples lie within the top 't'
- // of the neighbors is given by:
- // sum_{j = k}^m Choose(m, j) (t/n)^j (1 - t/n)^{m - j}
- // which is also equal to
- // 1 - sum_{j = 0}^{k - 1} Choose(m, j) (t/n)^j (1 - t/n)^{m - j}
- //
- // So this is a m - k term summation or a k term summation. So if
- // m > 2k, do the k term summation, otherwise do the m term summation.
-
- size_t lb;
- size_t ub;
- bool topHalf;
-
- if (2 * k < m)
- {
- // compute 1 - sum_{j = 0}^{k - 1} Choose(m, j) eps^j (1 - eps)^{m - j}
- // eps = t/n
- //
- // Choosing 'lb' as 1 and 'ub' as k so as to sum from 1 to (k - 1),
- // and add the term (1 - eps)^m term separately.
- lb = 1;
- ub = k;
- topHalf = true;
- sum = std::pow(1 - eps, (double) m);
- }
- else
- {
- // compute sum_{j = k}^m Choose(m, j) eps^j (1 - eps)^{m - j}
- // eps = t/n
- //
- // Choosing 'lb' as k and 'ub' as m so as to sum from k to (m - 1),
- // and add the term eps^m term separately.
- lb = k;
- ub = m;
- topHalf = false;
- sum = std::pow(eps, (double) m);
- }
-
- for (size_t j = lb; j < ub; j++)
- {
- // compute Choose(m, j)
- double mCj = (double) m;
- size_t jTrans;
-
- // if j < m - j, compute Choose(m, j)
- // if j > m - j, compute Choose(m, m - j)
- if (topHalf)
- jTrans = j;
- else
- jTrans = m - j;
-
- for(size_t i = 2; i <= jTrans; i++)
- {
- mCj *= (double) (m - (i - 1));
- mCj /= (double) i;
- }
-
- sum += (mCj * std::pow(eps, (double) j)
- * std::pow(1.0 - eps, (double) (m - j)));
- }
-
- if (topHalf)
- sum = 1.0 - sum;
-
- return sum;
- } // for k > 1
-} // FastComputeProb
-
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline force_inline
-double RASearchRules<SortPolicy, MetricType, TreeType>::
-BaseCase(const size_t queryIndex, const size_t referenceIndex)
-{
- // If the datasets are the same, then this search is only using one dataset
- // and we should not return identical points.
- if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
- return 0.0;
-
- double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
- referenceSet.unsafe_col(referenceIndex));
-
- // If this distance is better than any of the current candidates, the
- // SortDistance() function will give us the position to insert it into.
- arma::vec queryDist = distances.unsafe_col(queryIndex);
- size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
-
- // SortDistance() returns (size_t() - 1) if we shouldn't add it.
- if (insertPosition != (size_t() - 1))
- InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance);
-
- numSamplesMade[queryIndex]++;
-
- // TO REMOVE
- numDistComputations++;
-
- return distance;
-}
-
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double RASearchRules<SortPolicy, MetricType, TreeType>::
-Prescore(TreeType& queryNode,
- TreeType& referenceNode,
- TreeType& referenceChildNode,
- const double baseCaseResult) const
-{
- return 0.0;
-}
-
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double RASearchRules<SortPolicy, MetricType, TreeType>::
-PrescoreQ(TreeType& queryNode,
- TreeType& queryChildNode,
- TreeType& referenceNode,
- const double baseCaseResult) const
-{
- return 0.0;
-}
-
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double RASearchRules<SortPolicy, MetricType, TreeType>::
-Score(const size_t queryIndex, TreeType& referenceNode)
-{
- const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
- const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
- &referenceNode);
- const double bestDistance = distances(distances.n_rows - 1, queryIndex);
-
-
- // If this is better than the best distance we've seen so far,
- // maybe there will be something down this node.
- // Also check if enough samples are already made for this query.
- if (SortPolicy::IsBetter(distance, bestDistance)
- && numSamplesMade[queryIndex] < numSamplesReqd)
- {
- // We cannot prune this node
- // Try approximating this node by sampling
-
- // If you are required to visit the first leaf (to find possible
- // duplicates), make sure you do not approximate.
- if (numSamplesMade[queryIndex] > 0 || !firstLeafExact)
- {
- // check if this node can be approximated by sampling
- size_t samplesReqd =
- (size_t) std::ceil(samplingRatio * (double) referenceNode.Count());
- samplesReqd = std::min(samplesReqd,
- numSamplesReqd - numSamplesMade[queryIndex]);
-
- if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
- {
- // if too many samples required and not at a leaf, then can't prune
- return distance;
- }
- else
- {
- if (!referenceNode.IsLeaf()) // if not a leaf
- {
- // Then samplesReqd <= singleSampleLimit.
- // Hence approximate node by sampling enough number of points
- arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
- distinctSamples);
- for (size_t i = 0; i < distinctSamples.n_elem; i++)
- // The counting of the samples are done in the 'BaseCase' function
- // so no book-keeping is required here.
- BaseCase(queryIndex,
- referenceNode.Begin() + (size_t) distinctSamples[i]);
-
- // Node approximated so we can prune it
- return DBL_MAX;
- }
- else // we are at a leaf.
- {
- if (sampleAtLeaves) // if allowed to sample at leaves.
- {
- // Approximate node by sampling enough number of points
- arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
- distinctSamples);
- for (size_t i = 0; i < distinctSamples.n_elem; i++)
- // The counting of the samples are done in the 'BaseCase'
- // function so no book-keeping is required here.
- BaseCase(queryIndex,
- referenceNode.Begin() + (size_t) distinctSamples[i]);
-
- // (Leaf) Node approximated so can prune it
- return DBL_MAX;
- }
- else
- {
- // not allowed to sample from leaves, so cannot prune.
- return distance;
- } // sample at leaves
- } // if not-leaf
- } // if cannot-approximate by sampling
- }
- else
- {
- // try first to visit your first leaf to boost your accuracy
- // and find your (near) duplicates if they exist
- return distance;
- } // if first-leaf exact visit required
- }
- else
- {
- // Either there cannot be anything better in this node.
- // Or enough number of samples are already made.
- // So prune it.
-
- // add 'fake' samples from this node; fake because the distances to
- // these samples need not be computed.
-
- // If enough samples already made, this step does not change the
- // result of the search.
- numSamplesMade[queryIndex] +=
- (size_t) std::floor(samplingRatio * (double) referenceNode.Count());
-
- return DBL_MAX;
- } // if can-prune
-} // Score(point, node)
-
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double RASearchRules<SortPolicy, MetricType, TreeType>::
-Score(const size_t queryIndex,
- TreeType& referenceNode,
- const double baseCaseResult)
-{
-
- const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
- const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
- &referenceNode, baseCaseResult);
- const double bestDistance = distances(distances.n_rows - 1, queryIndex);
-
- // Hereon, this 'Score' function is exactly same as the previous
- // 'Score' function.
-
- // If this is better than the best distance we've seen so far,
- // maybe there will be something down this node.
- // Also check if enough samples are already made for this query.
- if (SortPolicy::IsBetter(distance, bestDistance)
- && numSamplesMade[queryIndex] < numSamplesReqd)
- {
- // We cannot prune this node
- // Try approximating this node by sampling
-
- // If you are required to visit the first leaf (to find possible
- // duplicates), make sure you do not approximate.
- if (numSamplesMade[queryIndex] > 0 || !firstLeafExact)
- {
- // Check if this node can be approximated by sampling:
- // 'referenceNode.Count() should correspond to the number of points
- // present in this subtree.
- size_t samplesReqd =
- (size_t) std::ceil(samplingRatio * (double) referenceNode.Count());
- samplesReqd = std::min(samplesReqd,
- numSamplesReqd - numSamplesMade[queryIndex]);
-
- if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
- {
- // if too many samples required and not at a leaf, then can't prune
- return distance;
- }
- else
- {
- if (!referenceNode.IsLeaf()) // if not a leaf
- {
- // Then samplesReqd <= singleSampleLimit.
- // Hence approximate node by sampling enough number of points
- // from this subtree.
- arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
- distinctSamples);
- for (size_t i = 0; i < distinctSamples.n_elem; i++)
- // The counting of the samples are done in the 'BaseCase'
- // function so no book-keeping is required here.
- BaseCase(queryIndex,
- referenceNode.Begin() + (size_t) distinctSamples[i]);
-
- // Node approximated so we can prune it
- return DBL_MAX;
- }
- else // we are at a leaf.
- {
- if (sampleAtLeaves) // if allowed to sample at leaves.
- {
- // Approximate node by sampling enough number of points
- arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
- distinctSamples);
- for (size_t i = 0; i < distinctSamples.n_elem; i++)
- // The counting of the samples are done in the 'BaseCase'
- // function so no book-keeping is required here.
- BaseCase(queryIndex,
- referenceNode.Begin() + (size_t) distinctSamples[i]);
-
- // (Leaf) Node approximated so can prune it
- return DBL_MAX;
- }
- else
- {
- // not allowed to sample from leaves, so cannot prune.
- return distance;
- } // sample at leaves
- } // if not-leaf
- } // if cannot-approximate by sampling
- }
- else
- {
- // try first to visit your first leaf to boost your accuracy
- return distance;
- } // if first-leaf exact visit required
- }
- else
- {
- // Either there cannot be anything better in this node.
- // Or enough number of samples are already made.
- // So prune it.
-
- // add 'fake' samples from this node; fake because the distances to
- // these samples need not be computed.
-
- // If enough samples already made, this step does not change the
- // result of the search.
- if (numSamplesMade[queryIndex] < numSamplesReqd)
- // add 'fake' samples from this node; fake because the distances to
- // these samples need not be computed.
- numSamplesMade[queryIndex] +=
- (size_t) std::floor(samplingRatio * (double) referenceNode.Count());
-
- return DBL_MAX;
- } // if can-prune
-
- return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
-
-} // Score(point, node, point-node-point-distance)
-
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double RASearchRules<SortPolicy, MetricType, TreeType>::
-Rescore(const size_t queryIndex,
- TreeType& referenceNode,
- const double oldScore)
-{
- // If we are already pruning, still prune.
- if (oldScore == DBL_MAX)
- return oldScore;
-
- // Just check the score again against the distances.
- const double bestDistance = distances(distances.n_rows - 1, queryIndex);
-
- // If this is better than the best distance we've seen so far,
- // maybe there will be something down this node.
- // Also check if enough samples are already made for this query.
- if (SortPolicy::IsBetter(oldScore, bestDistance)
- && numSamplesMade[queryIndex] < numSamplesReqd)
- {
- // We cannot prune this node
- // Try approximating this node by sampling
-
- // Here we assume that since we are re-scoring, the algorithm
- // has already sampled some candidates, and if specified, also
- // traversed to the first leaf.
- // So no checks regarding that is made any more.
- //
- // check if this node can be approximated by sampling
- size_t samplesReqd =
- (size_t) std::ceil(samplingRatio * (double) referenceNode.Count());
- samplesReqd = std::min(samplesReqd,
- numSamplesReqd - numSamplesMade[queryIndex]);
-
- if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
- {
- // if too many samples required and not at a leaf, then can't prune
- return oldScore;
- }
- else
- {
- if (!referenceNode.IsLeaf()) // if not a leaf
- {
- // Then samplesReqd <= singleSampleLimit.
- // Hence approximate node by sampling enough number of points
- arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
- distinctSamples);
- for (size_t i = 0; i < distinctSamples.n_elem; i++)
- // The counting of the samples are done in the 'BaseCase'
- // function so no book-keeping is required here.
- BaseCase(queryIndex,
- referenceNode.Begin() + (size_t) distinctSamples[i]);
-
- // Node approximated so we can prune it
- return DBL_MAX;
- }
- else // we are at a leaf.
- {
- if (sampleAtLeaves) // if allowed to sample at leaves.
- {
- // Approximate node by sampling enough number of points
- arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
- distinctSamples);
- for (size_t i = 0; i < distinctSamples.n_elem; i++)
- // The counting of the samples are done in the 'BaseCase'
- // function so no book-keeping is required here.
- BaseCase(queryIndex,
- referenceNode.Begin() + (size_t) distinctSamples[i]);
-
- // (Leaf) Node approximated so can prune it
- return DBL_MAX;
- }
- else
- {
- // not allowed to sample from leaves, so cannot prune.
- return oldScore;
- } // sample at leaves
- } // if not-leaf
- } // if cannot-approximate by sampling
- }
- else
- {
- // Either there cannot be anything better in this node.
- // Or enough number of samples are already made.
- // So prune it.
-
- // add 'fake' samples from this node; fake because the distances to
- // these samples need not be computed.
-
- // If enough samples already made, this step does not change the
- // result of the search.
- numSamplesMade[queryIndex] +=
- (size_t) std::floor(samplingRatio * (double) referenceNode.Count());
-
- return DBL_MAX;
- } // if can-prune
-} // Rescore(point, node, oldScore)
-
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double RASearchRules<SortPolicy, MetricType, TreeType>::
-Score(TreeType& queryNode, TreeType& referenceNode)
-{
- // First try to find the distance bound to check if we can prune
- // by distance.
-
- // finding the best node-to-node distance
- const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
- &referenceNode);
-
- double pointBound = DBL_MAX;
- double childBound = DBL_MAX;
- const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
-
- for (size_t i = 0; i < queryNode.NumPoints(); i++)
- {
- const double bound = distances(distances.n_rows - 1, queryNode.Point(i))
- + maxDescendantDistance;
- if (bound < pointBound)
- pointBound = bound;
- }
-
- for (size_t i = 0; i < queryNode.NumChildren(); i++)
- {
- const double bound = queryNode.Child(i).Stat().Bound();
- if (bound < childBound)
- childBound = bound;
- }
-
- // update the bound
- queryNode.Stat().Bound() = std::min(pointBound, childBound);
- const double bestDistance = queryNode.Stat().Bound();
-
- // update the number of samples made for that node
- // -- propagate up from child nodes if child nodes have made samples
- // that the parent node is not aware of.
- // REMEMBER to propagate down samples made to the child nodes
- // if 'queryNode' descend is deemed necessary.
-
- // only update from children if a non-leaf node obviously
- if (!queryNode.IsLeaf())
- {
- size_t numSamplesMadeInChildNodes = std::numeric_limits<size_t>::max();
-
- // Find the minimum number of samples made among all children
- for (size_t i = 0; i < queryNode.NumChildren(); i++)
- {
- const size_t numSamples = queryNode.Child(i).Stat().NumSamplesMade();
- if (numSamples < numSamplesMadeInChildNodes)
- numSamplesMadeInChildNodes = numSamples;
- }
-
- // The number of samples made for a node is propagated up from the
- // child nodes if the child nodes have made samples that the parent
- // (which is the current 'queryNode') is not aware of.
- queryNode.Stat().NumSamplesMade()
- = std::max(queryNode.Stat().NumSamplesMade(),
- numSamplesMadeInChildNodes);
- }
-
- // Now check if the node-pair interaction can be pruned
-
- // If this is better than the best distance we've seen so far,
- // maybe there will be something down this node.
- // Also check if enough samples are already made for this 'queryNode'.
- if (SortPolicy::IsBetter(distance, bestDistance)
- && queryNode.Stat().NumSamplesMade() < numSamplesReqd)
- {
- // We cannot prune this node
- // Try approximating this node by sampling
-
- // If you are required to visit the first leaf (to find possible
- // duplicates), make sure you do not approximate.
- if (queryNode.Stat().NumSamplesMade() > 0 || !firstLeafExact)
- {
- // check if this node can be approximated by sampling
- size_t samplesReqd =
- (size_t) std::ceil(samplingRatio * (double) referenceNode.Count());
- samplesReqd
- = std::min(samplesReqd,
- numSamplesReqd - queryNode.Stat().NumSamplesMade());
-
- if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
- {
- // if too many samples required and not at a leaf, then can't prune
-
- // Since query tree descend is necessary now,
- // propagate the number of samples made down to the children
-
- // Iterate through all children and propagate the number of
- // samples made to the children.
- // Only update if the parent node has made samples the children
- // have not seen
- for (size_t i = 0; i < queryNode.NumChildren(); i++)
- queryNode.Child(i).Stat().NumSamplesMade()
- = std::max(queryNode.Stat().NumSamplesMade(),
- queryNode.Child(i).Stat().NumSamplesMade());
-
- return distance;
- }
- else
- {
- if (!referenceNode.IsLeaf()) // if not a leaf
- {
- // Then samplesReqd <= singleSampleLimit.
- // Hence approximate node by sampling enough number of points
- // for every query in the 'queryNode'
- for (size_t queryIndex = queryNode.Begin();
- queryIndex < queryNode.End(); queryIndex++)
- {
- arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
- distinctSamples);
- for (size_t i = 0; i < distinctSamples.n_elem; i++)
- // The counting of the samples are done in the 'BaseCase'
- // function so no book-keeping is required here.
- BaseCase(queryIndex,
- referenceNode.Begin() + (size_t) distinctSamples[i]);
- }
-
- // update the number of samples made for the queryNode and
- // also update the number of sample made for the child nodes.
- queryNode.Stat().NumSamplesMade() += samplesReqd;
-
- // since you are not going to descend down the query tree for this
- // referenceNode, there is no point updating the number of
- // samples made for the child nodes of this queryNode.
-
- // Node approximated so we can prune it
- return DBL_MAX;
- }
- else // we are at a leaf.
- {
- if (sampleAtLeaves) // if allowed to sample at leaves.
- {
- // Approximate node by sampling enough number of points
- // for every query in the 'queryNode'.
- for (size_t queryIndex = queryNode.Begin();
- queryIndex < queryNode.End(); queryIndex++)
- {
- arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
- distinctSamples);
- for (size_t i = 0; i < distinctSamples.n_elem; i++)
- // The counting of the samples are done in the 'BaseCase'
- // function so no book-keeping is required here.
- BaseCase(queryIndex,
- referenceNode.Begin() + (size_t) distinctSamples[i]);
- }
-
- // update the number of samples made for the queryNode and
- // also update the number of sample made for the child nodes.
- queryNode.Stat().NumSamplesMade() += samplesReqd;
-
- // since you are not going to descend down the query tree for this
- // referenceNode, there is no point updating the number of
- // samples made for the child nodes of this queryNode.
-
- // (Leaf) Node approximated so can prune it
- return DBL_MAX;
- }
- else
- {
- // Not allowed to sample from leaves, so cannot prune.
- // Propagate the number of samples made down to the children
-
- // Go through all children and propagate the number of
- // samples made to the children.
- for (size_t i = 0; i < queryNode.NumChildren(); i++)
- queryNode.Child(i).Stat().NumSamplesMade()
- = std::max(queryNode.Stat().NumSamplesMade(),
- queryNode.Child(i).Stat().NumSamplesMade());
-
- return distance;
- } // sample at leaves
- } // if not-leaf
- } // if cannot-approximate by sampling
- }
- else
- {
- // Have to first to visit your first leaf to boost your accuracy
-
- // Propagate the number of samples made down to the children
-
- // Go through all children and propagate the number of
- // samples made to the children.
- for (size_t i = 0; i < queryNode.NumChildren(); i++)
- queryNode.Child(i).Stat().NumSamplesMade()
- = std::max(queryNode.Stat().NumSamplesMade(),
- queryNode.Child(i).Stat().NumSamplesMade());
-
- return distance;
- } // if first-leaf exact visit required
- }
- else
- {
- // Either there cannot be anything better in this node.
- // Or enough number of samples are already made.
- // So prune it.
-
- // add 'fake' samples from this node; fake because the distances to
- // these samples need not be computed.
-
- // If enough samples already made, this step does not change the
- // result of the search since this queryNode will never be
- // descended anymore.
- queryNode.Stat().NumSamplesMade() +=
- (size_t) std::floor(samplingRatio * (double) referenceNode.Count());
-
- // since you are not going to descend down the query tree for this
- // reference node, there is no point updating the number of samples
- // made for the child nodes of this queryNode.
-
- return DBL_MAX;
- } // if can-prune
-} // Score(node, node)
-
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double RASearchRules<SortPolicy, MetricType, TreeType>::
-Score(TreeType& queryNode,
- TreeType& referenceNode,
- const double baseCaseResult)
-{
- // First try to find the distance bound to check if we can prune
- // by distance.
-
- // find the best node-to-node distance
- const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
- &referenceNode,
- baseCaseResult);
- double pointBound = DBL_MAX;
- double childBound = DBL_MAX;
- const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
-
- for (size_t i = 0; i < queryNode.NumPoints(); i++)
- {
- const double bound = distances(distances.n_rows - 1, queryNode.Point(i))
- + maxDescendantDistance;
- if (bound < pointBound)
- pointBound = bound;
- }
-
- for (size_t i = 0; i < queryNode.NumChildren(); i++)
- {
- const double bound = queryNode.Child(i).Stat().Bound();
- if (bound < childBound)
- childBound = bound;
- }
-
- // update the bound
- queryNode.Stat().Bound() = std::min(pointBound, childBound);
- const double bestDistance = queryNode.Stat().Bound();
-
- // update the number of samples made for that node
- // -- propagate up from child nodes if child nodes have made samples
- // that the parent node is not aware of.
- // REMEMBER to propagate down samples made to the child nodes
- // if 'queryNode' descend is deemed necessary.
-
- // only update from children if a non-leaf node obviously
- if (!queryNode.IsLeaf())
- {
- size_t numSamplesMadeInChildNodes = std::numeric_limits<size_t>::max();
-
- // Find the minimum number of samples made among all children
- for (size_t i = 0; i < queryNode.NumChildren(); i++)
- {
- const size_t numSamples = queryNode.Child(i).Stat().NumSamplesMade();
- if (numSamples < numSamplesMadeInChildNodes)
- numSamplesMadeInChildNodes = numSamples;
- }
-
- // The number of samples made for a node is propagated up from the
- // child nodes if the child nodes have made samples that the parent
- // (which is the current 'queryNode') is not aware of.
- queryNode.Stat().NumSamplesMade()
- = std::max(queryNode.Stat().NumSamplesMade(),
- numSamplesMadeInChildNodes);
- }
-
- // Now check if the node-pair interaction can be pruned
-
- // If this is better than the best distance we've seen so far,
- // maybe there will be something down this node.
- // Also check if enough samples are already made for this 'queryNode'.
- if (SortPolicy::IsBetter(distance, bestDistance)
- && queryNode.Stat().NumSamplesMade() < numSamplesReqd)
- {
- // We cannot prune this node
- // Try approximating this node by sampling
-
- // If you are required to visit the first leaf (to find possible
- // duplicates), make sure you do not approximate.
- if (queryNode.Stat().NumSamplesMade() > 0 || !firstLeafExact)
- {
- // check if this node can be approximated by sampling
- size_t samplesReqd =
- (size_t) std::ceil(samplingRatio * (double) referenceNode.Count());
- samplesReqd
- = std::min(samplesReqd,
- numSamplesReqd - queryNode.Stat().NumSamplesMade());
-
- if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
- {
- // if too many samples required and not at a leaf, then can't prune
-
- // Since query tree descend is necessary now,
- // propagate the number of samples made down to the children
-
- // Iterate through all children and propagate the number of
- // samples made to the children.
- // Only update if the parent node has made samples the children
- // have not seen
- for (size_t i = 0; i < queryNode.NumChildren(); i++)
- queryNode.Child(i).Stat().NumSamplesMade()
- = std::max(queryNode.Stat().NumSamplesMade(),
- queryNode.Child(i).Stat().NumSamplesMade());
-
- return distance;
- }
- else
- {
- if (!referenceNode.IsLeaf()) // if not a leaf
- {
- // Then samplesReqd <= singleSampleLimit.
- // Hence approximate node by sampling enough number of points
- // for every query in the 'queryNode'
- for (size_t queryIndex = queryNode.Begin();
- queryIndex < queryNode.End(); queryIndex++)
- {
- arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
- distinctSamples);
- for (size_t i = 0; i < distinctSamples.n_elem; i++)
- // The counting of the samples are done in the 'BaseCase'
- // function so no book-keeping is required here.
- BaseCase(queryIndex,
- referenceNode.Begin() + (size_t) distinctSamples[i]);
- }
-
- // update the number of samples made for the queryNode and
- // also update the number of sample made for the child nodes.
- queryNode.Stat().NumSamplesMade() += samplesReqd;
-
- // since you are not going to descend down the query tree for this
- // referenceNode, there is no point updating the number of
- // samples made for the child nodes of this queryNode.
-
- // Node approximated so we can prune it
- return DBL_MAX;
- }
- else // we are at a leaf.
- {
- if (sampleAtLeaves) // if allowed to sample at leaves.
- {
- // Approximate node by sampling enough number of points
- // for every query in the 'queryNode'.
- for (size_t queryIndex = queryNode.Begin();
- queryIndex < queryNode.End(); queryIndex++)
- {
- arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
- distinctSamples);
- for (size_t i = 0; i < distinctSamples.n_elem; i++)
- // The counting of the samples are done in the 'BaseCase'
- // function so no book-keeping is required here.
- BaseCase(queryIndex,
- referenceNode.Begin() + (size_t) distinctSamples[i]);
- }
-
- // update the number of samples made for the queryNode and
- // also update the number of sample made for the child nodes.
- queryNode.Stat().NumSamplesMade() += samplesReqd;
-
- // since you are not going to descend down the query tree for this
- // referenceNode, there is no point updating the number of
- // samples made for the child nodes of this queryNode.
-
- // (Leaf) Node approximated so can prune it
- return DBL_MAX;
- }
- else
- {
- // Not allowed to sample from leaves, so cannot prune.
- // Propagate the number of samples made down to the children
-
- // Go through all children and propagate the number of
- // samples made to the children.
- for (size_t i = 0; i < queryNode.NumChildren(); i++)
- queryNode.Child(i).Stat().NumSamplesMade()
- = std::max(queryNode.Stat().NumSamplesMade(),
- queryNode.Child(i).Stat().NumSamplesMade());
-
- return distance;
- } // sample at leaves
- } // if not-leaf
- } // if cannot-approximate by sampling
- }
- else
- {
- // Have to first to visit your first leaf to boost your accuracy
-
- // Propagate the number of samples made down to the children
-
- // Go through all children and propagate the number of
- // samples made to the children.
- for (size_t i = 0; i < queryNode.NumChildren(); i++)
- queryNode.Child(i).Stat().NumSamplesMade()
- = std::max(queryNode.Stat().NumSamplesMade(),
- queryNode.Child(i).Stat().NumSamplesMade());
-
- return distance;
- } // if first-leaf exact visit required
- }
- else
- {
- // Either there cannot be anything better in this node.
- // Or enough number of samples are already made.
- // So prune it.
-
- // add 'fake' samples from this node; fake because the distances to
- // these samples need not be computed.
-
- // If enough samples already made, this step does not change the
- // result of the search since this queryNode will never be
- // descended anymore.
- queryNode.Stat().NumSamplesMade() +=
- (size_t) std::floor(samplingRatio * (double) referenceNode.Count());
-
- // since you are not going to descend down the query tree for this
- // reference node, there is no point updating the number of samples
- // made for the child nodes of this queryNode.
-
- return DBL_MAX;
- } // if can-prune
-} // Score(node, node, baseCaseResult)
-
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double RASearchRules<SortPolicy, MetricType, TreeType>::
-Rescore(TreeType& queryNode,
- TreeType& referenceNode,
- const double oldScore)
-{
- if (oldScore == DBL_MAX)
- return oldScore;
-
- // First try to find the distance bound to check if we can prune
- // by distance.
- double pointBound = DBL_MAX;
- double childBound = DBL_MAX;
- const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
-
- for (size_t i = 0; i < queryNode.NumPoints(); i++)
- {
- const double bound = distances(distances.n_rows - 1, queryNode.Point(i))
- + maxDescendantDistance;
- if (bound < pointBound)
- pointBound = bound;
- }
-
- for (size_t i = 0; i < queryNode.NumChildren(); i++)
- {
- const double bound = queryNode.Child(i).Stat().Bound();
- if (bound < childBound)
- childBound = bound;
- }
-
- // update the bound
- queryNode.Stat().Bound() = std::min(pointBound, childBound);
- const double bestDistance = queryNode.Stat().Bound();
-
- // Now check if the node-pair interaction can be pruned by sampling
- // update the number of samples made for that node
- // -- propagate up from child nodes if child nodes have made samples
- // that the parent node is not aware of.
- // REMEMBER to propagate down samples made to the child nodes
- // if the parent samples.
-
- // only update from children if a non-leaf node obviously
- if (!queryNode.IsLeaf())
- {
- size_t numSamplesMadeInChildNodes = std::numeric_limits<size_t>::max();
-
- // Find the minimum number of samples made among all children
- for (size_t i = 0; i < queryNode.NumChildren(); i++)
- {
- const size_t numSamples = queryNode.Child(i).Stat().NumSamplesMade();
- if (numSamples < numSamplesMadeInChildNodes)
- numSamplesMadeInChildNodes = numSamples;
- }
-
- // The number of samples made for a node is propagated up from the
- // child nodes if the child nodes have made samples that the parent
- // (which is the current 'queryNode') is not aware of.
- queryNode.Stat().NumSamplesMade()
- = std::max(queryNode.Stat().NumSamplesMade(),
- numSamplesMadeInChildNodes);
- }
-
- // Now check if the node-pair interaction can be pruned by sampling
-
- // If this is better than the best distance we've seen so far,
- // maybe there will be something down this node.
- // Also check if enough samples are already made for this query.
- if (SortPolicy::IsBetter(oldScore, bestDistance)
- && queryNode.Stat().NumSamplesMade() < numSamplesReqd)
- {
- // We cannot prune this node
- // Try approximating this node by sampling
-
- // Here we assume that since we are re-scoring, the algorithm
- // has already sampled some candidates, and if specified, also
- // traversed to the first leaf.
- // So no checks regarding that is made any more.
- //
- // check if this node can be approximated by sampling
- size_t samplesReqd =
- (size_t) std::ceil(samplingRatio * (double) referenceNode.Count());
- samplesReqd
- = std::min(samplesReqd,
- numSamplesReqd - queryNode.Stat().NumSamplesMade());
-
- if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
- {
- // if too many samples required and not at a leaf, then can't prune
-
- // Since query tree descend is necessary now,
- // propagate the number of samples made down to the children
-
- // Go through all children and propagate the number of
- // samples made to the children.
- // Only update if the parent node has made samples the children
- // have not seen
- for (size_t i = 0; i < queryNode.NumChildren(); i++)
- queryNode.Child(i).Stat().NumSamplesMade()
- = std::max(queryNode.Stat().NumSamplesMade(),
- queryNode.Child(i).Stat().NumSamplesMade());
-
- return oldScore;
- }
- else
- {
- if (!referenceNode.IsLeaf()) // if not a leaf
- {
- // Then samplesReqd <= singleSampleLimit.
- // Hence approximate node by sampling enough number of points
- // for every query in the 'queryNode'
- for (size_t queryIndex = queryNode.Begin();
- queryIndex < queryNode.End(); queryIndex++)
- {
- arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
- distinctSamples);
- for (size_t i = 0; i < distinctSamples.n_elem; i++)
- // The counting of the samples are done in the 'BaseCase'
- // function so no book-keeping is required here.
- BaseCase(queryIndex,
- referenceNode.Begin() + (size_t) distinctSamples[i]);
- }
-
- // update the number of samples made for the queryNode and
- // also update the number of sample made for the child nodes.
- queryNode.Stat().NumSamplesMade() += samplesReqd;
-
- // since you are not going to descend down the query tree for this
- // referenceNode, there is no point updating the number of
- // samples made for the child nodes of this queryNode.
-
- // Node approximated so we can prune it
- return DBL_MAX;
- }
- else // we are at a leaf.
- {
- if (sampleAtLeaves) // if allowed to sample at leaves.
- {
- // Approximate node by sampling enough number of points
- // for every query in the 'queryNode'.
- for (size_t queryIndex = queryNode.Begin();
- queryIndex < queryNode.End(); queryIndex++)
- {
- arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
- distinctSamples);
- for (size_t i = 0; i < distinctSamples.n_elem; i++)
- // The counting of the samples are done in the 'BaseCase'
- // function so no book-keeping is required here.
- BaseCase(queryIndex,
- referenceNode.Begin() + (size_t) distinctSamples[i]);
- }
-
- // update the number of samples made for the queryNode and
- // also update the number of sample made for the child nodes.
- queryNode.Stat().NumSamplesMade() += samplesReqd;
-
- // since you are not going to descend down the query tree for this
- // referenceNode, there is no point updating the number of
- // samples made for the child nodes of this queryNode.
-
- // (Leaf) Node approximated so can prune it
- return DBL_MAX;
- }
- else
- {
- // not allowed to sample from leaves, so cannot prune.
- // propagate the number of samples made down to the children
-
- // going through all children and propagate the number of
- // samples made to the children.
- for (size_t i = 0; i < queryNode.NumChildren(); i++)
- queryNode.Child(i).Stat().NumSamplesMade()
- = std::max(queryNode.Stat().NumSamplesMade(),
- queryNode.Child(i).Stat().NumSamplesMade());
-
- return oldScore;
- } // sample at leaves
- } // if not-leaf
- } // if cannot-approximate by sampling
- }
- else
- {
- // Either there cannot be anything better in this node.
- // Or enough number of samples are already made.
- // So prune it.
-
- // add 'fake' samples from this node; fake because the distances to
- // these samples need not be computed.
-
- // If enough samples already made, this step does not change the
- // result of the search since this queryNode will never be
- // descended anymore.
- queryNode.Stat().NumSamplesMade() +=
- (size_t) std::floor(samplingRatio * (double) referenceNode.Count());
-
- // since you are not going to descend down the query tree for this
- // reference node, there is no point updating the number of samples
- // made for the child nodes of this queryNode.
-
- return DBL_MAX;
- } // if can-prune
-} // Rescore(node, node, oldScore)
-
-
-/**
- * Helper function to insert a point into the neighbors and distances matrices.
- *
- * @param queryIndex Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
- * @param neighbor Index of reference point which is being inserted.
- * @param distance Distance from query point to reference point.
- */
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void RASearchRules<SortPolicy, MetricType, TreeType>::InsertNeighbor(
- const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance)
-{
- // We only memmove() if there is actually a need to shift something.
- if (pos < (distances.n_rows - 1))
- {
- int len = (distances.n_rows - 1) - pos;
- memmove(distances.colptr(queryIndex) + (pos + 1),
- distances.colptr(queryIndex) + pos,
- sizeof(double) * len);
- memmove(neighbors.colptr(queryIndex) + (pos + 1),
- neighbors.colptr(queryIndex) + pos,
- sizeof(size_t) * len);
- }
-
- // Now put the new information in the right index.
- distances(pos, queryIndex) = distance;
- neighbors(pos, queryIndex) = neighbor;
-}
-
-}; // namespace neighbor
-}; // namespace mlpack
-
-#endif // __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search_rules_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/rann/ra_search_rules_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search_rules_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_search_rules_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,1314 @@
+/**
+ * @file ra_search_rules_impl.hpp
+ * @author Parikshit Ram
+ *
+ * Implementation of RASearchRules.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_RA_SEARCH_RULES_IMPL_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_RA_SEARCH_RULES_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "ra_search_rules.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+RASearchRules<SortPolicy, MetricType, TreeType>::
+RASearchRules(const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances,
+ MetricType& metric,
+ const double tau,
+ const double alpha,
+ const bool naive,
+ const bool sampleAtLeaves,
+ const bool firstLeafExact,
+ const size_t singleSampleLimit) :
+ referenceSet(referenceSet),
+ querySet(querySet),
+ neighbors(neighbors),
+ distances(distances),
+ metric(metric),
+ sampleAtLeaves(sampleAtLeaves),
+ firstLeafExact(firstLeafExact),
+ singleSampleLimit(singleSampleLimit)
+{
+ Timer::Start("computing_number_of_samples_reqd");
+ numSamplesReqd = MinimumSamplesReqd(referenceSet.n_cols, neighbors.n_rows,
+ tau, alpha);
+ Timer::Stop("computing_number_of_samples_reqd");
+
+ // initializing some stats to be collected during the search
+ numSamplesMade = arma::zeros<arma::Col<size_t> >(querySet.n_cols);
+ numDistComputations = 0;
+ samplingRatio = (double) numSamplesReqd / (double) referenceSet.n_cols;
+
+ Log::Info << "Minimum Samples Required per-query: " << numSamplesReqd <<
+ ", Sampling Ratio: " << samplingRatio << std::endl;
+
+ if (naive) // no tree traversal going to happen, just do naive sampling here.
+ {
+ // sample enough number of points
+ for (size_t i = 0; i < querySet.n_cols; ++i)
+ {
+ arma::uvec distinctSamples;
+ ObtainDistinctSamples(numSamplesReqd, referenceSet.n_cols,
+ distinctSamples);
+ for (size_t j = 0; j < distinctSamples.n_elem; j++)
+ BaseCase(i, (size_t) distinctSamples[j]);
+ }
+ }
+}
+
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline force_inline
+void RASearchRules<SortPolicy, MetricType, TreeType>::
+ObtainDistinctSamples(const size_t numSamples,
+ const size_t rangeUpperBound,
+ arma::uvec& distinctSamples) const
+{
+ // keep track of the points that are sampled
+ arma::Col<size_t> sampledPoints;
+ sampledPoints.zeros(rangeUpperBound);
+
+ for (size_t i = 0; i < numSamples; i++)
+ sampledPoints[(size_t) math::RandInt(rangeUpperBound)]++;
+
+ distinctSamples = arma::find(sampledPoints > 0);
+ return;
+}
+
+
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+size_t RASearchRules<SortPolicy, MetricType, TreeType>::
+MinimumSamplesReqd(const size_t n,
+ const size_t k,
+ const double tau,
+ const double alpha) const
+{
+ size_t ub = n; // the upper bound on the binary search
+ size_t lb = k; // the lower bound on the binary search
+ size_t m = lb; // the minimum number of random samples
+
+ // The rank approximation
+ size_t t = (size_t) std::ceil(tau * (double) n / 100.0);
+
+ double prob;
+ assert(alpha <= 1.0);
+
+ // going through all values of sample sizes
+ // to find the minimum samples required to satisfy the
+ // desired bound
+ bool done = false;
+
+ // This performs a binary search on the integer values between 'lb = k'
+ // and 'ub = n' to find the minimum number of samples 'm' required to obtain
+ // the desired success probability 'alpha'.
+ do
+ {
+ prob = SuccessProbability(n, k, m, t);
+
+ if (prob > alpha)
+ {
+ if (prob - alpha < 0.001 || ub < lb + 2) {
+ done = true;
+ break;
+ }
+ else
+ ub = m;
+ }
+ else
+ {
+ if (prob < alpha)
+ {
+ if (m == lb)
+ {
+ m++;
+ continue;
+ }
+ else
+ lb = m;
+ }
+ else
+ {
+ done = true;
+ break;
+ }
+ }
+ m = (ub + lb) / 2;
+
+ } while (!done);
+
+ return (m + 1);
+}
+
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+double RASearchRules<SortPolicy, MetricType, TreeType>::
+SuccessProbability(const size_t n,
+ const size_t k,
+ const size_t m,
+ const size_t t) const
+{
+ if (k == 1)
+ {
+ if (m > n - t)
+ return 1.0;
+
+ double eps = (double) t / (double) n;
+
+ return 1.0 - std::pow(1.0 - eps, (double) m);
+
+ } // faster implementation for topK = 1
+ else
+ {
+ if (m < k)
+ return 0.0;
+
+ if (m > n - t + k)
+ return 1.0;
+
+ double eps = (double) t / (double) n;
+ double sum = 0.0;
+
+ // The probability that 'k' of the 'm' samples lie within the top 't'
+ // of the neighbors is given by:
+ // sum_{j = k}^m Choose(m, j) (t/n)^j (1 - t/n)^{m - j}
+ // which is also equal to
+ // 1 - sum_{j = 0}^{k - 1} Choose(m, j) (t/n)^j (1 - t/n)^{m - j}
+ //
+ // So this is a m - k term summation or a k term summation. So if
+ // m > 2k, do the k term summation, otherwise do the m term summation.
+
+ size_t lb;
+ size_t ub;
+ bool topHalf;
+
+ if (2 * k < m)
+ {
+ // compute 1 - sum_{j = 0}^{k - 1} Choose(m, j) eps^j (1 - eps)^{m - j}
+ // eps = t/n
+ //
+ // Choosing 'lb' as 1 and 'ub' as k so as to sum from 1 to (k - 1),
+ // and add the term (1 - eps)^m term separately.
+ lb = 1;
+ ub = k;
+ topHalf = true;
+ sum = std::pow(1 - eps, (double) m);
+ }
+ else
+ {
+ // compute sum_{j = k}^m Choose(m, j) eps^j (1 - eps)^{m - j}
+ // eps = t/n
+ //
+ // Choosing 'lb' as k and 'ub' as m so as to sum from k to (m - 1),
+ // and add the term eps^m term separately.
+ lb = k;
+ ub = m;
+ topHalf = false;
+ sum = std::pow(eps, (double) m);
+ }
+
+ for (size_t j = lb; j < ub; j++)
+ {
+ // compute Choose(m, j)
+ double mCj = (double) m;
+ size_t jTrans;
+
+ // if j < m - j, compute Choose(m, j)
+ // if j > m - j, compute Choose(m, m - j)
+ if (topHalf)
+ jTrans = j;
+ else
+ jTrans = m - j;
+
+ for(size_t i = 2; i <= jTrans; i++)
+ {
+ mCj *= (double) (m - (i - 1));
+ mCj /= (double) i;
+ }
+
+ sum += (mCj * std::pow(eps, (double) j)
+ * std::pow(1.0 - eps, (double) (m - j)));
+ }
+
+ if (topHalf)
+ sum = 1.0 - sum;
+
+ return sum;
+ } // for k > 1
+} // FastComputeProb
+
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline force_inline
+double RASearchRules<SortPolicy, MetricType, TreeType>::
+BaseCase(const size_t queryIndex, const size_t referenceIndex)
+{
+ // If the datasets are the same, then this search is only using one dataset
+ // and we should not return identical points.
+ if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
+ return 0.0;
+
+ double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
+ referenceSet.unsafe_col(referenceIndex));
+
+ // If this distance is better than any of the current candidates, the
+ // SortDistance() function will give us the position to insert it into.
+ arma::vec queryDist = distances.unsafe_col(queryIndex);
+ size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
+
+ // SortDistance() returns (size_t() - 1) if we shouldn't add it.
+ if (insertPosition != (size_t() - 1))
+ InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance);
+
+ numSamplesMade[queryIndex]++;
+
+ // TO REMOVE
+ numDistComputations++;
+
+ return distance;
+}
+
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double RASearchRules<SortPolicy, MetricType, TreeType>::
+Prescore(TreeType& queryNode,
+ TreeType& referenceNode,
+ TreeType& referenceChildNode,
+ const double baseCaseResult) const
+{
+ return 0.0;
+}
+
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double RASearchRules<SortPolicy, MetricType, TreeType>::
+PrescoreQ(TreeType& queryNode,
+ TreeType& queryChildNode,
+ TreeType& referenceNode,
+ const double baseCaseResult) const
+{
+ return 0.0;
+}
+
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double RASearchRules<SortPolicy, MetricType, TreeType>::
+Score(const size_t queryIndex, TreeType& referenceNode)
+{
+ const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+ const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
+ &referenceNode);
+ const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+
+
+ // If this is better than the best distance we've seen so far,
+ // maybe there will be something down this node.
+ // Also check if enough samples are already made for this query.
+ if (SortPolicy::IsBetter(distance, bestDistance)
+ && numSamplesMade[queryIndex] < numSamplesReqd)
+ {
+ // We cannot prune this node
+ // Try approximating this node by sampling
+
+ // If you are required to visit the first leaf (to find possible
+ // duplicates), make sure you do not approximate.
+ if (numSamplesMade[queryIndex] > 0 || !firstLeafExact)
+ {
+ // check if this node can be approximated by sampling
+ size_t samplesReqd =
+ (size_t) std::ceil(samplingRatio * (double) referenceNode.Count());
+ samplesReqd = std::min(samplesReqd,
+ numSamplesReqd - numSamplesMade[queryIndex]);
+
+ if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
+ {
+ // if too many samples required and not at a leaf, then can't prune
+ return distance;
+ }
+ else
+ {
+ if (!referenceNode.IsLeaf()) // if not a leaf
+ {
+ // Then samplesReqd <= singleSampleLimit.
+ // Hence approximate node by sampling enough number of points
+ arma::uvec distinctSamples;
+ ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+ distinctSamples);
+ for (size_t i = 0; i < distinctSamples.n_elem; i++)
+ // The counting of the samples are done in the 'BaseCase' function
+ // so no book-keeping is required here.
+ BaseCase(queryIndex,
+ referenceNode.Begin() + (size_t) distinctSamples[i]);
+
+ // Node approximated so we can prune it
+ return DBL_MAX;
+ }
+ else // we are at a leaf.
+ {
+ if (sampleAtLeaves) // if allowed to sample at leaves.
+ {
+ // Approximate node by sampling enough number of points
+ arma::uvec distinctSamples;
+ ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+ distinctSamples);
+ for (size_t i = 0; i < distinctSamples.n_elem; i++)
+ // The counting of the samples are done in the 'BaseCase'
+ // function so no book-keeping is required here.
+ BaseCase(queryIndex,
+ referenceNode.Begin() + (size_t) distinctSamples[i]);
+
+ // (Leaf) Node approximated so can prune it
+ return DBL_MAX;
+ }
+ else
+ {
+ // not allowed to sample from leaves, so cannot prune.
+ return distance;
+ } // sample at leaves
+ } // if not-leaf
+ } // if cannot-approximate by sampling
+ }
+ else
+ {
+ // try first to visit your first leaf to boost your accuracy
+ // and find your (near) duplicates if they exist
+ return distance;
+ } // if first-leaf exact visit required
+ }
+ else
+ {
+ // Either there cannot be anything better in this node.
+ // Or enough number of samples are already made.
+ // So prune it.
+
+ // add 'fake' samples from this node; fake because the distances to
+ // these samples need not be computed.
+
+ // If enough samples already made, this step does not change the
+ // result of the search.
+ numSamplesMade[queryIndex] +=
+ (size_t) std::floor(samplingRatio * (double) referenceNode.Count());
+
+ return DBL_MAX;
+ } // if can-prune
+} // Score(point, node)
+
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double RASearchRules<SortPolicy, MetricType, TreeType>::
+Score(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double baseCaseResult)
+{
+
+ const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+ const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
+ &referenceNode, baseCaseResult);
+ const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+
+ // Hereon, this 'Score' function is exactly same as the previous
+ // 'Score' function.
+
+ // If this is better than the best distance we've seen so far,
+ // maybe there will be something down this node.
+ // Also check if enough samples are already made for this query.
+ if (SortPolicy::IsBetter(distance, bestDistance)
+ && numSamplesMade[queryIndex] < numSamplesReqd)
+ {
+ // We cannot prune this node
+ // Try approximating this node by sampling
+
+ // If you are required to visit the first leaf (to find possible
+ // duplicates), make sure you do not approximate.
+ if (numSamplesMade[queryIndex] > 0 || !firstLeafExact)
+ {
+ // Check if this node can be approximated by sampling:
+ // 'referenceNode.Count() should correspond to the number of points
+ // present in this subtree.
+ size_t samplesReqd =
+ (size_t) std::ceil(samplingRatio * (double) referenceNode.Count());
+ samplesReqd = std::min(samplesReqd,
+ numSamplesReqd - numSamplesMade[queryIndex]);
+
+ if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
+ {
+ // if too many samples required and not at a leaf, then can't prune
+ return distance;
+ }
+ else
+ {
+ if (!referenceNode.IsLeaf()) // if not a leaf
+ {
+ // Then samplesReqd <= singleSampleLimit.
+ // Hence approximate node by sampling enough number of points
+ // from this subtree.
+ arma::uvec distinctSamples;
+ ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+ distinctSamples);
+ for (size_t i = 0; i < distinctSamples.n_elem; i++)
+ // The counting of the samples are done in the 'BaseCase'
+ // function so no book-keeping is required here.
+ BaseCase(queryIndex,
+ referenceNode.Begin() + (size_t) distinctSamples[i]);
+
+ // Node approximated so we can prune it
+ return DBL_MAX;
+ }
+ else // we are at a leaf.
+ {
+ if (sampleAtLeaves) // if allowed to sample at leaves.
+ {
+ // Approximate node by sampling enough number of points
+ arma::uvec distinctSamples;
+ ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+ distinctSamples);
+ for (size_t i = 0; i < distinctSamples.n_elem; i++)
+ // The counting of the samples are done in the 'BaseCase'
+ // function so no book-keeping is required here.
+ BaseCase(queryIndex,
+ referenceNode.Begin() + (size_t) distinctSamples[i]);
+
+ // (Leaf) Node approximated so can prune it
+ return DBL_MAX;
+ }
+ else
+ {
+ // not allowed to sample from leaves, so cannot prune.
+ return distance;
+ } // sample at leaves
+ } // if not-leaf
+ } // if cannot-approximate by sampling
+ }
+ else
+ {
+ // try first to visit your first leaf to boost your accuracy
+ return distance;
+ } // if first-leaf exact visit required
+ }
+ else
+ {
+ // Either there cannot be anything better in this node.
+ // Or enough number of samples are already made.
+ // So prune it.
+
+ // add 'fake' samples from this node; fake because the distances to
+ // these samples need not be computed.
+
+ // If enough samples already made, this step does not change the
+ // result of the search.
+ if (numSamplesMade[queryIndex] < numSamplesReqd)
+ // add 'fake' samples from this node; fake because the distances to
+ // these samples need not be computed.
+ numSamplesMade[queryIndex] +=
+ (size_t) std::floor(samplingRatio * (double) referenceNode.Count());
+
+ return DBL_MAX;
+ } // if can-prune
+
+ return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
+
+} // Score(point, node, point-node-point-distance)
+
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double RASearchRules<SortPolicy, MetricType, TreeType>::
+Rescore(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double oldScore)
+{
+ // If we are already pruning, still prune.
+ if (oldScore == DBL_MAX)
+ return oldScore;
+
+ // Just check the score again against the distances.
+ const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+
+ // If this is better than the best distance we've seen so far,
+ // maybe there will be something down this node.
+ // Also check if enough samples are already made for this query.
+ if (SortPolicy::IsBetter(oldScore, bestDistance)
+ && numSamplesMade[queryIndex] < numSamplesReqd)
+ {
+ // We cannot prune this node
+ // Try approximating this node by sampling
+
+ // Here we assume that since we are re-scoring, the algorithm
+ // has already sampled some candidates, and if specified, also
+ // traversed to the first leaf.
+ // So no checks regarding that is made any more.
+ //
+ // check if this node can be approximated by sampling
+ size_t samplesReqd =
+ (size_t) std::ceil(samplingRatio * (double) referenceNode.Count());
+ samplesReqd = std::min(samplesReqd,
+ numSamplesReqd - numSamplesMade[queryIndex]);
+
+ if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
+ {
+ // if too many samples required and not at a leaf, then can't prune
+ return oldScore;
+ }
+ else
+ {
+ if (!referenceNode.IsLeaf()) // if not a leaf
+ {
+ // Then samplesReqd <= singleSampleLimit.
+ // Hence approximate node by sampling enough number of points
+ arma::uvec distinctSamples;
+ ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+ distinctSamples);
+ for (size_t i = 0; i < distinctSamples.n_elem; i++)
+ // The counting of the samples are done in the 'BaseCase'
+ // function so no book-keeping is required here.
+ BaseCase(queryIndex,
+ referenceNode.Begin() + (size_t) distinctSamples[i]);
+
+ // Node approximated so we can prune it
+ return DBL_MAX;
+ }
+ else // we are at a leaf.
+ {
+ if (sampleAtLeaves) // if allowed to sample at leaves.
+ {
+ // Approximate node by sampling enough number of points
+ arma::uvec distinctSamples;
+ ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+ distinctSamples);
+ for (size_t i = 0; i < distinctSamples.n_elem; i++)
+ // The counting of the samples are done in the 'BaseCase'
+ // function so no book-keeping is required here.
+ BaseCase(queryIndex,
+ referenceNode.Begin() + (size_t) distinctSamples[i]);
+
+ // (Leaf) Node approximated so can prune it
+ return DBL_MAX;
+ }
+ else
+ {
+ // not allowed to sample from leaves, so cannot prune.
+ return oldScore;
+ } // sample at leaves
+ } // if not-leaf
+ } // if cannot-approximate by sampling
+ }
+ else
+ {
+ // Either there cannot be anything better in this node.
+ // Or enough number of samples are already made.
+ // So prune it.
+
+ // add 'fake' samples from this node; fake because the distances to
+ // these samples need not be computed.
+
+ // If enough samples already made, this step does not change the
+ // result of the search.
+ numSamplesMade[queryIndex] +=
+ (size_t) std::floor(samplingRatio * (double) referenceNode.Count());
+
+ return DBL_MAX;
+ } // if can-prune
+} // Rescore(point, node, oldScore)
+
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double RASearchRules<SortPolicy, MetricType, TreeType>::
+Score(TreeType& queryNode, TreeType& referenceNode)
+{
+ // First try to find the distance bound to check if we can prune
+ // by distance.
+
+ // finding the best node-to-node distance
+ const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
+ &referenceNode);
+
+ double pointBound = DBL_MAX;
+ double childBound = DBL_MAX;
+ const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
+
+ for (size_t i = 0; i < queryNode.NumPoints(); i++)
+ {
+ const double bound = distances(distances.n_rows - 1, queryNode.Point(i))
+ + maxDescendantDistance;
+ if (bound < pointBound)
+ pointBound = bound;
+ }
+
+ for (size_t i = 0; i < queryNode.NumChildren(); i++)
+ {
+ const double bound = queryNode.Child(i).Stat().Bound();
+ if (bound < childBound)
+ childBound = bound;
+ }
+
+ // update the bound
+ queryNode.Stat().Bound() = std::min(pointBound, childBound);
+ const double bestDistance = queryNode.Stat().Bound();
+
+ // update the number of samples made for that node
+ // -- propagate up from child nodes if child nodes have made samples
+ // that the parent node is not aware of.
+ // REMEMBER to propagate down samples made to the child nodes
+ // if 'queryNode' descend is deemed necessary.
+
+ // only update from children if a non-leaf node obviously
+ if (!queryNode.IsLeaf())
+ {
+ size_t numSamplesMadeInChildNodes = std::numeric_limits<size_t>::max();
+
+ // Find the minimum number of samples made among all children
+ for (size_t i = 0; i < queryNode.NumChildren(); i++)
+ {
+ const size_t numSamples = queryNode.Child(i).Stat().NumSamplesMade();
+ if (numSamples < numSamplesMadeInChildNodes)
+ numSamplesMadeInChildNodes = numSamples;
+ }
+
+ // The number of samples made for a node is propagated up from the
+ // child nodes if the child nodes have made samples that the parent
+ // (which is the current 'queryNode') is not aware of.
+ queryNode.Stat().NumSamplesMade()
+ = std::max(queryNode.Stat().NumSamplesMade(),
+ numSamplesMadeInChildNodes);
+ }
+
+ // Now check if the node-pair interaction can be pruned
+
+ // If this is better than the best distance we've seen so far,
+ // maybe there will be something down this node.
+ // Also check if enough samples are already made for this 'queryNode'.
+ if (SortPolicy::IsBetter(distance, bestDistance)
+ && queryNode.Stat().NumSamplesMade() < numSamplesReqd)
+ {
+ // We cannot prune this node
+ // Try approximating this node by sampling
+
+ // If you are required to visit the first leaf (to find possible
+ // duplicates), make sure you do not approximate.
+ if (queryNode.Stat().NumSamplesMade() > 0 || !firstLeafExact)
+ {
+ // check if this node can be approximated by sampling
+ size_t samplesReqd =
+ (size_t) std::ceil(samplingRatio * (double) referenceNode.Count());
+ samplesReqd
+ = std::min(samplesReqd,
+ numSamplesReqd - queryNode.Stat().NumSamplesMade());
+
+ if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
+ {
+ // if too many samples required and not at a leaf, then can't prune
+
+ // Since query tree descend is necessary now,
+ // propagate the number of samples made down to the children
+
+ // Iterate through all children and propagate the number of
+ // samples made to the children.
+ // Only update if the parent node has made samples the children
+ // have not seen
+ for (size_t i = 0; i < queryNode.NumChildren(); i++)
+ queryNode.Child(i).Stat().NumSamplesMade()
+ = std::max(queryNode.Stat().NumSamplesMade(),
+ queryNode.Child(i).Stat().NumSamplesMade());
+
+ return distance;
+ }
+ else
+ {
+ if (!referenceNode.IsLeaf()) // if not a leaf
+ {
+ // Then samplesReqd <= singleSampleLimit.
+ // Hence approximate node by sampling enough number of points
+ // for every query in the 'queryNode'
+ for (size_t queryIndex = queryNode.Begin();
+ queryIndex < queryNode.End(); queryIndex++)
+ {
+ arma::uvec distinctSamples;
+ ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+ distinctSamples);
+ for (size_t i = 0; i < distinctSamples.n_elem; i++)
+ // The counting of the samples are done in the 'BaseCase'
+ // function so no book-keeping is required here.
+ BaseCase(queryIndex,
+ referenceNode.Begin() + (size_t) distinctSamples[i]);
+ }
+
+ // update the number of samples made for the queryNode and
+ // also update the number of sample made for the child nodes.
+ queryNode.Stat().NumSamplesMade() += samplesReqd;
+
+ // since you are not going to descend down the query tree for this
+ // referenceNode, there is no point updating the number of
+ // samples made for the child nodes of this queryNode.
+
+ // Node approximated so we can prune it
+ return DBL_MAX;
+ }
+ else // we are at a leaf.
+ {
+ if (sampleAtLeaves) // if allowed to sample at leaves.
+ {
+ // Approximate node by sampling enough number of points
+ // for every query in the 'queryNode'.
+ for (size_t queryIndex = queryNode.Begin();
+ queryIndex < queryNode.End(); queryIndex++)
+ {
+ arma::uvec distinctSamples;
+ ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+ distinctSamples);
+ for (size_t i = 0; i < distinctSamples.n_elem; i++)
+ // The counting of the samples are done in the 'BaseCase'
+ // function so no book-keeping is required here.
+ BaseCase(queryIndex,
+ referenceNode.Begin() + (size_t) distinctSamples[i]);
+ }
+
+ // update the number of samples made for the queryNode and
+ // also update the number of sample made for the child nodes.
+ queryNode.Stat().NumSamplesMade() += samplesReqd;
+
+ // since you are not going to descend down the query tree for this
+ // referenceNode, there is no point updating the number of
+ // samples made for the child nodes of this queryNode.
+
+ // (Leaf) Node approximated so can prune it
+ return DBL_MAX;
+ }
+ else
+ {
+ // Not allowed to sample from leaves, so cannot prune.
+ // Propagate the number of samples made down to the children
+
+ // Go through all children and propagate the number of
+ // samples made to the children.
+ for (size_t i = 0; i < queryNode.NumChildren(); i++)
+ queryNode.Child(i).Stat().NumSamplesMade()
+ = std::max(queryNode.Stat().NumSamplesMade(),
+ queryNode.Child(i).Stat().NumSamplesMade());
+
+ return distance;
+ } // sample at leaves
+ } // if not-leaf
+ } // if cannot-approximate by sampling
+ }
+ else
+ {
+ // Have to first to visit your first leaf to boost your accuracy
+
+ // Propagate the number of samples made down to the children
+
+ // Go through all children and propagate the number of
+ // samples made to the children.
+ for (size_t i = 0; i < queryNode.NumChildren(); i++)
+ queryNode.Child(i).Stat().NumSamplesMade()
+ = std::max(queryNode.Stat().NumSamplesMade(),
+ queryNode.Child(i).Stat().NumSamplesMade());
+
+ return distance;
+ } // if first-leaf exact visit required
+ }
+ else
+ {
+ // Either there cannot be anything better in this node.
+ // Or enough number of samples are already made.
+ // So prune it.
+
+ // add 'fake' samples from this node; fake because the distances to
+ // these samples need not be computed.
+
+ // If enough samples already made, this step does not change the
+ // result of the search since this queryNode will never be
+ // descended anymore.
+ queryNode.Stat().NumSamplesMade() +=
+ (size_t) std::floor(samplingRatio * (double) referenceNode.Count());
+
+ // since you are not going to descend down the query tree for this
+ // reference node, there is no point updating the number of samples
+ // made for the child nodes of this queryNode.
+
+ return DBL_MAX;
+ } // if can-prune
+} // Score(node, node)
+
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double RASearchRules<SortPolicy, MetricType, TreeType>::
+Score(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double baseCaseResult)
+{
+ // First try to find the distance bound to check if we can prune
+ // by distance.
+
+ // find the best node-to-node distance
+ const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
+ &referenceNode,
+ baseCaseResult);
+ double pointBound = DBL_MAX;
+ double childBound = DBL_MAX;
+ const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
+
+ for (size_t i = 0; i < queryNode.NumPoints(); i++)
+ {
+ const double bound = distances(distances.n_rows - 1, queryNode.Point(i))
+ + maxDescendantDistance;
+ if (bound < pointBound)
+ pointBound = bound;
+ }
+
+ for (size_t i = 0; i < queryNode.NumChildren(); i++)
+ {
+ const double bound = queryNode.Child(i).Stat().Bound();
+ if (bound < childBound)
+ childBound = bound;
+ }
+
+ // update the bound
+ queryNode.Stat().Bound() = std::min(pointBound, childBound);
+ const double bestDistance = queryNode.Stat().Bound();
+
+ // update the number of samples made for that node
+ // -- propagate up from child nodes if child nodes have made samples
+ // that the parent node is not aware of.
+ // REMEMBER to propagate down samples made to the child nodes
+ // if 'queryNode' descend is deemed necessary.
+
+ // only update from children if a non-leaf node obviously
+ if (!queryNode.IsLeaf())
+ {
+ size_t numSamplesMadeInChildNodes = std::numeric_limits<size_t>::max();
+
+ // Find the minimum number of samples made among all children
+ for (size_t i = 0; i < queryNode.NumChildren(); i++)
+ {
+ const size_t numSamples = queryNode.Child(i).Stat().NumSamplesMade();
+ if (numSamples < numSamplesMadeInChildNodes)
+ numSamplesMadeInChildNodes = numSamples;
+ }
+
+ // The number of samples made for a node is propagated up from the
+ // child nodes if the child nodes have made samples that the parent
+ // (which is the current 'queryNode') is not aware of.
+ queryNode.Stat().NumSamplesMade()
+ = std::max(queryNode.Stat().NumSamplesMade(),
+ numSamplesMadeInChildNodes);
+ }
+
+ // Now check if the node-pair interaction can be pruned
+
+ // If this is better than the best distance we've seen so far,
+ // maybe there will be something down this node.
+ // Also check if enough samples are already made for this 'queryNode'.
+ if (SortPolicy::IsBetter(distance, bestDistance)
+ && queryNode.Stat().NumSamplesMade() < numSamplesReqd)
+ {
+ // We cannot prune this node
+ // Try approximating this node by sampling
+
+ // If you are required to visit the first leaf (to find possible
+ // duplicates), make sure you do not approximate.
+ if (queryNode.Stat().NumSamplesMade() > 0 || !firstLeafExact)
+ {
+ // check if this node can be approximated by sampling
+ size_t samplesReqd =
+ (size_t) std::ceil(samplingRatio * (double) referenceNode.Count());
+ samplesReqd
+ = std::min(samplesReqd,
+ numSamplesReqd - queryNode.Stat().NumSamplesMade());
+
+ if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
+ {
+ // if too many samples required and not at a leaf, then can't prune
+
+ // Since query tree descend is necessary now,
+ // propagate the number of samples made down to the children
+
+ // Iterate through all children and propagate the number of
+ // samples made to the children.
+ // Only update if the parent node has made samples the children
+ // have not seen
+ for (size_t i = 0; i < queryNode.NumChildren(); i++)
+ queryNode.Child(i).Stat().NumSamplesMade()
+ = std::max(queryNode.Stat().NumSamplesMade(),
+ queryNode.Child(i).Stat().NumSamplesMade());
+
+ return distance;
+ }
+ else
+ {
+ if (!referenceNode.IsLeaf()) // if not a leaf
+ {
+ // Then samplesReqd <= singleSampleLimit.
+ // Hence approximate node by sampling enough number of points
+ // for every query in the 'queryNode'
+ for (size_t queryIndex = queryNode.Begin();
+ queryIndex < queryNode.End(); queryIndex++)
+ {
+ arma::uvec distinctSamples;
+ ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+ distinctSamples);
+ for (size_t i = 0; i < distinctSamples.n_elem; i++)
+ // The counting of the samples are done in the 'BaseCase'
+ // function so no book-keeping is required here.
+ BaseCase(queryIndex,
+ referenceNode.Begin() + (size_t) distinctSamples[i]);
+ }
+
+ // update the number of samples made for the queryNode and
+ // also update the number of sample made for the child nodes.
+ queryNode.Stat().NumSamplesMade() += samplesReqd;
+
+ // since you are not going to descend down the query tree for this
+ // referenceNode, there is no point updating the number of
+ // samples made for the child nodes of this queryNode.
+
+ // Node approximated so we can prune it
+ return DBL_MAX;
+ }
+ else // we are at a leaf.
+ {
+ if (sampleAtLeaves) // if allowed to sample at leaves.
+ {
+ // Approximate node by sampling enough number of points
+ // for every query in the 'queryNode'.
+ for (size_t queryIndex = queryNode.Begin();
+ queryIndex < queryNode.End(); queryIndex++)
+ {
+ arma::uvec distinctSamples;
+ ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+ distinctSamples);
+ for (size_t i = 0; i < distinctSamples.n_elem; i++)
+ // The counting of the samples are done in the 'BaseCase'
+ // function so no book-keeping is required here.
+ BaseCase(queryIndex,
+ referenceNode.Begin() + (size_t) distinctSamples[i]);
+ }
+
+ // update the number of samples made for the queryNode and
+ // also update the number of sample made for the child nodes.
+ queryNode.Stat().NumSamplesMade() += samplesReqd;
+
+ // since you are not going to descend down the query tree for this
+ // referenceNode, there is no point updating the number of
+ // samples made for the child nodes of this queryNode.
+
+ // (Leaf) Node approximated so can prune it
+ return DBL_MAX;
+ }
+ else
+ {
+ // Not allowed to sample from leaves, so cannot prune.
+ // Propagate the number of samples made down to the children
+
+ // Go through all children and propagate the number of
+ // samples made to the children.
+ for (size_t i = 0; i < queryNode.NumChildren(); i++)
+ queryNode.Child(i).Stat().NumSamplesMade()
+ = std::max(queryNode.Stat().NumSamplesMade(),
+ queryNode.Child(i).Stat().NumSamplesMade());
+
+ return distance;
+ } // sample at leaves
+ } // if not-leaf
+ } // if cannot-approximate by sampling
+ }
+ else
+ {
+ // Have to first to visit your first leaf to boost your accuracy
+
+ // Propagate the number of samples made down to the children
+
+ // Go through all children and propagate the number of
+ // samples made to the children.
+ for (size_t i = 0; i < queryNode.NumChildren(); i++)
+ queryNode.Child(i).Stat().NumSamplesMade()
+ = std::max(queryNode.Stat().NumSamplesMade(),
+ queryNode.Child(i).Stat().NumSamplesMade());
+
+ return distance;
+ } // if first-leaf exact visit required
+ }
+ else
+ {
+ // Either there cannot be anything better in this node.
+ // Or enough number of samples are already made.
+ // So prune it.
+
+ // add 'fake' samples from this node; fake because the distances to
+ // these samples need not be computed.
+
+ // If enough samples already made, this step does not change the
+ // result of the search since this queryNode will never be
+ // descended anymore.
+ queryNode.Stat().NumSamplesMade() +=
+ (size_t) std::floor(samplingRatio * (double) referenceNode.Count());
+
+ // since you are not going to descend down the query tree for this
+ // reference node, there is no point updating the number of samples
+ // made for the child nodes of this queryNode.
+
+ return DBL_MAX;
+ } // if can-prune
+} // Score(node, node, baseCaseResult)
+
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline double RASearchRules<SortPolicy, MetricType, TreeType>::
+Rescore(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double oldScore)
+{
+ if (oldScore == DBL_MAX)
+ return oldScore;
+
+ // First try to find the distance bound to check if we can prune
+ // by distance.
+ double pointBound = DBL_MAX;
+ double childBound = DBL_MAX;
+ const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
+
+ for (size_t i = 0; i < queryNode.NumPoints(); i++)
+ {
+ const double bound = distances(distances.n_rows - 1, queryNode.Point(i))
+ + maxDescendantDistance;
+ if (bound < pointBound)
+ pointBound = bound;
+ }
+
+ for (size_t i = 0; i < queryNode.NumChildren(); i++)
+ {
+ const double bound = queryNode.Child(i).Stat().Bound();
+ if (bound < childBound)
+ childBound = bound;
+ }
+
+ // update the bound
+ queryNode.Stat().Bound() = std::min(pointBound, childBound);
+ const double bestDistance = queryNode.Stat().Bound();
+
+ // Now check if the node-pair interaction can be pruned by sampling
+ // update the number of samples made for that node
+ // -- propagate up from child nodes if child nodes have made samples
+ // that the parent node is not aware of.
+ // REMEMBER to propagate down samples made to the child nodes
+ // if the parent samples.
+
+ // only update from children if a non-leaf node obviously
+ if (!queryNode.IsLeaf())
+ {
+ size_t numSamplesMadeInChildNodes = std::numeric_limits<size_t>::max();
+
+ // Find the minimum number of samples made among all children
+ for (size_t i = 0; i < queryNode.NumChildren(); i++)
+ {
+ const size_t numSamples = queryNode.Child(i).Stat().NumSamplesMade();
+ if (numSamples < numSamplesMadeInChildNodes)
+ numSamplesMadeInChildNodes = numSamples;
+ }
+
+ // The number of samples made for a node is propagated up from the
+ // child nodes if the child nodes have made samples that the parent
+ // (which is the current 'queryNode') is not aware of.
+ queryNode.Stat().NumSamplesMade()
+ = std::max(queryNode.Stat().NumSamplesMade(),
+ numSamplesMadeInChildNodes);
+ }
+
+ // Now check if the node-pair interaction can be pruned by sampling
+
+ // If this is better than the best distance we've seen so far,
+ // maybe there will be something down this node.
+ // Also check if enough samples are already made for this query.
+ if (SortPolicy::IsBetter(oldScore, bestDistance)
+ && queryNode.Stat().NumSamplesMade() < numSamplesReqd)
+ {
+ // We cannot prune this node
+ // Try approximating this node by sampling
+
+ // Here we assume that since we are re-scoring, the algorithm
+ // has already sampled some candidates, and if specified, also
+ // traversed to the first leaf.
+ // So no checks regarding that is made any more.
+ //
+ // check if this node can be approximated by sampling
+ size_t samplesReqd =
+ (size_t) std::ceil(samplingRatio * (double) referenceNode.Count());
+ samplesReqd
+ = std::min(samplesReqd,
+ numSamplesReqd - queryNode.Stat().NumSamplesMade());
+
+ if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
+ {
+ // if too many samples required and not at a leaf, then can't prune
+
+ // Since query tree descend is necessary now,
+ // propagate the number of samples made down to the children
+
+ // Go through all children and propagate the number of
+ // samples made to the children.
+ // Only update if the parent node has made samples the children
+ // have not seen
+ for (size_t i = 0; i < queryNode.NumChildren(); i++)
+ queryNode.Child(i).Stat().NumSamplesMade()
+ = std::max(queryNode.Stat().NumSamplesMade(),
+ queryNode.Child(i).Stat().NumSamplesMade());
+
+ return oldScore;
+ }
+ else
+ {
+ if (!referenceNode.IsLeaf()) // if not a leaf
+ {
+ // Then samplesReqd <= singleSampleLimit.
+ // Hence approximate node by sampling enough number of points
+ // for every query in the 'queryNode'
+ for (size_t queryIndex = queryNode.Begin();
+ queryIndex < queryNode.End(); queryIndex++)
+ {
+ arma::uvec distinctSamples;
+ ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+ distinctSamples);
+ for (size_t i = 0; i < distinctSamples.n_elem; i++)
+ // The counting of the samples are done in the 'BaseCase'
+ // function so no book-keeping is required here.
+ BaseCase(queryIndex,
+ referenceNode.Begin() + (size_t) distinctSamples[i]);
+ }
+
+ // update the number of samples made for the queryNode and
+ // also update the number of sample made for the child nodes.
+ queryNode.Stat().NumSamplesMade() += samplesReqd;
+
+ // since you are not going to descend down the query tree for this
+ // referenceNode, there is no point updating the number of
+ // samples made for the child nodes of this queryNode.
+
+ // Node approximated so we can prune it
+ return DBL_MAX;
+ }
+ else // we are at a leaf.
+ {
+ if (sampleAtLeaves) // if allowed to sample at leaves.
+ {
+ // Approximate node by sampling enough number of points
+ // for every query in the 'queryNode'.
+ for (size_t queryIndex = queryNode.Begin();
+ queryIndex < queryNode.End(); queryIndex++)
+ {
+ arma::uvec distinctSamples;
+ ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+ distinctSamples);
+ for (size_t i = 0; i < distinctSamples.n_elem; i++)
+ // The counting of the samples are done in the 'BaseCase'
+ // function so no book-keeping is required here.
+ BaseCase(queryIndex,
+ referenceNode.Begin() + (size_t) distinctSamples[i]);
+ }
+
+ // update the number of samples made for the queryNode and
+ // also update the number of sample made for the child nodes.
+ queryNode.Stat().NumSamplesMade() += samplesReqd;
+
+ // since you are not going to descend down the query tree for this
+ // referenceNode, there is no point updating the number of
+ // samples made for the child nodes of this queryNode.
+
+ // (Leaf) Node approximated so can prune it
+ return DBL_MAX;
+ }
+ else
+ {
+ // not allowed to sample from leaves, so cannot prune.
+ // propagate the number of samples made down to the children
+
+ // going through all children and propagate the number of
+ // samples made to the children.
+ for (size_t i = 0; i < queryNode.NumChildren(); i++)
+ queryNode.Child(i).Stat().NumSamplesMade()
+ = std::max(queryNode.Stat().NumSamplesMade(),
+ queryNode.Child(i).Stat().NumSamplesMade());
+
+ return oldScore;
+ } // sample at leaves
+ } // if not-leaf
+ } // if cannot-approximate by sampling
+ }
+ else
+ {
+ // Either there cannot be anything better in this node.
+ // Or enough number of samples are already made.
+ // So prune it.
+
+ // add 'fake' samples from this node; fake because the distances to
+ // these samples need not be computed.
+
+ // If enough samples already made, this step does not change the
+ // result of the search since this queryNode will never be
+ // descended anymore.
+ queryNode.Stat().NumSamplesMade() +=
+ (size_t) std::floor(samplingRatio * (double) referenceNode.Count());
+
+ // since you are not going to descend down the query tree for this
+ // reference node, there is no point updating the number of samples
+ // made for the child nodes of this queryNode.
+
+ return DBL_MAX;
+ } // if can-prune
+} // Rescore(node, node, oldScore)
+
+
+/**
+ * Helper function to insert a point into the neighbors and distances matrices.
+ *
+ * @param queryIndex Index of point whose neighbors we are inserting into.
+ * @param pos Position in list to insert into.
+ * @param neighbor Index of reference point which is being inserted.
+ * @param distance Distance from query point to reference point.
+ */
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void RASearchRules<SortPolicy, MetricType, TreeType>::InsertNeighbor(
+ const size_t queryIndex,
+ const size_t pos,
+ const size_t neighbor,
+ const double distance)
+{
+ // We only memmove() if there is actually a need to shift something.
+ if (pos < (distances.n_rows - 1))
+ {
+ int len = (distances.n_rows - 1) - pos;
+ memmove(distances.colptr(queryIndex) + (pos + 1),
+ distances.colptr(queryIndex) + pos,
+ sizeof(double) * len);
+ memmove(neighbors.colptr(queryIndex) + (pos + 1),
+ neighbors.colptr(queryIndex) + pos,
+ sizeof(size_t) * len);
+ }
+
+ // Now put the new information in the right index.
+ distances(pos, queryIndex) = distance;
+ neighbors(pos, queryIndex) = neighbor;
+}
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+#endif // __MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_typedef.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/rann/ra_typedef.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_typedef.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,68 +0,0 @@
-/**
- * @file ra_typedef.hpp
- * @author Parikshit Ram
- *
- * Simple typedefs describing template instantiations of the RASearch
- * class which are commonly used.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_NEIGHBOR_SEARCH_RA_TYPEDEF_H
-#define __MLPACK_NEIGHBOR_SEARCH_RA_TYPEDEF_H
-
-// In case someone included this directly.
-#include "ra_search.hpp"
-
-#include <mlpack/core/metrics/lmetric.hpp>
-
-#include <mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp>
-#include <mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp>
-
-namespace mlpack {
-namespace neighbor {
-
-/**
- * The AllkRANN class is the all-k-rank-approximate-nearest-neighbors method.
- * It returns squared L2 distances (squared Euclidean distances) for each
- * of the k rank-approximate nearest-neighbors. Squared distances are used
- * because they are slightly faster than non-squared distances
- * (they have one fewer call to sqrt()).
- *
- * The approximation is controlled with two parameters (see allkrann_main.cpp)
- * which can be specified at search time. So the tree building is done
- * only once while the search can be performed multiple times with
- * different approximation levels.
- */
-typedef RASearch<> AllkRANN;
-
-/**
- * The AllkRAFN class is the all-k-rank-approximate-farthest-neighbors method.
- * It returns squared L2 distances (squared Euclidean distances) for each
- * of the k rank-approximate farthest-neighbors. Squared distances are used
- * because they are slightly faster than non-squared distances
- * (they have one fewer call to sqrt()).
- *
- * The approximation is controlled with two parameters (see allkrann_main.cpp)
- * which can be specified at search time. So the tree building is done
- * only once while the search can be performed multiple times with
- * different approximation levels.
- */
-typedef RASearch<FurthestNeighborSort> AllkRAFN;
-
-}; // namespace neighbor
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_typedef.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/rann/ra_typedef.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_typedef.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/rann/ra_typedef.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,68 @@
+/**
+ * @file ra_typedef.hpp
+ * @author Parikshit Ram
+ *
+ * Simple typedefs describing template instantiations of the RASearch
+ * class which are commonly used.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_NEIGHBOR_SEARCH_RA_TYPEDEF_H
+#define __MLPACK_NEIGHBOR_SEARCH_RA_TYPEDEF_H
+
+// In case someone included this directly.
+#include "ra_search.hpp"
+
+#include <mlpack/core/metrics/lmetric.hpp>
+
+#include <mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp>
+#include <mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp>
+
+namespace mlpack {
+namespace neighbor {
+
+/**
+ * The AllkRANN class is the all-k-rank-approximate-nearest-neighbors method.
+ * It returns squared L2 distances (squared Euclidean distances) for each
+ * of the k rank-approximate nearest-neighbors. Squared distances are used
+ * because they are slightly faster than non-squared distances
+ * (they have one fewer call to sqrt()).
+ *
+ * The approximation is controlled with two parameters (see allkrann_main.cpp)
+ * which can be specified at search time. So the tree building is done
+ * only once while the search can be performed multiple times with
+ * different approximation levels.
+ */
+typedef RASearch<> AllkRANN;
+
+/**
+ * The AllkRAFN class is the all-k-rank-approximate-farthest-neighbors method.
+ * It returns squared L2 distances (squared Euclidean distances) for each
+ * of the k rank-approximate farthest-neighbors. Squared distances are used
+ * because they are slightly faster than non-squared distances
+ * (they have one fewer call to sqrt()).
+ *
+ * The approximation is controlled with two parameters (see allkrann_main.cpp)
+ * which can be specified at search time. So the tree building is done
+ * only once while the search can be performed multiple times with
+ * different approximation levels.
+ */
+typedef RASearch<FurthestNeighborSort> AllkRAFN;
+
+}; // namespace neighbor
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/data_dependent_random_initializer.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/sparse_coding/data_dependent_random_initializer.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/data_dependent_random_initializer.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,71 +0,0 @@
-/**
- * @file data_dependent_random_initializer.hpp
- * @author Nishant Mehta
- *
- * A sensible heuristic for initializing dictionaries for sparse coding.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_SPARSE_CODING_DATA_DEPENDENT_RANDOM_INITIALIZER_HPP
-#define __MLPACK_METHODS_SPARSE_CODING_DATA_DEPENDENT_RANDOM_INITIALIZER_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace sparse_coding {
-
-/**
- * A data-dependent random dictionary initializer for SparseCoding. This
- * creates random dictionary atoms by adding three random observations from the
- * data together, and then normalizing the atom.
- */
-class DataDependentRandomInitializer
-{
- public:
- /**
- * Initialize the dictionary by adding together three random observations from
- * the data, and then normalizing the atom. This implementation is simple
- * enough to be included with the definition.
- *
- * @param data Dataset to initialize the dictionary with.
- * @param atoms Number of atoms in dictionary.
- * @param dictionary Dictionary to initialize.
- */
- static void Initialize(const arma::mat& data,
- const size_t atoms,
- arma::mat& dictionary)
- {
- // Set the size of the dictionary.
- dictionary.set_size(data.n_rows, atoms);
-
- // Create each atom.
- for (size_t i = 0; i < atoms; ++i)
- {
- // Add three atoms together.
- dictionary.col(i) = (data.col(math::RandInt(data.n_cols)) +
- data.col(math::RandInt(data.n_cols)) +
- data.col(math::RandInt(data.n_cols)));
-
- // Now normalize the atom.
- dictionary.col(i) /= norm(dictionary.col(i), 2);
- }
- }
-};
-
-}; // namespace sparse_coding
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/data_dependent_random_initializer.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/sparse_coding/data_dependent_random_initializer.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/data_dependent_random_initializer.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/data_dependent_random_initializer.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,71 @@
+/**
+ * @file data_dependent_random_initializer.hpp
+ * @author Nishant Mehta
+ *
+ * A sensible heuristic for initializing dictionaries for sparse coding.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_SPARSE_CODING_DATA_DEPENDENT_RANDOM_INITIALIZER_HPP
+#define __MLPACK_METHODS_SPARSE_CODING_DATA_DEPENDENT_RANDOM_INITIALIZER_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace sparse_coding {
+
+/**
+ * A data-dependent random dictionary initializer for SparseCoding. This
+ * creates random dictionary atoms by adding three random observations from the
+ * data together, and then normalizing the atom.
+ */
+class DataDependentRandomInitializer
+{
+ public:
+ /**
+ * Initialize the dictionary by adding together three random observations from
+ * the data, and then normalizing the atom. This implementation is simple
+ * enough to be included with the definition.
+ *
+ * @param data Dataset to initialize the dictionary with.
+ * @param atoms Number of atoms in dictionary.
+ * @param dictionary Dictionary to initialize.
+ */
+ static void Initialize(const arma::mat& data,
+ const size_t atoms,
+ arma::mat& dictionary)
+ {
+ // Set the size of the dictionary.
+ dictionary.set_size(data.n_rows, atoms);
+
+ // Create each atom.
+ for (size_t i = 0; i < atoms; ++i)
+ {
+ // Add three atoms together.
+ dictionary.col(i) = (data.col(math::RandInt(data.n_cols)) +
+ data.col(math::RandInt(data.n_cols)) +
+ data.col(math::RandInt(data.n_cols)));
+
+ // Now normalize the atom.
+ dictionary.col(i) /= norm(dictionary.col(i), 2);
+ }
+ }
+};
+
+}; // namespace sparse_coding
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/nothing_initializer.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/sparse_coding/nothing_initializer.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/nothing_initializer.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,56 +0,0 @@
-/**
- * @file nothing_initializer.hpp
- * @author Ryan Curtin
- *
- * An initializer for SparseCoding which does precisely nothing. It is useful
- * for when you have an already defined dictionary and you plan on setting it
- * with SparseCoding::Dictionary().
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_SPARSE_CODING_NOTHING_INITIALIZER_HPP
-#define __MLPACK_METHODS_SPARSE_CODING_NOTHING_INITIALIZER_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace sparse_coding {
-
-/**
- * A DictionaryInitializer for SparseCoding which does not initialize anything;
- * it is useful for when the dictionary is already known and will be set with
- * SparseCoding::Dictionary().
- */
-class NothingInitializer
-{
- public:
- /**
- * This function does not initialize the dictionary. This will cause problems
- * for SparseCoding if the dictionary is not set manually before running the
- * method.
- */
- static void Initialize(const arma::mat& /* data */,
- const size_t /* atoms */,
- arma::mat& /* dictionary */)
- {
- // Do nothing!
- }
-};
-
-}; // namespace sparse_coding
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/nothing_initializer.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/sparse_coding/nothing_initializer.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/nothing_initializer.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/nothing_initializer.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,56 @@
+/**
+ * @file nothing_initializer.hpp
+ * @author Ryan Curtin
+ *
+ * An initializer for SparseCoding which does precisely nothing. It is useful
+ * for when you have an already defined dictionary and you plan on setting it
+ * with SparseCoding::Dictionary().
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_SPARSE_CODING_NOTHING_INITIALIZER_HPP
+#define __MLPACK_METHODS_SPARSE_CODING_NOTHING_INITIALIZER_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace sparse_coding {
+
+/**
+ * A DictionaryInitializer for SparseCoding which does not initialize anything;
+ * it is useful for when the dictionary is already known and will be set with
+ * SparseCoding::Dictionary().
+ */
+class NothingInitializer
+{
+ public:
+ /**
+ * This function does not initialize the dictionary. This will cause problems
+ * for SparseCoding if the dictionary is not set manually before running the
+ * method.
+ */
+ static void Initialize(const arma::mat& /* data */,
+ const size_t /* atoms */,
+ arma::mat& /* dictionary */)
+ {
+ // Do nothing!
+ }
+};
+
+}; // namespace sparse_coding
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/random_initializer.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/sparse_coding/random_initializer.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/random_initializer.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,63 +0,0 @@
-/**
- * @file random_initializer.hpp
- * @author Nishant Mehta
- *
- * A very simple random dictionary initializer for SparseCoding; it is probably
- * not a very good choice.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_SPARSE_CODING_RANDOM_INITIALIZER_HPP
-#define __MLPACK_METHODS_SPARSE_CODING_RANDOM_INITIALIZER_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace sparse_coding {
-
-/**
- * A DictionaryInitializer for use with the SparseCoding class. This provides a
- * random, normally distributed dictionary, such that each atom has a norm of 1.
- */
-class RandomInitializer
-{
- public:
- /**
- * Initialize the dictionary randomly from a normal distribution, such that
- * each atom has a norm of 1. This is simple enough to be included with the
- * definition.
- *
- * @param data Dataset to use for initialization.
- * @param atoms Number of atoms (columns) in the dictionary.
- * @param dictionary Dictionary to initialize.
- */
- static void Initialize(const arma::mat& data,
- const size_t atoms,
- arma::mat& dictionary)
- {
- // Create random dictionary.
- dictionary.randn(data.n_rows, atoms);
-
- // Normalize each atom.
- for (size_t j = 0; j < atoms; ++j)
- dictionary.col(j) /= norm(dictionary.col(j), 2);
- }
-};
-
-}; // namespace sparse_coding
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/random_initializer.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/sparse_coding/random_initializer.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/random_initializer.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/random_initializer.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,63 @@
+/**
+ * @file random_initializer.hpp
+ * @author Nishant Mehta
+ *
+ * A very simple random dictionary initializer for SparseCoding; it is probably
+ * not a very good choice.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_SPARSE_CODING_RANDOM_INITIALIZER_HPP
+#define __MLPACK_METHODS_SPARSE_CODING_RANDOM_INITIALIZER_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace sparse_coding {
+
+/**
+ * A DictionaryInitializer for use with the SparseCoding class. This provides a
+ * random, normally distributed dictionary, such that each atom has a norm of 1.
+ */
+class RandomInitializer
+{
+ public:
+ /**
+ * Initialize the dictionary randomly from a normal distribution, such that
+ * each atom has a norm of 1. This is simple enough to be included with the
+ * definition.
+ *
+ * @param data Dataset to use for initialization.
+ * @param atoms Number of atoms (columns) in the dictionary.
+ * @param dictionary Dictionary to initialize.
+ */
+ static void Initialize(const arma::mat& data,
+ const size_t atoms,
+ arma::mat& dictionary)
+ {
+ // Create random dictionary.
+ dictionary.randn(data.n_rows, atoms);
+
+ // Normalize each atom.
+ for (size_t j = 0; j < atoms; ++j)
+ dictionary.col(j) /= norm(dictionary.col(j), 2);
+ }
+};
+
+}; // namespace sparse_coding
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/sparse_coding.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/sparse_coding/sparse_coding.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/sparse_coding.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,217 +0,0 @@
-/**
- * @file sparse_coding.hpp
- * @author Nishant Mehta
- *
- * Definition of the SparseCoding class, which performs L1 (LASSO) or
- * L1+L2 (Elastic Net)-regularized sparse coding with dictionary learning
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_SPARSE_CODING_SPARSE_CODING_HPP
-#define __MLPACK_METHODS_SPARSE_CODING_SPARSE_CODING_HPP
-
-#include <mlpack/core.hpp>
-#include <mlpack/methods/lars/lars.hpp>
-
-// Include our three simple dictionary initializers.
-#include "nothing_initializer.hpp"
-#include "data_dependent_random_initializer.hpp"
-#include "random_initializer.hpp"
-
-namespace mlpack {
-namespace sparse_coding {
-
-/**
- * An implementation of Sparse Coding with Dictionary Learning that achieves
- * sparsity via an l1-norm regularizer on the codes (LASSO) or an (l1+l2)-norm
- * regularizer on the codes (the Elastic Net).
- *
- * Let d be the number of dimensions in the original space, m the number of
- * training points, and k the number of atoms in the dictionary (the dimension
- * of the learned feature space). The training data X is a d-by-m matrix where
- * each column is a point and each row is a dimension. The dictionary D is a
- * d-by-k matrix, and the sparse codes matrix Z is a k-by-m matrix.
- * This program seeks to minimize the objective:
- *
- * \f[
- * \min_{D,Z} 0.5 ||X - D Z||_{F}^2\ + \lambda_1 \sum_{i=1}^m ||Z_i||_1
- * + 0.5 \lambda_2 \sum_{i=1}^m ||Z_i||_2^2
- * \f]
- *
- * subject to \f$ ||D_j||_2 <= 1 \f$ for \f$ 1 <= j <= k \f$
- * where typically \f$ lambda_1 > 0 \f$ and \f$ lambda_2 = 0 \f$.
- *
- * This problem is solved by an algorithm that alternates between a dictionary
- * learning step and a sparse coding step. The dictionary learning step updates
- * the dictionary D using a Newton method based on the Lagrange dual (see the
- * paper below for details). The sparse coding step involves solving a large
- * number of sparse linear regression problems; this can be done efficiently
- * using LARS, an algorithm that can solve the LASSO or the Elastic Net (papers
- * below).
- *
- * Here are those papers:
- *
- * @code
- * @incollection{lee2007efficient,
- * title = {Efficient sparse coding algorithms},
- * author = {Honglak Lee and Alexis Battle and Rajat Raina and Andrew Y. Ng},
- * booktitle = {Advances in Neural Information Processing Systems 19},
- * editor = {B. Sch\"{o}lkopf and J. Platt and T. Hoffman},
- * publisher = {MIT Press},
- * address = {Cambridge, MA},
- * pages = {801--808},
- * year = {2007}
- * }
- * @endcode
- *
- * @code
- * @article{efron2004least,
- * title={Least angle regression},
- * author={Efron, B. and Hastie, T. and Johnstone, I. and Tibshirani, R.},
- * journal={The Annals of statistics},
- * volume={32},
- * number={2},
- * pages={407--499},
- * year={2004},
- * publisher={Institute of Mathematical Statistics}
- * }
- * @endcode
- *
- * @code
- * @article{zou2005regularization,
- * title={Regularization and variable selection via the elastic net},
- * author={Zou, H. and Hastie, T.},
- * journal={Journal of the Royal Statistical Society Series B},
- * volume={67},
- * number={2},
- * pages={301--320},
- * year={2005},
- * publisher={Royal Statistical Society}
- * }
- * @endcode
- *
- * Before the method is run, the dictionary is initialized using the
- * DictionaryInitializationPolicy class. Possible choices include the
- * RandomInitializer, which provides an entirely random dictionary, the
- * DataDependentRandomInitializer, which provides a random dictionary based
- * loosely on characteristics of the dataset, and the NothingInitializer, which
- * does not initialize the dictionary -- instead, the user should set the
- * dictionary using the Dictionary() mutator method.
- *
- * @tparam DictionaryInitializationPolicy The class to use to initialize the
- * dictionary; must have 'void Initialize(const arma::mat& data, arma::mat&
- * dictionary)' function.
- */
-template<typename DictionaryInitializer = DataDependentRandomInitializer>
-class SparseCoding
-{
- public:
- /**
- * Set the parameters to SparseCoding. lambda2 defaults to 0.
- *
- * @param data Data matrix
- * @param atoms Number of atoms in dictionary
- * @param lambda1 Regularization parameter for l1-norm penalty
- * @param lambda2 Regularization parameter for l2-norm penalty
- */
- SparseCoding(const arma::mat& data,
- const size_t atoms,
- const double lambda1,
- const double lambda2 = 0);
-
- /**
- * Run Sparse Coding with Dictionary Learning.
- *
- * @param maxIterations Maximum number of iterations to run algorithm. If 0,
- * the algorithm will run until convergence (or forever).
- * @param objTolerance Tolerance for objective function. When an iteration of
- * the algorithm produces an improvement smaller than this, the algorithm
- * will terminate.
- * @param newtonTolerance Tolerance for the Newton's method dictionary
- * optimization step.
- */
- void Encode(const size_t maxIterations = 0,
- const double objTolerance = 0.01,
- const double newtonTolerance = 1e-6);
-
- /**
- * Sparse code each point via LARS.
- */
- void OptimizeCode();
-
- /**
- * Learn dictionary via Newton method based on Lagrange dual.
- *
- * @param adjacencies Indices of entries (unrolled column by column) of
- * the coding matrix Z that are non-zero (the adjacency matrix for the
- * bipartite graph of points and atoms).
- * @param newtonTolerance Tolerance of the Newton's method optimizer.
- * @return the norm of the gradient of the Lagrange dual with respect to
- * the dual variables
- */
- double OptimizeDictionary(const arma::uvec& adjacencies,
- const double newtonTolerance = 1e-6);
-
- /**
- * Project each atom of the dictionary back onto the unit ball, if necessary.
- */
- void ProjectDictionary();
-
- /**
- * Compute the objective function.
- */
- double Objective() const;
-
- //! Access the data.
- const arma::mat& Data() const { return data; }
-
- //! Access the dictionary.
- const arma::mat& Dictionary() const { return dictionary; }
- //! Modify the dictionary.
- arma::mat& Dictionary() { return dictionary; }
-
- //! Access the sparse codes.
- const arma::mat& Codes() const { return codes; }
- //! Modify the sparse codes.
- arma::mat& Codes() { return codes; }
-
- private:
- //! Number of atoms.
- size_t atoms;
-
- //! Data matrix (columns are points).
- const arma::mat& data;
-
- //! Dictionary (columns are atoms).
- arma::mat dictionary;
-
- //! Sparse codes (columns are points).
- arma::mat codes;
-
- //! l1 regularization term.
- double lambda1;
-
- //! l2 regularization term.
- double lambda2;
-};
-
-}; // namespace sparse_coding
-}; // namespace mlpack
-
-// Include implementation.
-#include "sparse_coding_impl.hpp"
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/sparse_coding.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/sparse_coding/sparse_coding.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/sparse_coding.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/sparse_coding.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,217 @@
+/**
+ * @file sparse_coding.hpp
+ * @author Nishant Mehta
+ *
+ * Definition of the SparseCoding class, which performs L1 (LASSO) or
+ * L1+L2 (Elastic Net)-regularized sparse coding with dictionary learning
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_SPARSE_CODING_SPARSE_CODING_HPP
+#define __MLPACK_METHODS_SPARSE_CODING_SPARSE_CODING_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/methods/lars/lars.hpp>
+
+// Include our three simple dictionary initializers.
+#include "nothing_initializer.hpp"
+#include "data_dependent_random_initializer.hpp"
+#include "random_initializer.hpp"
+
+namespace mlpack {
+namespace sparse_coding {
+
+/**
+ * An implementation of Sparse Coding with Dictionary Learning that achieves
+ * sparsity via an l1-norm regularizer on the codes (LASSO) or an (l1+l2)-norm
+ * regularizer on the codes (the Elastic Net).
+ *
+ * Let d be the number of dimensions in the original space, m the number of
+ * training points, and k the number of atoms in the dictionary (the dimension
+ * of the learned feature space). The training data X is a d-by-m matrix where
+ * each column is a point and each row is a dimension. The dictionary D is a
+ * d-by-k matrix, and the sparse codes matrix Z is a k-by-m matrix.
+ * This program seeks to minimize the objective:
+ *
+ * \f[
+ * \min_{D,Z} 0.5 ||X - D Z||_{F}^2\ + \lambda_1 \sum_{i=1}^m ||Z_i||_1
+ * + 0.5 \lambda_2 \sum_{i=1}^m ||Z_i||_2^2
+ * \f]
+ *
+ * subject to \f$ ||D_j||_2 <= 1 \f$ for \f$ 1 <= j <= k \f$
+ * where typically \f$ lambda_1 > 0 \f$ and \f$ lambda_2 = 0 \f$.
+ *
+ * This problem is solved by an algorithm that alternates between a dictionary
+ * learning step and a sparse coding step. The dictionary learning step updates
+ * the dictionary D using a Newton method based on the Lagrange dual (see the
+ * paper below for details). The sparse coding step involves solving a large
+ * number of sparse linear regression problems; this can be done efficiently
+ * using LARS, an algorithm that can solve the LASSO or the Elastic Net (papers
+ * below).
+ *
+ * Here are those papers:
+ *
+ * @code
+ * @incollection{lee2007efficient,
+ * title = {Efficient sparse coding algorithms},
+ * author = {Honglak Lee and Alexis Battle and Rajat Raina and Andrew Y. Ng},
+ * booktitle = {Advances in Neural Information Processing Systems 19},
+ * editor = {B. Sch\"{o}lkopf and J. Platt and T. Hoffman},
+ * publisher = {MIT Press},
+ * address = {Cambridge, MA},
+ * pages = {801--808},
+ * year = {2007}
+ * }
+ * @endcode
+ *
+ * @code
+ * @article{efron2004least,
+ * title={Least angle regression},
+ * author={Efron, B. and Hastie, T. and Johnstone, I. and Tibshirani, R.},
+ * journal={The Annals of statistics},
+ * volume={32},
+ * number={2},
+ * pages={407--499},
+ * year={2004},
+ * publisher={Institute of Mathematical Statistics}
+ * }
+ * @endcode
+ *
+ * @code
+ * @article{zou2005regularization,
+ * title={Regularization and variable selection via the elastic net},
+ * author={Zou, H. and Hastie, T.},
+ * journal={Journal of the Royal Statistical Society Series B},
+ * volume={67},
+ * number={2},
+ * pages={301--320},
+ * year={2005},
+ * publisher={Royal Statistical Society}
+ * }
+ * @endcode
+ *
+ * Before the method is run, the dictionary is initialized using the
+ * DictionaryInitializationPolicy class. Possible choices include the
+ * RandomInitializer, which provides an entirely random dictionary, the
+ * DataDependentRandomInitializer, which provides a random dictionary based
+ * loosely on characteristics of the dataset, and the NothingInitializer, which
+ * does not initialize the dictionary -- instead, the user should set the
+ * dictionary using the Dictionary() mutator method.
+ *
+ * @tparam DictionaryInitializationPolicy The class to use to initialize the
+ * dictionary; must have 'void Initialize(const arma::mat& data, arma::mat&
+ * dictionary)' function.
+ */
+template<typename DictionaryInitializer = DataDependentRandomInitializer>
+class SparseCoding
+{
+ public:
+ /**
+ * Set the parameters to SparseCoding. lambda2 defaults to 0.
+ *
+ * @param data Data matrix
+ * @param atoms Number of atoms in dictionary
+ * @param lambda1 Regularization parameter for l1-norm penalty
+ * @param lambda2 Regularization parameter for l2-norm penalty
+ */
+ SparseCoding(const arma::mat& data,
+ const size_t atoms,
+ const double lambda1,
+ const double lambda2 = 0);
+
+ /**
+ * Run Sparse Coding with Dictionary Learning.
+ *
+ * @param maxIterations Maximum number of iterations to run algorithm. If 0,
+ * the algorithm will run until convergence (or forever).
+ * @param objTolerance Tolerance for objective function. When an iteration of
+ * the algorithm produces an improvement smaller than this, the algorithm
+ * will terminate.
+ * @param newtonTolerance Tolerance for the Newton's method dictionary
+ * optimization step.
+ */
+ void Encode(const size_t maxIterations = 0,
+ const double objTolerance = 0.01,
+ const double newtonTolerance = 1e-6);
+
+ /**
+ * Sparse code each point via LARS.
+ */
+ void OptimizeCode();
+
+ /**
+ * Learn dictionary via Newton method based on Lagrange dual.
+ *
+ * @param adjacencies Indices of entries (unrolled column by column) of
+ * the coding matrix Z that are non-zero (the adjacency matrix for the
+ * bipartite graph of points and atoms).
+ * @param newtonTolerance Tolerance of the Newton's method optimizer.
+ * @return the norm of the gradient of the Lagrange dual with respect to
+ * the dual variables
+ */
+ double OptimizeDictionary(const arma::uvec& adjacencies,
+ const double newtonTolerance = 1e-6);
+
+ /**
+ * Project each atom of the dictionary back onto the unit ball, if necessary.
+ */
+ void ProjectDictionary();
+
+ /**
+ * Compute the objective function.
+ */
+ double Objective() const;
+
+ //! Access the data.
+ const arma::mat& Data() const { return data; }
+
+ //! Access the dictionary.
+ const arma::mat& Dictionary() const { return dictionary; }
+ //! Modify the dictionary.
+ arma::mat& Dictionary() { return dictionary; }
+
+ //! Access the sparse codes.
+ const arma::mat& Codes() const { return codes; }
+ //! Modify the sparse codes.
+ arma::mat& Codes() { return codes; }
+
+ private:
+ //! Number of atoms.
+ size_t atoms;
+
+ //! Data matrix (columns are points).
+ const arma::mat& data;
+
+ //! Dictionary (columns are atoms).
+ arma::mat dictionary;
+
+ //! Sparse codes (columns are points).
+ arma::mat codes;
+
+ //! l1 regularization term.
+ double lambda1;
+
+ //! l2 regularization term.
+ double lambda2;
+};
+
+}; // namespace sparse_coding
+}; // namespace mlpack
+
+// Include implementation.
+#include "sparse_coding_impl.hpp"
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/sparse_coding_impl.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/sparse_coding/sparse_coding_impl.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/sparse_coding_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,348 +0,0 @@
-/**
- * @file sparse_coding_impl.hpp
- * @author Nishant Mehta
- *
- * Implementation of Sparse Coding with Dictionary Learning using l1 (LASSO) or
- * l1+l2 (Elastic Net) regularization.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_METHODS_SPARSE_CODING_SPARSE_CODING_IMPL_HPP
-#define __MLPACK_METHODS_SPARSE_CODING_SPARSE_CODING_IMPL_HPP
-
-// In case it hasn't already been included.
-#include "sparse_coding.hpp"
-
-namespace mlpack {
-namespace sparse_coding {
-
-template<typename DictionaryInitializer>
-SparseCoding<DictionaryInitializer>::SparseCoding(const arma::mat& data,
- const size_t atoms,
- const double lambda1,
- const double lambda2) :
- atoms(atoms),
- data(data),
- codes(atoms, data.n_cols),
- lambda1(lambda1),
- lambda2(lambda2)
-{
- // Initialize the dictionary.
- DictionaryInitializer::Initialize(data, atoms, dictionary);
-}
-
-template<typename DictionaryInitializer>
-void SparseCoding<DictionaryInitializer>::Encode(const size_t maxIterations,
- const double objTolerance,
- const double newtonTolerance)
-{
- Timer::Start("sparse_coding");
-
- double lastObjVal = DBL_MAX;
-
- // Take the initial coding step, which has to happen before entering the main
- // optimization loop.
- Log::Info << "Initial Coding Step." << std::endl;
-
- OptimizeCode();
- arma::uvec adjacencies = find(codes);
-
- Log::Info << " Sparsity level: " << 100.0 * ((double) (adjacencies.n_elem))
- / ((double) (atoms * data.n_cols)) << "%." << std::endl;
- Log::Info << " Objective value: " << Objective() << "." << std::endl;
-
- for (size_t t = 1; t != maxIterations; ++t)
- {
- Log::Info << "Iteration " << t << " of " << maxIterations << "."
- << std::endl;
-
- // First step: optimize the dictionary.
- Log::Info << "Performing dictionary step... " << std::endl;
- OptimizeDictionary(adjacencies, newtonTolerance);
- Log::Info << " Objective value: " << Objective() << "." << std::endl;
-
- // Second step: perform the coding.
- Log::Info << "Performing coding step..." << std::endl;
- OptimizeCode();
- // Get the indices of all the nonzero elements in the codes.
- adjacencies = find(codes);
- Log::Info << " Sparsity level: " << 100.0 * ((double) (adjacencies.n_elem))
- / ((double) (atoms * data.n_cols)) << "%." << std::endl;
-
- // Find the new objective value and improvement so we can check for
- // convergence.
- double curObjVal = Objective();
- double improvement = lastObjVal - curObjVal;
- Log::Info << " Objective value: " << curObjVal << " (improvement "
- << std::scientific << improvement << ")." << std::endl;
-
- // Have we converged?
- if (improvement < objTolerance)
- {
- Log::Info << "Converged within tolerance " << objTolerance << ".\n";
- break;
- }
-
- lastObjVal = curObjVal;
- }
-
- Timer::Stop("sparse_coding");
-}
-
-template<typename DictionaryInitializer>
-void SparseCoding<DictionaryInitializer>::OptimizeCode()
-{
- // When using the Cholesky version of LARS, this is correct even if
- // lambda2 > 0.
- arma::mat matGram = trans(dictionary) * dictionary;
-
- for (size_t i = 0; i < data.n_cols; ++i)
- {
- // Report progress.
- if ((i % 100) == 0)
- Log::Debug << "Optimization at point " << i << "." << std::endl;
-
- bool useCholesky = true;
- regression::LARS lars(useCholesky, matGram, lambda1, lambda2);
-
- // Create an alias of the code (using the same memory), and then LARS will
- // place the result directly into that; then we will not need to have an
- // extra copy.
- arma::vec code = codes.unsafe_col(i);
- lars.Regress(dictionary, data.unsafe_col(i), code, false);
- }
-}
-
-// Dictionary step for optimization.
-template<typename DictionaryInitializer>
-double SparseCoding<DictionaryInitializer>::OptimizeDictionary(
- const arma::uvec& adjacencies,
- const double newtonTolerance)
-{
- // Count the number of atomic neighbors for each point x^i.
- arma::uvec neighborCounts = arma::zeros<arma::uvec>(data.n_cols, 1);
-
- if (adjacencies.n_elem > 0)
- {
- // This gets the column index. Intentional integer division.
- size_t curPointInd = (size_t) (adjacencies(0) / atoms);
-
- size_t nextColIndex = (curPointInd + 1) * atoms;
- for (size_t l = 1; l < adjacencies.n_elem; ++l)
- {
- // If l no longer refers to an element in this column, advance the column
- // number accordingly.
- if (adjacencies(l) >= nextColIndex)
- {
- curPointInd = (size_t) (adjacencies(l) / atoms);
- nextColIndex = (curPointInd + 1) * atoms;
- }
-
- ++neighborCounts(curPointInd);
- }
- }
-
- // Handle the case of inactive atoms (atoms not used in the given coding).
- std::vector<size_t> inactiveAtoms;
-
- for (size_t j = 0; j < atoms; ++j)
- {
- if (accu(codes.row(j) != 0) == 0)
- inactiveAtoms.push_back(j);
- }
-
- const size_t nInactiveAtoms = inactiveAtoms.size();
- const size_t nActiveAtoms = atoms - nInactiveAtoms;
-
- // Efficient construction of Z restricted to active atoms.
- arma::mat matActiveZ;
- if (nInactiveAtoms > 0)
- {
- math::RemoveRows(codes, inactiveAtoms, matActiveZ);
- }
-
- if (nInactiveAtoms > 0)
- {
- Log::Warn << "There are " << nInactiveAtoms
- << " inactive atoms. They will be re-initialized randomly.\n";
- }
-
- Log::Debug << "Solving Dual via Newton's Method.\n";
-
- // Solve using Newton's method in the dual - note that the final dot
- // multiplication with inv(A) seems to be unavoidable. Although more
- // expensive, the code written this way (we use solve()) should be more
- // numerically stable than just using inv(A) for everything.
- arma::vec dualVars = arma::zeros<arma::vec>(nActiveAtoms);
-
- //vec dualVars = 1e-14 * ones<vec>(nActiveAtoms);
-
- // Method used by feature sign code - fails miserably here. Perhaps the
- // MATLAB optimizer fmincon does something clever?
- //vec dualVars = 10.0 * randu(nActiveAtoms, 1);
-
- //vec dualVars = diagvec(solve(dictionary, data * trans(codes))
- // - codes * trans(codes));
- //for (size_t i = 0; i < dualVars.n_elem; i++)
- // if (dualVars(i) < 0)
- // dualVars(i) = 0;
-
- bool converged = false;
-
- // If we have any inactive atoms, we must construct these differently.
- arma::mat codesXT;
- arma::mat codesZT;
-
- if (inactiveAtoms.empty())
- {
- codesXT = codes * trans(data);
- codesZT = codes * trans(codes);
- }
- else
- {
- codesXT = matActiveZ * trans(data);
- codesZT = matActiveZ * trans(matActiveZ);
- }
-
- double normGradient;
- double improvement;
- for (size_t t = 1; !converged; ++t)
- {
- arma::mat A = codesZT + diagmat(dualVars);
-
- arma::mat matAInvZXT = solve(A, codesXT);
-
- arma::vec gradient = -arma::sum(arma::square(matAInvZXT), 1);
- gradient += 1;
-
- arma::mat hessian = -(-2 * (matAInvZXT * trans(matAInvZXT)) % inv(A));
-
- arma::vec searchDirection = -solve(hessian, gradient);
- //printf("%e\n", norm(searchDirection, 2));
-
- // Armijo line search.
- const double c = 1e-4;
- double alpha = 1.0;
- const double rho = 0.9;
- double sufficientDecrease = c * dot(gradient, searchDirection);
-
- while (true)
- {
- // Calculate objective.
- double sumDualVars = sum(dualVars);
- double fOld = -(-trace(trans(codesXT) * matAInvZXT) - sumDualVars);
- double fNew = -(-trace(trans(codesXT) * solve(codesZT +
- diagmat(dualVars + alpha * searchDirection), codesXT)) -
- (sumDualVars + alpha * sum(searchDirection)));
-
- if (fNew <= fOld + alpha * sufficientDecrease)
- {
- searchDirection = alpha * searchDirection;
- improvement = fOld - fNew;
- break;
- }
-
- alpha *= rho;
- }
-
- // Take step and print useful information.
- dualVars += searchDirection;
- normGradient = norm(gradient, 2);
- Log::Debug << "Newton Method iteration " << t << ":" << std::endl;
- Log::Debug << " Gradient norm: " << std::scientific << normGradient
- << "." << std::endl;
- Log::Debug << " Improvement: " << std::scientific << improvement << ".\n";
-
- if (improvement < newtonTolerance)
- converged = true;
- }
-
- if (inactiveAtoms.empty())
- {
- // Directly update dictionary.
- dictionary = trans(solve(codesZT + diagmat(dualVars), codesXT));
- }
- else
- {
- arma::mat activeDictionary = trans(solve(codesZT +
- diagmat(dualVars), codesXT));
-
- // Update all atoms.
- size_t currentInactiveIndex = 0;
- for (size_t i = 0; i < atoms; ++i)
- {
- if (inactiveAtoms[currentInactiveIndex] == i)
- {
- // This atom is inactive. Reinitialize it randomly.
- dictionary.col(i) = (data.col(math::RandInt(data.n_cols)) +
- data.col(math::RandInt(data.n_cols)) +
- data.col(math::RandInt(data.n_cols)));
-
- dictionary.col(i) /= norm(dictionary.col(i), 2);
-
- // Increment inactive index counter.
- ++currentInactiveIndex;
- }
- else
- {
- // Update estimate.
- dictionary.col(i) = activeDictionary.col(i - currentInactiveIndex);
- }
- }
- }
- //printf("final reconstruction error: %e\n", norm(data - dictionary * codes, "fro"));
- return normGradient;
-}
-
-// Project each atom of the dictionary back into the unit ball (if necessary).
-template<typename DictionaryInitializer>
-void SparseCoding<DictionaryInitializer>::ProjectDictionary()
-{
- for (size_t j = 0; j < atoms; j++)
- {
- double atomNorm = norm(dictionary.col(j), 2);
- if (atomNorm > 1)
- {
- Log::Info << "Norm of atom " << j << " exceeds 1 (" << std::scientific
- << atomNorm << "). Shrinking...\n";
- dictionary.col(j) /= atomNorm;
- }
- }
-}
-
-// Compute the objective function.
-template<typename DictionaryInitializer>
-double SparseCoding<DictionaryInitializer>::Objective() const
-{
- double l11NormZ = sum(sum(abs(codes)));
- double froNormResidual = norm(data - (dictionary * codes), "fro");
-
- if (lambda2 > 0)
- {
- double froNormZ = norm(codes, "fro");
- return 0.5 * (std::pow(froNormResidual, 2.0) + (lambda2 *
- std::pow(froNormZ, 2.0))) + (lambda1 * l11NormZ);
- }
- else // It can be simpler.
- {
- return 0.5 * std::pow(froNormResidual, 2.0) + lambda1 * l11NormZ;
- }
-}
-
-}; // namespace sparse_coding
-}; // namespace mlpack
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/sparse_coding_impl.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/sparse_coding/sparse_coding_impl.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/sparse_coding_impl.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/sparse_coding_impl.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,348 @@
+/**
+ * @file sparse_coding_impl.hpp
+ * @author Nishant Mehta
+ *
+ * Implementation of Sparse Coding with Dictionary Learning using l1 (LASSO) or
+ * l1+l2 (Elastic Net) regularization.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_METHODS_SPARSE_CODING_SPARSE_CODING_IMPL_HPP
+#define __MLPACK_METHODS_SPARSE_CODING_SPARSE_CODING_IMPL_HPP
+
+// In case it hasn't already been included.
+#include "sparse_coding.hpp"
+
+namespace mlpack {
+namespace sparse_coding {
+
+template<typename DictionaryInitializer>
+SparseCoding<DictionaryInitializer>::SparseCoding(const arma::mat& data,
+ const size_t atoms,
+ const double lambda1,
+ const double lambda2) :
+ atoms(atoms),
+ data(data),
+ codes(atoms, data.n_cols),
+ lambda1(lambda1),
+ lambda2(lambda2)
+{
+ // Initialize the dictionary.
+ DictionaryInitializer::Initialize(data, atoms, dictionary);
+}
+
+template<typename DictionaryInitializer>
+void SparseCoding<DictionaryInitializer>::Encode(const size_t maxIterations,
+ const double objTolerance,
+ const double newtonTolerance)
+{
+ Timer::Start("sparse_coding");
+
+ double lastObjVal = DBL_MAX;
+
+ // Take the initial coding step, which has to happen before entering the main
+ // optimization loop.
+ Log::Info << "Initial Coding Step." << std::endl;
+
+ OptimizeCode();
+ arma::uvec adjacencies = find(codes);
+
+ Log::Info << " Sparsity level: " << 100.0 * ((double) (adjacencies.n_elem))
+ / ((double) (atoms * data.n_cols)) << "%." << std::endl;
+ Log::Info << " Objective value: " << Objective() << "." << std::endl;
+
+ for (size_t t = 1; t != maxIterations; ++t)
+ {
+ Log::Info << "Iteration " << t << " of " << maxIterations << "."
+ << std::endl;
+
+ // First step: optimize the dictionary.
+ Log::Info << "Performing dictionary step... " << std::endl;
+ OptimizeDictionary(adjacencies, newtonTolerance);
+ Log::Info << " Objective value: " << Objective() << "." << std::endl;
+
+ // Second step: perform the coding.
+ Log::Info << "Performing coding step..." << std::endl;
+ OptimizeCode();
+ // Get the indices of all the nonzero elements in the codes.
+ adjacencies = find(codes);
+ Log::Info << " Sparsity level: " << 100.0 * ((double) (adjacencies.n_elem))
+ / ((double) (atoms * data.n_cols)) << "%." << std::endl;
+
+ // Find the new objective value and improvement so we can check for
+ // convergence.
+ double curObjVal = Objective();
+ double improvement = lastObjVal - curObjVal;
+ Log::Info << " Objective value: " << curObjVal << " (improvement "
+ << std::scientific << improvement << ")." << std::endl;
+
+ // Have we converged?
+ if (improvement < objTolerance)
+ {
+ Log::Info << "Converged within tolerance " << objTolerance << ".\n";
+ break;
+ }
+
+ lastObjVal = curObjVal;
+ }
+
+ Timer::Stop("sparse_coding");
+}
+
+template<typename DictionaryInitializer>
+void SparseCoding<DictionaryInitializer>::OptimizeCode()
+{
+ // When using the Cholesky version of LARS, this is correct even if
+ // lambda2 > 0.
+ arma::mat matGram = trans(dictionary) * dictionary;
+
+ for (size_t i = 0; i < data.n_cols; ++i)
+ {
+ // Report progress.
+ if ((i % 100) == 0)
+ Log::Debug << "Optimization at point " << i << "." << std::endl;
+
+ bool useCholesky = true;
+ regression::LARS lars(useCholesky, matGram, lambda1, lambda2);
+
+ // Create an alias of the code (using the same memory), and then LARS will
+ // place the result directly into that; then we will not need to have an
+ // extra copy.
+ arma::vec code = codes.unsafe_col(i);
+ lars.Regress(dictionary, data.unsafe_col(i), code, false);
+ }
+}
+
+// Dictionary step for optimization.
+template<typename DictionaryInitializer>
+double SparseCoding<DictionaryInitializer>::OptimizeDictionary(
+ const arma::uvec& adjacencies,
+ const double newtonTolerance)
+{
+ // Count the number of atomic neighbors for each point x^i.
+ arma::uvec neighborCounts = arma::zeros<arma::uvec>(data.n_cols, 1);
+
+ if (adjacencies.n_elem > 0)
+ {
+ // This gets the column index. Intentional integer division.
+ size_t curPointInd = (size_t) (adjacencies(0) / atoms);
+
+ size_t nextColIndex = (curPointInd + 1) * atoms;
+ for (size_t l = 1; l < adjacencies.n_elem; ++l)
+ {
+ // If l no longer refers to an element in this column, advance the column
+ // number accordingly.
+ if (adjacencies(l) >= nextColIndex)
+ {
+ curPointInd = (size_t) (adjacencies(l) / atoms);
+ nextColIndex = (curPointInd + 1) * atoms;
+ }
+
+ ++neighborCounts(curPointInd);
+ }
+ }
+
+ // Handle the case of inactive atoms (atoms not used in the given coding).
+ std::vector<size_t> inactiveAtoms;
+
+ for (size_t j = 0; j < atoms; ++j)
+ {
+ if (accu(codes.row(j) != 0) == 0)
+ inactiveAtoms.push_back(j);
+ }
+
+ const size_t nInactiveAtoms = inactiveAtoms.size();
+ const size_t nActiveAtoms = atoms - nInactiveAtoms;
+
+ // Efficient construction of Z restricted to active atoms.
+ arma::mat matActiveZ;
+ if (nInactiveAtoms > 0)
+ {
+ math::RemoveRows(codes, inactiveAtoms, matActiveZ);
+ }
+
+ if (nInactiveAtoms > 0)
+ {
+ Log::Warn << "There are " << nInactiveAtoms
+ << " inactive atoms. They will be re-initialized randomly.\n";
+ }
+
+ Log::Debug << "Solving Dual via Newton's Method.\n";
+
+ // Solve using Newton's method in the dual - note that the final dot
+ // multiplication with inv(A) seems to be unavoidable. Although more
+ // expensive, the code written this way (we use solve()) should be more
+ // numerically stable than just using inv(A) for everything.
+ arma::vec dualVars = arma::zeros<arma::vec>(nActiveAtoms);
+
+ //vec dualVars = 1e-14 * ones<vec>(nActiveAtoms);
+
+ // Method used by feature sign code - fails miserably here. Perhaps the
+ // MATLAB optimizer fmincon does something clever?
+ //vec dualVars = 10.0 * randu(nActiveAtoms, 1);
+
+ //vec dualVars = diagvec(solve(dictionary, data * trans(codes))
+ // - codes * trans(codes));
+ //for (size_t i = 0; i < dualVars.n_elem; i++)
+ // if (dualVars(i) < 0)
+ // dualVars(i) = 0;
+
+ bool converged = false;
+
+ // If we have any inactive atoms, we must construct these differently.
+ arma::mat codesXT;
+ arma::mat codesZT;
+
+ if (inactiveAtoms.empty())
+ {
+ codesXT = codes * trans(data);
+ codesZT = codes * trans(codes);
+ }
+ else
+ {
+ codesXT = matActiveZ * trans(data);
+ codesZT = matActiveZ * trans(matActiveZ);
+ }
+
+ double normGradient;
+ double improvement;
+ for (size_t t = 1; !converged; ++t)
+ {
+ arma::mat A = codesZT + diagmat(dualVars);
+
+ arma::mat matAInvZXT = solve(A, codesXT);
+
+ arma::vec gradient = -arma::sum(arma::square(matAInvZXT), 1);
+ gradient += 1;
+
+ arma::mat hessian = -(-2 * (matAInvZXT * trans(matAInvZXT)) % inv(A));
+
+ arma::vec searchDirection = -solve(hessian, gradient);
+ //printf("%e\n", norm(searchDirection, 2));
+
+ // Armijo line search.
+ const double c = 1e-4;
+ double alpha = 1.0;
+ const double rho = 0.9;
+ double sufficientDecrease = c * dot(gradient, searchDirection);
+
+ while (true)
+ {
+ // Calculate objective.
+ double sumDualVars = sum(dualVars);
+ double fOld = -(-trace(trans(codesXT) * matAInvZXT) - sumDualVars);
+ double fNew = -(-trace(trans(codesXT) * solve(codesZT +
+ diagmat(dualVars + alpha * searchDirection), codesXT)) -
+ (sumDualVars + alpha * sum(searchDirection)));
+
+ if (fNew <= fOld + alpha * sufficientDecrease)
+ {
+ searchDirection = alpha * searchDirection;
+ improvement = fOld - fNew;
+ break;
+ }
+
+ alpha *= rho;
+ }
+
+ // Take step and print useful information.
+ dualVars += searchDirection;
+ normGradient = norm(gradient, 2);
+ Log::Debug << "Newton Method iteration " << t << ":" << std::endl;
+ Log::Debug << " Gradient norm: " << std::scientific << normGradient
+ << "." << std::endl;
+ Log::Debug << " Improvement: " << std::scientific << improvement << ".\n";
+
+ if (improvement < newtonTolerance)
+ converged = true;
+ }
+
+ if (inactiveAtoms.empty())
+ {
+ // Directly update dictionary.
+ dictionary = trans(solve(codesZT + diagmat(dualVars), codesXT));
+ }
+ else
+ {
+ arma::mat activeDictionary = trans(solve(codesZT +
+ diagmat(dualVars), codesXT));
+
+ // Update all atoms.
+ size_t currentInactiveIndex = 0;
+ for (size_t i = 0; i < atoms; ++i)
+ {
+ if (inactiveAtoms[currentInactiveIndex] == i)
+ {
+ // This atom is inactive. Reinitialize it randomly.
+ dictionary.col(i) = (data.col(math::RandInt(data.n_cols)) +
+ data.col(math::RandInt(data.n_cols)) +
+ data.col(math::RandInt(data.n_cols)));
+
+ dictionary.col(i) /= norm(dictionary.col(i), 2);
+
+ // Increment inactive index counter.
+ ++currentInactiveIndex;
+ }
+ else
+ {
+ // Update estimate.
+ dictionary.col(i) = activeDictionary.col(i - currentInactiveIndex);
+ }
+ }
+ }
+ //printf("final reconstruction error: %e\n", norm(data - dictionary * codes, "fro"));
+ return normGradient;
+}
+
+// Project each atom of the dictionary back into the unit ball (if necessary).
+template<typename DictionaryInitializer>
+void SparseCoding<DictionaryInitializer>::ProjectDictionary()
+{
+ for (size_t j = 0; j < atoms; j++)
+ {
+ double atomNorm = norm(dictionary.col(j), 2);
+ if (atomNorm > 1)
+ {
+ Log::Info << "Norm of atom " << j << " exceeds 1 (" << std::scientific
+ << atomNorm << "). Shrinking...\n";
+ dictionary.col(j) /= atomNorm;
+ }
+ }
+}
+
+// Compute the objective function.
+template<typename DictionaryInitializer>
+double SparseCoding<DictionaryInitializer>::Objective() const
+{
+ double l11NormZ = sum(sum(abs(codes)));
+ double froNormResidual = norm(data - (dictionary * codes), "fro");
+
+ if (lambda2 > 0)
+ {
+ double froNormZ = norm(codes, "fro");
+ return 0.5 * (std::pow(froNormResidual, 2.0) + (lambda2 *
+ std::pow(froNormZ, 2.0))) + (lambda1 * l11NormZ);
+ }
+ else // It can be simpler.
+ {
+ return 0.5 * std::pow(froNormResidual, 2.0) + lambda1 * l11NormZ;
+ }
+}
+
+}; // namespace sparse_coding
+}; // namespace mlpack
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/sparse_coding_main.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/methods/sparse_coding/sparse_coding_main.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/sparse_coding_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,175 +0,0 @@
-/**
- * @file sparse_coding_main.cpp
- * @author Nishant Mehta
- *
- * Executable for Sparse Coding.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include "sparse_coding.hpp"
-
-PROGRAM_INFO("Sparse Coding", "An implementation of Sparse Coding with "
- "Dictionary Learning, which achieves sparsity via an l1-norm regularizer on"
- " the codes (LASSO) or an (l1+l2)-norm regularizer on the codes (the "
- "Elastic Net). Given a dense data matrix X with n points and d dimensions,"
- " sparse coding seeks to find a dense dictionary matrix D with k atoms in "
- "d dimensions, and a sparse coding matrix Z with n points in k dimensions."
- "\n\n"
- "The original data matrix X can then be reconstructed as D * Z. Therefore,"
- " this program finds a representation of each point in X as a sparse linear"
- " combination of atoms in the dictionary D."
- "\n\n"
- "The sparse coding is found with an algorithm which alternates between a "
- "dictionary step, which updates the dictionary D, and a sparse coding step,"
- " which updates the sparse coding matrix."
- "\n\n"
- "To run this program, the input matrix X must be specified (with -i), along"
- " with the number of atoms in the dictionary (-k). An initial dictionary "
- "may also be specified with the --initial_dictionary option. The l1 and l2"
- " norm regularization parameters may be specified with -l and -L, "
- "respectively. For example, to run sparse coding on the dataset in "
- "data.csv using 200 atoms and an l1-regularization parameter of 0.1, saving"
- " the dictionary into dict.csv and the codes into codes.csv, use "
- "\n\n"
- "$ sparse_coding -i data.csv -k 200 -l 0.1 -d dict.csv -c codes.csv"
- "\n\n"
- "The maximum number of iterations may be specified with the -n option. "
- "Optionally, the input data matrix X can be normalized before coding with "
- "the -N option.");
-
-PARAM_STRING_REQ("input_file", "Filename of the input data.", "i");
-PARAM_INT_REQ("atoms", "Number of atoms in the dictionary.", "k");
-
-PARAM_DOUBLE("lambda1", "Sparse coding l1-norm regularization parameter.", "l",
- 0);
-PARAM_DOUBLE("lambda2", "Sparse coding l2-norm regularization parameter.", "L",
- 0);
-
-PARAM_INT("max_iterations", "Maximum number of iterations for sparse coding (0 "
- "indicates no limit).", "n", 0);
-
-PARAM_STRING("initial_dictionary", "Filename for optional initial dictionary.",
- "D", "");
-
-PARAM_STRING("dictionary_file", "Filename to save the output dictionary to.",
- "d", "dictionary.csv");
-PARAM_STRING("codes_file", "Filename to save the output sparse codes to.", "c",
- "codes.csv");
-
-PARAM_FLAG("normalize", "If set, the input data matrix will be normalized "
- "before coding.", "N");
-
-PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
-
-PARAM_DOUBLE("objective_tolerance", "Tolerance for convergence of the objective"
- " function.", "o", 0.01);
-PARAM_DOUBLE("newton_tolerance", "Tolerance for convergence of Newton method.",
- "w", 1e-6);
-
-using namespace arma;
-using namespace std;
-using namespace mlpack;
-using namespace mlpack::math;
-using namespace mlpack::sparse_coding;
-
-int main(int argc, char* argv[])
-{
- CLI::ParseCommandLine(argc, argv);
-
- if (CLI::GetParam<int>("seed") != 0)
- RandomSeed((size_t) CLI::GetParam<int>("seed"));
- else
- RandomSeed((size_t) std::time(NULL));
-
- const double lambda1 = CLI::GetParam<double>("lambda1");
- const double lambda2 = CLI::GetParam<double>("lambda2");
-
- const string inputFile = CLI::GetParam<string>("input_file");
- const string dictionaryFile = CLI::GetParam<string>("dictionary_file");
- const string codesFile = CLI::GetParam<string>("codes_file");
- const string initialDictionaryFile =
- CLI::GetParam<string>("initial_dictionary");
-
- const size_t maxIterations = CLI::GetParam<int>("max_iterations");
- const size_t atoms = CLI::GetParam<int>("atoms");
-
- const bool normalize = CLI::HasParam("normalize");
-
- const double objTolerance = CLI::GetParam<double>("objective_tolerance");
- const double newtonTolerance = CLI::GetParam<double>("newton_tolerance");
-
- mat matX;
- data::Load(inputFile, matX, true);
-
- Log::Info << "Loaded " << matX.n_cols << " points in " << matX.n_rows <<
- " dimensions." << endl;
-
- // Normalize each point if the user asked for it.
- if (normalize)
- {
- Log::Info << "Normalizing data before coding..." << std::endl;
- for (size_t i = 0; i < matX.n_cols; ++i)
- matX.col(i) /= norm(matX.col(i), 2);
- }
-
- // If there is an initial dictionary, be sure we do not initialize one.
- if (initialDictionaryFile != "")
- {
- SparseCoding<NothingInitializer> sc(matX, atoms, lambda1, lambda2);
-
- // Load initial dictionary directly into sparse coding object.
- data::Load(initialDictionaryFile, sc.Dictionary(), true);
-
- // Validate size of initial dictionary.
- if (sc.Dictionary().n_cols != atoms)
- {
- Log::Fatal << "The initial dictionary has " << sc.Dictionary().n_cols
- << " atoms, but the number of atoms was specified to be " << atoms
- << "!" << endl;
- }
-
- if (sc.Dictionary().n_rows != matX.n_rows)
- {
- Log::Fatal << "The initial dictionary has " << sc.Dictionary().n_rows
- << " dimensions, but the data has " << matX.n_rows << " dimensions!"
- << endl;
- }
-
- // Run sparse coding.
- sc.Encode(maxIterations, objTolerance, newtonTolerance);
-
- // Save the results.
- Log::Info << "Saving dictionary matrix to '" << dictionaryFile << "'.\n";
- data::Save(dictionaryFile, sc.Dictionary());
- Log::Info << "Saving sparse codes to '" << codesFile << "'.\n";
- data::Save(codesFile, sc.Codes());
- }
- else
- {
- // No initial dictionary.
- SparseCoding<> sc(matX, atoms, lambda1, lambda2);
-
- // Run sparse coding.
- sc.Encode(maxIterations, objTolerance, newtonTolerance);
-
- // Save the results.
- Log::Info << "Saving dictionary matrix to '" << dictionaryFile << "'.\n";
- data::Save(dictionaryFile, sc.Dictionary());
- Log::Info << "Saving sparse codes to '" << codesFile << "'.\n";
- data::Save(codesFile, sc.Codes());
- }
-}
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/sparse_coding_main.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/methods/sparse_coding/sparse_coding_main.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/sparse_coding_main.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/methods/sparse_coding/sparse_coding_main.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,175 @@
+/**
+ * @file sparse_coding_main.cpp
+ * @author Nishant Mehta
+ *
+ * Executable for Sparse Coding.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include "sparse_coding.hpp"
+
+PROGRAM_INFO("Sparse Coding", "An implementation of Sparse Coding with "
+ "Dictionary Learning, which achieves sparsity via an l1-norm regularizer on"
+ " the codes (LASSO) or an (l1+l2)-norm regularizer on the codes (the "
+ "Elastic Net). Given a dense data matrix X with n points and d dimensions,"
+ " sparse coding seeks to find a dense dictionary matrix D with k atoms in "
+ "d dimensions, and a sparse coding matrix Z with n points in k dimensions."
+ "\n\n"
+ "The original data matrix X can then be reconstructed as D * Z. Therefore,"
+ " this program finds a representation of each point in X as a sparse linear"
+ " combination of atoms in the dictionary D."
+ "\n\n"
+ "The sparse coding is found with an algorithm which alternates between a "
+ "dictionary step, which updates the dictionary D, and a sparse coding step,"
+ " which updates the sparse coding matrix."
+ "\n\n"
+ "To run this program, the input matrix X must be specified (with -i), along"
+ " with the number of atoms in the dictionary (-k). An initial dictionary "
+ "may also be specified with the --initial_dictionary option. The l1 and l2"
+ " norm regularization parameters may be specified with -l and -L, "
+ "respectively. For example, to run sparse coding on the dataset in "
+ "data.csv using 200 atoms and an l1-regularization parameter of 0.1, saving"
+ " the dictionary into dict.csv and the codes into codes.csv, use "
+ "\n\n"
+ "$ sparse_coding -i data.csv -k 200 -l 0.1 -d dict.csv -c codes.csv"
+ "\n\n"
+ "The maximum number of iterations may be specified with the -n option. "
+ "Optionally, the input data matrix X can be normalized before coding with "
+ "the -N option.");
+
+PARAM_STRING_REQ("input_file", "Filename of the input data.", "i");
+PARAM_INT_REQ("atoms", "Number of atoms in the dictionary.", "k");
+
+PARAM_DOUBLE("lambda1", "Sparse coding l1-norm regularization parameter.", "l",
+ 0);
+PARAM_DOUBLE("lambda2", "Sparse coding l2-norm regularization parameter.", "L",
+ 0);
+
+PARAM_INT("max_iterations", "Maximum number of iterations for sparse coding (0 "
+ "indicates no limit).", "n", 0);
+
+PARAM_STRING("initial_dictionary", "Filename for optional initial dictionary.",
+ "D", "");
+
+PARAM_STRING("dictionary_file", "Filename to save the output dictionary to.",
+ "d", "dictionary.csv");
+PARAM_STRING("codes_file", "Filename to save the output sparse codes to.", "c",
+ "codes.csv");
+
+PARAM_FLAG("normalize", "If set, the input data matrix will be normalized "
+ "before coding.", "N");
+
+PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
+
+PARAM_DOUBLE("objective_tolerance", "Tolerance for convergence of the objective"
+ " function.", "o", 0.01);
+PARAM_DOUBLE("newton_tolerance", "Tolerance for convergence of Newton method.",
+ "w", 1e-6);
+
+using namespace arma;
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::math;
+using namespace mlpack::sparse_coding;
+
+int main(int argc, char* argv[])
+{
+ CLI::ParseCommandLine(argc, argv);
+
+ if (CLI::GetParam<int>("seed") != 0)
+ RandomSeed((size_t) CLI::GetParam<int>("seed"));
+ else
+ RandomSeed((size_t) std::time(NULL));
+
+ const double lambda1 = CLI::GetParam<double>("lambda1");
+ const double lambda2 = CLI::GetParam<double>("lambda2");
+
+ const string inputFile = CLI::GetParam<string>("input_file");
+ const string dictionaryFile = CLI::GetParam<string>("dictionary_file");
+ const string codesFile = CLI::GetParam<string>("codes_file");
+ const string initialDictionaryFile =
+ CLI::GetParam<string>("initial_dictionary");
+
+ const size_t maxIterations = CLI::GetParam<int>("max_iterations");
+ const size_t atoms = CLI::GetParam<int>("atoms");
+
+ const bool normalize = CLI::HasParam("normalize");
+
+ const double objTolerance = CLI::GetParam<double>("objective_tolerance");
+ const double newtonTolerance = CLI::GetParam<double>("newton_tolerance");
+
+ mat matX;
+ data::Load(inputFile, matX, true);
+
+ Log::Info << "Loaded " << matX.n_cols << " points in " << matX.n_rows <<
+ " dimensions." << endl;
+
+ // Normalize each point if the user asked for it.
+ if (normalize)
+ {
+ Log::Info << "Normalizing data before coding..." << std::endl;
+ for (size_t i = 0; i < matX.n_cols; ++i)
+ matX.col(i) /= norm(matX.col(i), 2);
+ }
+
+ // If there is an initial dictionary, be sure we do not initialize one.
+ if (initialDictionaryFile != "")
+ {
+ SparseCoding<NothingInitializer> sc(matX, atoms, lambda1, lambda2);
+
+ // Load initial dictionary directly into sparse coding object.
+ data::Load(initialDictionaryFile, sc.Dictionary(), true);
+
+ // Validate size of initial dictionary.
+ if (sc.Dictionary().n_cols != atoms)
+ {
+ Log::Fatal << "The initial dictionary has " << sc.Dictionary().n_cols
+ << " atoms, but the number of atoms was specified to be " << atoms
+ << "!" << endl;
+ }
+
+ if (sc.Dictionary().n_rows != matX.n_rows)
+ {
+ Log::Fatal << "The initial dictionary has " << sc.Dictionary().n_rows
+ << " dimensions, but the data has " << matX.n_rows << " dimensions!"
+ << endl;
+ }
+
+ // Run sparse coding.
+ sc.Encode(maxIterations, objTolerance, newtonTolerance);
+
+ // Save the results.
+ Log::Info << "Saving dictionary matrix to '" << dictionaryFile << "'.\n";
+ data::Save(dictionaryFile, sc.Dictionary());
+ Log::Info << "Saving sparse codes to '" << codesFile << "'.\n";
+ data::Save(codesFile, sc.Codes());
+ }
+ else
+ {
+ // No initial dictionary.
+ SparseCoding<> sc(matX, atoms, lambda1, lambda2);
+
+ // Run sparse coding.
+ sc.Encode(maxIterations, objTolerance, newtonTolerance);
+
+ // Save the results.
+ Log::Info << "Saving dictionary matrix to '" << dictionaryFile << "'.\n";
+ data::Save(dictionaryFile, sc.Dictionary());
+ Log::Info << "Saving sparse codes to '" << codesFile << "'.\n";
+ data::Save(codesFile, sc.Codes());
+ }
+}
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/allkfn_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/allkfn_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/allkfn_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,445 +0,0 @@
-/**
- * @file allkfntest.cpp
- *
- * Tests for AllkFN (all-k-furthest-neighbors).
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::neighbor;
-
-BOOST_AUTO_TEST_SUITE(AllkFNTest);
-
-/**
- * Simple furthest-neighbors test with small, synthetic dataset. This is an
- * exhaustive test, which checks that each method for performing the calculation
- * (dual-tree, single-tree, naive) produces the correct results. An
- * eleven-point dataset and the ten furthest neighbors are taken. The dataset
- * is in one dimension for simplicity -- the correct functionality of distance
- * functions is not tested here.
- */
-BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
-{
- // Set up our data.
- arma::mat data(1, 11);
- data[0] = 0.05; // Row addressing is unnecessary (they are all 0).
- data[1] = 0.35;
- data[2] = 0.15;
- data[3] = 1.25;
- data[4] = 5.05;
- data[5] = -0.22;
- data[6] = -2.00;
- data[7] = -1.30;
- data[8] = 0.45;
- data[9] = 0.90;
- data[10] = 1.00;
-
- // We will loop through three times, one for each method of performing the
- // calculation. We'll always use 10 neighbors, so set that parameter.
- for (int i = 0; i < 3; i++)
- {
- AllkFN* allkfn;
- arma::mat dataMutable = data;
- switch (i)
- {
- case 0: // Use the dual-tree method.
- allkfn = new AllkFN(dataMutable, false, false, 1);
- break;
- case 1: // Use the single-tree method.
- allkfn = new AllkFN(dataMutable, false, true, 1);
- break;
- case 2: // Use the naive method.
- allkfn = new AllkFN(dataMutable, true);
- break;
- }
-
- // Now perform the actual calculation.
- arma::Mat<size_t> neighbors;
- arma::mat distances;
- allkfn->Search(10, neighbors, distances);
-
- // Now the exhaustive check for correctness. This will be long. We must
- // also remember that the distances returned are squared distances. As a
- // result, distance comparisons are written out as (distance * distance) for
- // readability.
-
- // Neighbors of point 0.
- BOOST_REQUIRE(neighbors(9, 0) == 2);
- BOOST_REQUIRE_CLOSE(distances(9, 0), 0.10, 1e-5);
- BOOST_REQUIRE(neighbors(8, 0) == 5);
- BOOST_REQUIRE_CLOSE(distances(8, 0), 0.27, 1e-5);
- BOOST_REQUIRE(neighbors(7, 0) == 1);
- BOOST_REQUIRE_CLOSE(distances(7, 0), 0.30, 1e-5);
- BOOST_REQUIRE(neighbors(6, 0) == 8);
- BOOST_REQUIRE_CLOSE(distances(6, 0), 0.40, 1e-5);
- BOOST_REQUIRE(neighbors(5, 0) == 9);
- BOOST_REQUIRE_CLOSE(distances(5, 0), 0.85, 1e-5);
- BOOST_REQUIRE(neighbors(4, 0) == 10);
- BOOST_REQUIRE_CLOSE(distances(4, 0), 0.95, 1e-5);
- BOOST_REQUIRE(neighbors(3, 0) == 3);
- BOOST_REQUIRE_CLOSE(distances(3, 0), 1.20, 1e-5);
- BOOST_REQUIRE(neighbors(2, 0) == 7);
- BOOST_REQUIRE_CLOSE(distances(2, 0), 1.35, 1e-5);
- BOOST_REQUIRE(neighbors(1, 0) == 6);
- BOOST_REQUIRE_CLOSE(distances(1, 0), 2.05, 1e-5);
- BOOST_REQUIRE(neighbors(0, 0) == 4);
- BOOST_REQUIRE_CLOSE(distances(0, 0), 5.00, 1e-5);
-
- // Neighbors of point 1.
- BOOST_REQUIRE(neighbors(9, 1) == 8);
- BOOST_REQUIRE_CLOSE(distances(9, 1), 0.10, 1e-5);
- BOOST_REQUIRE(neighbors(8, 1) == 2);
- BOOST_REQUIRE_CLOSE(distances(8, 1), 0.20, 1e-5);
- BOOST_REQUIRE(neighbors(7, 1) == 0);
- BOOST_REQUIRE_CLOSE(distances(7, 1), 0.30, 1e-5);
- BOOST_REQUIRE(neighbors(6, 1) == 9);
- BOOST_REQUIRE_CLOSE(distances(6, 1), 0.55, 1e-5);
- BOOST_REQUIRE(neighbors(5, 1) == 5);
- BOOST_REQUIRE_CLOSE(distances(5, 1), 0.57, 1e-5);
- BOOST_REQUIRE(neighbors(4, 1) == 10);
- BOOST_REQUIRE_CLOSE(distances(4, 1), 0.65, 1e-5);
- BOOST_REQUIRE(neighbors(3, 1) == 3);
- BOOST_REQUIRE_CLOSE(distances(3, 1), 0.90, 1e-5);
- BOOST_REQUIRE(neighbors(2, 1) == 7);
- BOOST_REQUIRE_CLOSE(distances(2, 1), 1.65, 1e-5);
- BOOST_REQUIRE(neighbors(1, 1) == 6);
- BOOST_REQUIRE_CLOSE(distances(1, 1), 2.35, 1e-5);
- BOOST_REQUIRE(neighbors(0, 1) == 4);
- BOOST_REQUIRE_CLOSE(distances(0, 1), 4.70, 1e-5);
-
- // Neighbors of point 2.
- BOOST_REQUIRE(neighbors(9, 2) == 0);
- BOOST_REQUIRE_CLOSE(distances(9, 2), 0.10, 1e-5);
- BOOST_REQUIRE(neighbors(8, 2) == 1);
- BOOST_REQUIRE_CLOSE(distances(8, 2), 0.20, 1e-5);
- BOOST_REQUIRE(neighbors(7, 2) == 8);
- BOOST_REQUIRE_CLOSE(distances(7, 2), 0.30, 1e-5);
- BOOST_REQUIRE(neighbors(6, 2) == 5);
- BOOST_REQUIRE_CLOSE(distances(6, 2), 0.37, 1e-5);
- BOOST_REQUIRE(neighbors(5, 2) == 9);
- BOOST_REQUIRE_CLOSE(distances(5, 2), 0.75, 1e-5);
- BOOST_REQUIRE(neighbors(4, 2) == 10);
- BOOST_REQUIRE_CLOSE(distances(4, 2), 0.85, 1e-5);
- BOOST_REQUIRE(neighbors(3, 2) == 3);
- BOOST_REQUIRE_CLOSE(distances(3, 2), 1.10, 1e-5);
- BOOST_REQUIRE(neighbors(2, 2) == 7);
- BOOST_REQUIRE_CLOSE(distances(2, 2), 1.45, 1e-5);
- BOOST_REQUIRE(neighbors(1, 2) == 6);
- BOOST_REQUIRE_CLOSE(distances(1, 2), 2.15, 1e-5);
- BOOST_REQUIRE(neighbors(0, 2) == 4);
- BOOST_REQUIRE_CLOSE(distances(0, 2), 4.90, 1e-5);
-
- // Neighbors of point 3.
- BOOST_REQUIRE(neighbors(9, 3) == 10);
- BOOST_REQUIRE_CLOSE(distances(9, 3), 0.25, 1e-5);
- BOOST_REQUIRE(neighbors(8, 3) == 9);
- BOOST_REQUIRE_CLOSE(distances(8, 3), 0.35, 1e-5);
- BOOST_REQUIRE(neighbors(7, 3) == 8);
- BOOST_REQUIRE_CLOSE(distances(7, 3), 0.80, 1e-5);
- BOOST_REQUIRE(neighbors(6, 3) == 1);
- BOOST_REQUIRE_CLOSE(distances(6, 3), 0.90, 1e-5);
- BOOST_REQUIRE(neighbors(5, 3) == 2);
- BOOST_REQUIRE_CLOSE(distances(5, 3), 1.10, 1e-5);
- BOOST_REQUIRE(neighbors(4, 3) == 0);
- BOOST_REQUIRE_CLOSE(distances(4, 3), 1.20, 1e-5);
- BOOST_REQUIRE(neighbors(3, 3) == 5);
- BOOST_REQUIRE_CLOSE(distances(3, 3), 1.47, 1e-5);
- BOOST_REQUIRE(neighbors(2, 3) == 7);
- BOOST_REQUIRE_CLOSE(distances(2, 3), 2.55, 1e-5);
- BOOST_REQUIRE(neighbors(1, 3) == 6);
- BOOST_REQUIRE_CLOSE(distances(1, 3), 3.25, 1e-5);
- BOOST_REQUIRE(neighbors(0, 3) == 4);
- BOOST_REQUIRE_CLOSE(distances(0, 3), 3.80, 1e-5);
-
- // Neighbors of point 4.
- BOOST_REQUIRE(neighbors(9, 4) == 3);
- BOOST_REQUIRE_CLOSE(distances(9, 4), 3.80, 1e-5);
- BOOST_REQUIRE(neighbors(8, 4) == 10);
- BOOST_REQUIRE_CLOSE(distances(8, 4), 4.05, 1e-5);
- BOOST_REQUIRE(neighbors(7, 4) == 9);
- BOOST_REQUIRE_CLOSE(distances(7, 4), 4.15, 1e-5);
- BOOST_REQUIRE(neighbors(6, 4) == 8);
- BOOST_REQUIRE_CLOSE(distances(6, 4), 4.60, 1e-5);
- BOOST_REQUIRE(neighbors(5, 4) == 1);
- BOOST_REQUIRE_CLOSE(distances(5, 4), 4.70, 1e-5);
- BOOST_REQUIRE(neighbors(4, 4) == 2);
- BOOST_REQUIRE_CLOSE(distances(4, 4), 4.90, 1e-5);
- BOOST_REQUIRE(neighbors(3, 4) == 0);
- BOOST_REQUIRE_CLOSE(distances(3, 4), 5.00, 1e-5);
- BOOST_REQUIRE(neighbors(2, 4) == 5);
- BOOST_REQUIRE_CLOSE(distances(2, 4), 5.27, 1e-5);
- BOOST_REQUIRE(neighbors(1, 4) == 7);
- BOOST_REQUIRE_CLOSE(distances(1, 4), 6.35, 1e-5);
- BOOST_REQUIRE(neighbors(0, 4) == 6);
- BOOST_REQUIRE_CLOSE(distances(0, 4), 7.05, 1e-5);
-
- // Neighbors of point 5.
- BOOST_REQUIRE(neighbors(9, 5) == 0);
- BOOST_REQUIRE_CLOSE(distances(9, 5), 0.27, 1e-5);
- BOOST_REQUIRE(neighbors(8, 5) == 2);
- BOOST_REQUIRE_CLOSE(distances(8, 5), 0.37, 1e-5);
- BOOST_REQUIRE(neighbors(7, 5) == 1);
- BOOST_REQUIRE_CLOSE(distances(7, 5), 0.57, 1e-5);
- BOOST_REQUIRE(neighbors(6, 5) == 8);
- BOOST_REQUIRE_CLOSE(distances(6, 5), 0.67, 1e-5);
- BOOST_REQUIRE(neighbors(5, 5) == 7);
- BOOST_REQUIRE_CLOSE(distances(5, 5), 1.08, 1e-5);
- BOOST_REQUIRE(neighbors(4, 5) == 9);
- BOOST_REQUIRE_CLOSE(distances(4, 5), 1.12, 1e-5);
- BOOST_REQUIRE(neighbors(3, 5) == 10);
- BOOST_REQUIRE_CLOSE(distances(3, 5), 1.22, 1e-5);
- BOOST_REQUIRE(neighbors(2, 5) == 3);
- BOOST_REQUIRE_CLOSE(distances(2, 5), 1.47, 1e-5);
- BOOST_REQUIRE(neighbors(1, 5) == 6);
- BOOST_REQUIRE_CLOSE(distances(1, 5), 1.78, 1e-5);
- BOOST_REQUIRE(neighbors(0, 5) == 4);
- BOOST_REQUIRE_CLOSE(distances(0, 5), 5.27, 1e-5);
-
- // Neighbors of point 6.
- BOOST_REQUIRE(neighbors(9, 6) == 7);
- BOOST_REQUIRE_CLOSE(distances(9, 6), 0.70, 1e-5);
- BOOST_REQUIRE(neighbors(8, 6) == 5);
- BOOST_REQUIRE_CLOSE(distances(8, 6), 1.78, 1e-5);
- BOOST_REQUIRE(neighbors(7, 6) == 0);
- BOOST_REQUIRE_CLOSE(distances(7, 6), 2.05, 1e-5);
- BOOST_REQUIRE(neighbors(6, 6) == 2);
- BOOST_REQUIRE_CLOSE(distances(6, 6), 2.15, 1e-5);
- BOOST_REQUIRE(neighbors(5, 6) == 1);
- BOOST_REQUIRE_CLOSE(distances(5, 6), 2.35, 1e-5);
- BOOST_REQUIRE(neighbors(4, 6) == 8);
- BOOST_REQUIRE_CLOSE(distances(4, 6), 2.45, 1e-5);
- BOOST_REQUIRE(neighbors(3, 6) == 9);
- BOOST_REQUIRE_CLOSE(distances(3, 6), 2.90, 1e-5);
- BOOST_REQUIRE(neighbors(2, 6) == 10);
- BOOST_REQUIRE_CLOSE(distances(2, 6), 3.00, 1e-5);
- BOOST_REQUIRE(neighbors(1, 6) == 3);
- BOOST_REQUIRE_CLOSE(distances(1, 6), 3.25, 1e-5);
- BOOST_REQUIRE(neighbors(0, 6) == 4);
- BOOST_REQUIRE_CLOSE(distances(0, 6), 7.05, 1e-5);
-
- // Neighbors of point 7.
- BOOST_REQUIRE(neighbors(9, 7) == 6);
- BOOST_REQUIRE_CLOSE(distances(9, 7), 0.70, 1e-5);
- BOOST_REQUIRE(neighbors(8, 7) == 5);
- BOOST_REQUIRE_CLOSE(distances(8, 7), 1.08, 1e-5);
- BOOST_REQUIRE(neighbors(7, 7) == 0);
- BOOST_REQUIRE_CLOSE(distances(7, 7), 1.35, 1e-5);
- BOOST_REQUIRE(neighbors(6, 7) == 2);
- BOOST_REQUIRE_CLOSE(distances(6, 7), 1.45, 1e-5);
- BOOST_REQUIRE(neighbors(5, 7) == 1);
- BOOST_REQUIRE_CLOSE(distances(5, 7), 1.65, 1e-5);
- BOOST_REQUIRE(neighbors(4, 7) == 8);
- BOOST_REQUIRE_CLOSE(distances(4, 7), 1.75, 1e-5);
- BOOST_REQUIRE(neighbors(3, 7) == 9);
- BOOST_REQUIRE_CLOSE(distances(3, 7), 2.20, 1e-5);
- BOOST_REQUIRE(neighbors(2, 7) == 10);
- BOOST_REQUIRE_CLOSE(distances(2, 7), 2.30, 1e-5);
- BOOST_REQUIRE(neighbors(1, 7) == 3);
- BOOST_REQUIRE_CLOSE(distances(1, 7), 2.55, 1e-5);
- BOOST_REQUIRE(neighbors(0, 7) == 4);
- BOOST_REQUIRE_CLOSE(distances(0, 7), 6.35, 1e-5);
-
- // Neighbors of point 8.
- BOOST_REQUIRE(neighbors(9, 8) == 1);
- BOOST_REQUIRE_CLOSE(distances(9, 8), 0.10, 1e-5);
- BOOST_REQUIRE(neighbors(8, 8) == 2);
- BOOST_REQUIRE_CLOSE(distances(8, 8), 0.30, 1e-5);
- BOOST_REQUIRE(neighbors(7, 8) == 0);
- BOOST_REQUIRE_CLOSE(distances(7, 8), 0.40, 1e-5);
- BOOST_REQUIRE(neighbors(6, 8) == 9);
- BOOST_REQUIRE_CLOSE(distances(6, 8), 0.45, 1e-5);
- BOOST_REQUIRE(neighbors(5, 8) == 10);
- BOOST_REQUIRE_CLOSE(distances(5, 8), 0.55, 1e-5);
- BOOST_REQUIRE(neighbors(4, 8) == 5);
- BOOST_REQUIRE_CLOSE(distances(4, 8), 0.67, 1e-5);
- BOOST_REQUIRE(neighbors(3, 8) == 3);
- BOOST_REQUIRE_CLOSE(distances(3, 8), 0.80, 1e-5);
- BOOST_REQUIRE(neighbors(2, 8) == 7);
- BOOST_REQUIRE_CLOSE(distances(2, 8), 1.75, 1e-5);
- BOOST_REQUIRE(neighbors(1, 8) == 6);
- BOOST_REQUIRE_CLOSE(distances(1, 8), 2.45, 1e-5);
- BOOST_REQUIRE(neighbors(0, 8) == 4);
- BOOST_REQUIRE_CLOSE(distances(0, 8), 4.60, 1e-5);
-
- // Neighbors of point 9.
- BOOST_REQUIRE(neighbors(9, 9) == 10);
- BOOST_REQUIRE_CLOSE(distances(9, 9), 0.10, 1e-5);
- BOOST_REQUIRE(neighbors(8, 9) == 3);
- BOOST_REQUIRE_CLOSE(distances(8, 9), 0.35, 1e-5);
- BOOST_REQUIRE(neighbors(7, 9) == 8);
- BOOST_REQUIRE_CLOSE(distances(7, 9), 0.45, 1e-5);
- BOOST_REQUIRE(neighbors(6, 9) == 1);
- BOOST_REQUIRE_CLOSE(distances(6, 9), 0.55, 1e-5);
- BOOST_REQUIRE(neighbors(5, 9) == 2);
- BOOST_REQUIRE_CLOSE(distances(5, 9), 0.75, 1e-5);
- BOOST_REQUIRE(neighbors(4, 9) == 0);
- BOOST_REQUIRE_CLOSE(distances(4, 9), 0.85, 1e-5);
- BOOST_REQUIRE(neighbors(3, 9) == 5);
- BOOST_REQUIRE_CLOSE(distances(3, 9), 1.12, 1e-5);
- BOOST_REQUIRE(neighbors(2, 9) == 7);
- BOOST_REQUIRE_CLOSE(distances(2, 9), 2.20, 1e-5);
- BOOST_REQUIRE(neighbors(1, 9) == 6);
- BOOST_REQUIRE_CLOSE(distances(1, 9), 2.90, 1e-5);
- BOOST_REQUIRE(neighbors(0, 9) == 4);
- BOOST_REQUIRE_CLOSE(distances(0, 9), 4.15, 1e-5);
-
- // Neighbors of point 10.
- BOOST_REQUIRE(neighbors(9, 10) == 9);
- BOOST_REQUIRE_CLOSE(distances(9, 10), 0.10, 1e-5);
- BOOST_REQUIRE(neighbors(8, 10) == 3);
- BOOST_REQUIRE_CLOSE(distances(8, 10), 0.25, 1e-5);
- BOOST_REQUIRE(neighbors(7, 10) == 8);
- BOOST_REQUIRE_CLOSE(distances(7, 10), 0.55, 1e-5);
- BOOST_REQUIRE(neighbors(6, 10) == 1);
- BOOST_REQUIRE_CLOSE(distances(6, 10), 0.65, 1e-5);
- BOOST_REQUIRE(neighbors(5, 10) == 2);
- BOOST_REQUIRE_CLOSE(distances(5, 10), 0.85, 1e-5);
- BOOST_REQUIRE(neighbors(4, 10) == 0);
- BOOST_REQUIRE_CLOSE(distances(4, 10), 0.95, 1e-5);
- BOOST_REQUIRE(neighbors(3, 10) == 5);
- BOOST_REQUIRE_CLOSE(distances(3, 10), 1.22, 1e-5);
- BOOST_REQUIRE(neighbors(2, 10) == 7);
- BOOST_REQUIRE_CLOSE(distances(2, 10), 2.30, 1e-5);
- BOOST_REQUIRE(neighbors(1, 10) == 6);
- BOOST_REQUIRE_CLOSE(distances(1, 10), 3.00, 1e-5);
- BOOST_REQUIRE(neighbors(0, 10) == 4);
- BOOST_REQUIRE_CLOSE(distances(0, 10), 4.05, 1e-5);
-
- // Clean the memory.
- delete allkfn;
- }
-}
-
-/**
- * Test the dual-tree furthest-neighbors method with the naive method. This
- * uses both a query and reference dataset.
- *
- * Errors are produced if the results are not identical.
- */
-BOOST_AUTO_TEST_CASE(DualTreeVsNaive1)
-{
- arma::mat dataForTree;
-
- // Hard-coded filename: bad!
- if (!data::Load("test_data_3_1000.csv", dataForTree))
- BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
-
- // Set up matrices to work with.
- arma::mat dualQuery(dataForTree);
- arma::mat dualReferences(dataForTree);
- arma::mat naiveQuery(dataForTree);
- arma::mat naiveReferences(dataForTree);
-
- AllkFN allkfn(dualQuery, dualReferences);
-
- AllkFN naive(naiveQuery, naiveReferences, true);
-
- arma::Mat<size_t> resultingNeighborsTree;
- arma::mat distancesTree;
- allkfn.Search(15, resultingNeighborsTree, distancesTree);
-
- arma::Mat<size_t> resultingNeighborsNaive;
- arma::mat distancesNaive;
- naive.Search(15, resultingNeighborsNaive, distancesNaive);
-
- for (size_t i = 0; i < resultingNeighborsTree.n_elem; i++)
- {
- BOOST_REQUIRE(resultingNeighborsTree[i] == resultingNeighborsNaive[i]);
- BOOST_REQUIRE_CLOSE(distancesTree[i], distancesNaive[i], 1e-5);
- }
-}
-
-/**
- * Test the dual-tree furthest-neighbors method with the naive method. This
- * uses only a reference dataset.
- *
- * Errors are produced if the results are not identical.
- */
-BOOST_AUTO_TEST_CASE(DualTreeVsNaive2)
-{
- arma::mat dataForTree;
-
- // Hard-coded filename: bad!
- // Code duplication: also bad!
- if (!data::Load("test_data_3_1000.csv", dataForTree))
- BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
-
- // Set up matrices to work with.
- arma::mat dualReferences(dataForTree);
- arma::mat naiveReferences(dataForTree);
-
- AllkFN allkfn(dualReferences);
-
- AllkFN naive(naiveReferences, true);
-
- arma::Mat<size_t> resultingNeighborsTree;
- arma::mat distancesTree;
- allkfn.Search(15, resultingNeighborsTree, distancesTree);
-
- arma::Mat<size_t> resultingNeighborsNaive;
- arma::mat distancesNaive;
- naive.Search(15, resultingNeighborsNaive, distancesNaive);
-
- for (size_t i = 0; i < resultingNeighborsTree.n_elem; i++)
- {
- BOOST_REQUIRE(resultingNeighborsTree[i] == resultingNeighborsNaive[i]);
- BOOST_REQUIRE_CLOSE(distancesTree[i], distancesNaive[i], 1e-5);
- }
-}
-
-/**
- * Test the single-tree furthest-neighbors method with the naive method. This
- * uses only a reference dataset.
- *
- * Errors are produced if the results are not identical.
- */
-BOOST_AUTO_TEST_CASE(SingleTreeVsNaive)
-{
- arma::mat dataForTree;
-
- // Hard-coded filename: bad!
- // Code duplication: also bad!
- if (!data::Load("test_data_3_1000.csv", dataForTree))
- BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
-
- arma::mat singleQuery(dataForTree);
- arma::mat naiveQuery(dataForTree);
-
- AllkFN allkfn(singleQuery, false, true);
-
- AllkFN naive(naiveQuery, true);
-
- arma::Mat<size_t> resultingNeighborsTree;
- arma::mat distancesTree;
- allkfn.Search(15, resultingNeighborsTree, distancesTree);
-
- arma::Mat<size_t> resultingNeighborsNaive;
- arma::mat distancesNaive;
- naive.Search(15, resultingNeighborsNaive, distancesNaive);
-
- for (size_t i = 0; i < resultingNeighborsTree.n_elem; i++)
- {
- BOOST_REQUIRE(resultingNeighborsTree[i] == resultingNeighborsNaive[i]);
- BOOST_REQUIRE_CLOSE(distancesTree[i], distancesNaive[i], 1e-5);
- }
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/allkfn_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/allkfn_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/allkfn_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/allkfn_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,445 @@
+/**
+ * @file allkfntest.cpp
+ *
+ * Tests for AllkFN (all-k-furthest-neighbors).
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::neighbor;
+
+BOOST_AUTO_TEST_SUITE(AllkFNTest);
+
+/**
+ * Simple furthest-neighbors test with small, synthetic dataset. This is an
+ * exhaustive test, which checks that each method for performing the calculation
+ * (dual-tree, single-tree, naive) produces the correct results. An
+ * eleven-point dataset and the ten furthest neighbors are taken. The dataset
+ * is in one dimension for simplicity -- the correct functionality of distance
+ * functions is not tested here.
+ */
+BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
+{
+ // Set up our data.
+ arma::mat data(1, 11);
+ data[0] = 0.05; // Row addressing is unnecessary (they are all 0).
+ data[1] = 0.35;
+ data[2] = 0.15;
+ data[3] = 1.25;
+ data[4] = 5.05;
+ data[5] = -0.22;
+ data[6] = -2.00;
+ data[7] = -1.30;
+ data[8] = 0.45;
+ data[9] = 0.90;
+ data[10] = 1.00;
+
+ // We will loop through three times, one for each method of performing the
+ // calculation. We'll always use 10 neighbors, so set that parameter.
+ for (int i = 0; i < 3; i++)
+ {
+ AllkFN* allkfn;
+ arma::mat dataMutable = data;
+ switch (i)
+ {
+ case 0: // Use the dual-tree method.
+ allkfn = new AllkFN(dataMutable, false, false, 1);
+ break;
+ case 1: // Use the single-tree method.
+ allkfn = new AllkFN(dataMutable, false, true, 1);
+ break;
+ case 2: // Use the naive method.
+ allkfn = new AllkFN(dataMutable, true);
+ break;
+ }
+
+ // Now perform the actual calculation.
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+ allkfn->Search(10, neighbors, distances);
+
+ // Now the exhaustive check for correctness. This will be long. We must
+ // also remember that the distances returned are squared distances. As a
+ // result, distance comparisons are written out as (distance * distance) for
+ // readability.
+
+ // Neighbors of point 0.
+ BOOST_REQUIRE(neighbors(9, 0) == 2);
+ BOOST_REQUIRE_CLOSE(distances(9, 0), 0.10, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 0) == 5);
+ BOOST_REQUIRE_CLOSE(distances(8, 0), 0.27, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 0) == 1);
+ BOOST_REQUIRE_CLOSE(distances(7, 0), 0.30, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 0) == 8);
+ BOOST_REQUIRE_CLOSE(distances(6, 0), 0.40, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 0) == 9);
+ BOOST_REQUIRE_CLOSE(distances(5, 0), 0.85, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 0) == 10);
+ BOOST_REQUIRE_CLOSE(distances(4, 0), 0.95, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 0) == 3);
+ BOOST_REQUIRE_CLOSE(distances(3, 0), 1.20, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 0) == 7);
+ BOOST_REQUIRE_CLOSE(distances(2, 0), 1.35, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 0) == 6);
+ BOOST_REQUIRE_CLOSE(distances(1, 0), 2.05, 1e-5);
+ BOOST_REQUIRE(neighbors(0, 0) == 4);
+ BOOST_REQUIRE_CLOSE(distances(0, 0), 5.00, 1e-5);
+
+ // Neighbors of point 1.
+ BOOST_REQUIRE(neighbors(9, 1) == 8);
+ BOOST_REQUIRE_CLOSE(distances(9, 1), 0.10, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 1) == 2);
+ BOOST_REQUIRE_CLOSE(distances(8, 1), 0.20, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 1) == 0);
+ BOOST_REQUIRE_CLOSE(distances(7, 1), 0.30, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 1) == 9);
+ BOOST_REQUIRE_CLOSE(distances(6, 1), 0.55, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 1) == 5);
+ BOOST_REQUIRE_CLOSE(distances(5, 1), 0.57, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 1) == 10);
+ BOOST_REQUIRE_CLOSE(distances(4, 1), 0.65, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 1) == 3);
+ BOOST_REQUIRE_CLOSE(distances(3, 1), 0.90, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 1) == 7);
+ BOOST_REQUIRE_CLOSE(distances(2, 1), 1.65, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 1) == 6);
+ BOOST_REQUIRE_CLOSE(distances(1, 1), 2.35, 1e-5);
+ BOOST_REQUIRE(neighbors(0, 1) == 4);
+ BOOST_REQUIRE_CLOSE(distances(0, 1), 4.70, 1e-5);
+
+ // Neighbors of point 2.
+ BOOST_REQUIRE(neighbors(9, 2) == 0);
+ BOOST_REQUIRE_CLOSE(distances(9, 2), 0.10, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 2) == 1);
+ BOOST_REQUIRE_CLOSE(distances(8, 2), 0.20, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 2) == 8);
+ BOOST_REQUIRE_CLOSE(distances(7, 2), 0.30, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 2) == 5);
+ BOOST_REQUIRE_CLOSE(distances(6, 2), 0.37, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 2) == 9);
+ BOOST_REQUIRE_CLOSE(distances(5, 2), 0.75, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 2) == 10);
+ BOOST_REQUIRE_CLOSE(distances(4, 2), 0.85, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 2) == 3);
+ BOOST_REQUIRE_CLOSE(distances(3, 2), 1.10, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 2) == 7);
+ BOOST_REQUIRE_CLOSE(distances(2, 2), 1.45, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 2) == 6);
+ BOOST_REQUIRE_CLOSE(distances(1, 2), 2.15, 1e-5);
+ BOOST_REQUIRE(neighbors(0, 2) == 4);
+ BOOST_REQUIRE_CLOSE(distances(0, 2), 4.90, 1e-5);
+
+ // Neighbors of point 3.
+ BOOST_REQUIRE(neighbors(9, 3) == 10);
+ BOOST_REQUIRE_CLOSE(distances(9, 3), 0.25, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 3) == 9);
+ BOOST_REQUIRE_CLOSE(distances(8, 3), 0.35, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 3) == 8);
+ BOOST_REQUIRE_CLOSE(distances(7, 3), 0.80, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 3) == 1);
+ BOOST_REQUIRE_CLOSE(distances(6, 3), 0.90, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 3) == 2);
+ BOOST_REQUIRE_CLOSE(distances(5, 3), 1.10, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 3) == 0);
+ BOOST_REQUIRE_CLOSE(distances(4, 3), 1.20, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 3) == 5);
+ BOOST_REQUIRE_CLOSE(distances(3, 3), 1.47, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 3) == 7);
+ BOOST_REQUIRE_CLOSE(distances(2, 3), 2.55, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 3) == 6);
+ BOOST_REQUIRE_CLOSE(distances(1, 3), 3.25, 1e-5);
+ BOOST_REQUIRE(neighbors(0, 3) == 4);
+ BOOST_REQUIRE_CLOSE(distances(0, 3), 3.80, 1e-5);
+
+ // Neighbors of point 4.
+ BOOST_REQUIRE(neighbors(9, 4) == 3);
+ BOOST_REQUIRE_CLOSE(distances(9, 4), 3.80, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 4) == 10);
+ BOOST_REQUIRE_CLOSE(distances(8, 4), 4.05, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 4) == 9);
+ BOOST_REQUIRE_CLOSE(distances(7, 4), 4.15, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 4) == 8);
+ BOOST_REQUIRE_CLOSE(distances(6, 4), 4.60, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 4) == 1);
+ BOOST_REQUIRE_CLOSE(distances(5, 4), 4.70, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 4) == 2);
+ BOOST_REQUIRE_CLOSE(distances(4, 4), 4.90, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 4) == 0);
+ BOOST_REQUIRE_CLOSE(distances(3, 4), 5.00, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 4) == 5);
+ BOOST_REQUIRE_CLOSE(distances(2, 4), 5.27, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 4) == 7);
+ BOOST_REQUIRE_CLOSE(distances(1, 4), 6.35, 1e-5);
+ BOOST_REQUIRE(neighbors(0, 4) == 6);
+ BOOST_REQUIRE_CLOSE(distances(0, 4), 7.05, 1e-5);
+
+ // Neighbors of point 5.
+ BOOST_REQUIRE(neighbors(9, 5) == 0);
+ BOOST_REQUIRE_CLOSE(distances(9, 5), 0.27, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 5) == 2);
+ BOOST_REQUIRE_CLOSE(distances(8, 5), 0.37, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 5) == 1);
+ BOOST_REQUIRE_CLOSE(distances(7, 5), 0.57, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 5) == 8);
+ BOOST_REQUIRE_CLOSE(distances(6, 5), 0.67, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 5) == 7);
+ BOOST_REQUIRE_CLOSE(distances(5, 5), 1.08, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 5) == 9);
+ BOOST_REQUIRE_CLOSE(distances(4, 5), 1.12, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 5) == 10);
+ BOOST_REQUIRE_CLOSE(distances(3, 5), 1.22, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 5) == 3);
+ BOOST_REQUIRE_CLOSE(distances(2, 5), 1.47, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 5) == 6);
+ BOOST_REQUIRE_CLOSE(distances(1, 5), 1.78, 1e-5);
+ BOOST_REQUIRE(neighbors(0, 5) == 4);
+ BOOST_REQUIRE_CLOSE(distances(0, 5), 5.27, 1e-5);
+
+ // Neighbors of point 6.
+ BOOST_REQUIRE(neighbors(9, 6) == 7);
+ BOOST_REQUIRE_CLOSE(distances(9, 6), 0.70, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 6) == 5);
+ BOOST_REQUIRE_CLOSE(distances(8, 6), 1.78, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 6) == 0);
+ BOOST_REQUIRE_CLOSE(distances(7, 6), 2.05, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 6) == 2);
+ BOOST_REQUIRE_CLOSE(distances(6, 6), 2.15, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 6) == 1);
+ BOOST_REQUIRE_CLOSE(distances(5, 6), 2.35, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 6) == 8);
+ BOOST_REQUIRE_CLOSE(distances(4, 6), 2.45, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 6) == 9);
+ BOOST_REQUIRE_CLOSE(distances(3, 6), 2.90, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 6) == 10);
+ BOOST_REQUIRE_CLOSE(distances(2, 6), 3.00, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 6) == 3);
+ BOOST_REQUIRE_CLOSE(distances(1, 6), 3.25, 1e-5);
+ BOOST_REQUIRE(neighbors(0, 6) == 4);
+ BOOST_REQUIRE_CLOSE(distances(0, 6), 7.05, 1e-5);
+
+ // Neighbors of point 7.
+ BOOST_REQUIRE(neighbors(9, 7) == 6);
+ BOOST_REQUIRE_CLOSE(distances(9, 7), 0.70, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 7) == 5);
+ BOOST_REQUIRE_CLOSE(distances(8, 7), 1.08, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 7) == 0);
+ BOOST_REQUIRE_CLOSE(distances(7, 7), 1.35, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 7) == 2);
+ BOOST_REQUIRE_CLOSE(distances(6, 7), 1.45, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 7) == 1);
+ BOOST_REQUIRE_CLOSE(distances(5, 7), 1.65, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 7) == 8);
+ BOOST_REQUIRE_CLOSE(distances(4, 7), 1.75, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 7) == 9);
+ BOOST_REQUIRE_CLOSE(distances(3, 7), 2.20, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 7) == 10);
+ BOOST_REQUIRE_CLOSE(distances(2, 7), 2.30, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 7) == 3);
+ BOOST_REQUIRE_CLOSE(distances(1, 7), 2.55, 1e-5);
+ BOOST_REQUIRE(neighbors(0, 7) == 4);
+ BOOST_REQUIRE_CLOSE(distances(0, 7), 6.35, 1e-5);
+
+ // Neighbors of point 8.
+ BOOST_REQUIRE(neighbors(9, 8) == 1);
+ BOOST_REQUIRE_CLOSE(distances(9, 8), 0.10, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 8) == 2);
+ BOOST_REQUIRE_CLOSE(distances(8, 8), 0.30, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 8) == 0);
+ BOOST_REQUIRE_CLOSE(distances(7, 8), 0.40, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 8) == 9);
+ BOOST_REQUIRE_CLOSE(distances(6, 8), 0.45, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 8) == 10);
+ BOOST_REQUIRE_CLOSE(distances(5, 8), 0.55, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 8) == 5);
+ BOOST_REQUIRE_CLOSE(distances(4, 8), 0.67, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 8) == 3);
+ BOOST_REQUIRE_CLOSE(distances(3, 8), 0.80, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 8) == 7);
+ BOOST_REQUIRE_CLOSE(distances(2, 8), 1.75, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 8) == 6);
+ BOOST_REQUIRE_CLOSE(distances(1, 8), 2.45, 1e-5);
+ BOOST_REQUIRE(neighbors(0, 8) == 4);
+ BOOST_REQUIRE_CLOSE(distances(0, 8), 4.60, 1e-5);
+
+ // Neighbors of point 9.
+ BOOST_REQUIRE(neighbors(9, 9) == 10);
+ BOOST_REQUIRE_CLOSE(distances(9, 9), 0.10, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 9) == 3);
+ BOOST_REQUIRE_CLOSE(distances(8, 9), 0.35, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 9) == 8);
+ BOOST_REQUIRE_CLOSE(distances(7, 9), 0.45, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 9) == 1);
+ BOOST_REQUIRE_CLOSE(distances(6, 9), 0.55, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 9) == 2);
+ BOOST_REQUIRE_CLOSE(distances(5, 9), 0.75, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 9) == 0);
+ BOOST_REQUIRE_CLOSE(distances(4, 9), 0.85, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 9) == 5);
+ BOOST_REQUIRE_CLOSE(distances(3, 9), 1.12, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 9) == 7);
+ BOOST_REQUIRE_CLOSE(distances(2, 9), 2.20, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 9) == 6);
+ BOOST_REQUIRE_CLOSE(distances(1, 9), 2.90, 1e-5);
+ BOOST_REQUIRE(neighbors(0, 9) == 4);
+ BOOST_REQUIRE_CLOSE(distances(0, 9), 4.15, 1e-5);
+
+ // Neighbors of point 10.
+ BOOST_REQUIRE(neighbors(9, 10) == 9);
+ BOOST_REQUIRE_CLOSE(distances(9, 10), 0.10, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 10) == 3);
+ BOOST_REQUIRE_CLOSE(distances(8, 10), 0.25, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 10) == 8);
+ BOOST_REQUIRE_CLOSE(distances(7, 10), 0.55, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 10) == 1);
+ BOOST_REQUIRE_CLOSE(distances(6, 10), 0.65, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 10) == 2);
+ BOOST_REQUIRE_CLOSE(distances(5, 10), 0.85, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 10) == 0);
+ BOOST_REQUIRE_CLOSE(distances(4, 10), 0.95, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 10) == 5);
+ BOOST_REQUIRE_CLOSE(distances(3, 10), 1.22, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 10) == 7);
+ BOOST_REQUIRE_CLOSE(distances(2, 10), 2.30, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 10) == 6);
+ BOOST_REQUIRE_CLOSE(distances(1, 10), 3.00, 1e-5);
+ BOOST_REQUIRE(neighbors(0, 10) == 4);
+ BOOST_REQUIRE_CLOSE(distances(0, 10), 4.05, 1e-5);
+
+ // Clean the memory.
+ delete allkfn;
+ }
+}
+
+/**
+ * Test the dual-tree furthest-neighbors method with the naive method. This
+ * uses both a query and reference dataset.
+ *
+ * Errors are produced if the results are not identical.
+ */
+BOOST_AUTO_TEST_CASE(DualTreeVsNaive1)
+{
+ arma::mat dataForTree;
+
+ // Hard-coded filename: bad!
+ if (!data::Load("test_data_3_1000.csv", dataForTree))
+ BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
+
+ // Set up matrices to work with.
+ arma::mat dualQuery(dataForTree);
+ arma::mat dualReferences(dataForTree);
+ arma::mat naiveQuery(dataForTree);
+ arma::mat naiveReferences(dataForTree);
+
+ AllkFN allkfn(dualQuery, dualReferences);
+
+ AllkFN naive(naiveQuery, naiveReferences, true);
+
+ arma::Mat<size_t> resultingNeighborsTree;
+ arma::mat distancesTree;
+ allkfn.Search(15, resultingNeighborsTree, distancesTree);
+
+ arma::Mat<size_t> resultingNeighborsNaive;
+ arma::mat distancesNaive;
+ naive.Search(15, resultingNeighborsNaive, distancesNaive);
+
+ for (size_t i = 0; i < resultingNeighborsTree.n_elem; i++)
+ {
+ BOOST_REQUIRE(resultingNeighborsTree[i] == resultingNeighborsNaive[i]);
+ BOOST_REQUIRE_CLOSE(distancesTree[i], distancesNaive[i], 1e-5);
+ }
+}
+
+/**
+ * Test the dual-tree furthest-neighbors method with the naive method. This
+ * uses only a reference dataset.
+ *
+ * Errors are produced if the results are not identical.
+ */
+BOOST_AUTO_TEST_CASE(DualTreeVsNaive2)
+{
+ arma::mat dataForTree;
+
+ // Hard-coded filename: bad!
+ // Code duplication: also bad!
+ if (!data::Load("test_data_3_1000.csv", dataForTree))
+ BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
+
+ // Set up matrices to work with.
+ arma::mat dualReferences(dataForTree);
+ arma::mat naiveReferences(dataForTree);
+
+ AllkFN allkfn(dualReferences);
+
+ AllkFN naive(naiveReferences, true);
+
+ arma::Mat<size_t> resultingNeighborsTree;
+ arma::mat distancesTree;
+ allkfn.Search(15, resultingNeighborsTree, distancesTree);
+
+ arma::Mat<size_t> resultingNeighborsNaive;
+ arma::mat distancesNaive;
+ naive.Search(15, resultingNeighborsNaive, distancesNaive);
+
+ for (size_t i = 0; i < resultingNeighborsTree.n_elem; i++)
+ {
+ BOOST_REQUIRE(resultingNeighborsTree[i] == resultingNeighborsNaive[i]);
+ BOOST_REQUIRE_CLOSE(distancesTree[i], distancesNaive[i], 1e-5);
+ }
+}
+
+/**
+ * Test the single-tree furthest-neighbors method with the naive method. This
+ * uses only a reference dataset.
+ *
+ * Errors are produced if the results are not identical.
+ */
+BOOST_AUTO_TEST_CASE(SingleTreeVsNaive)
+{
+ arma::mat dataForTree;
+
+ // Hard-coded filename: bad!
+ // Code duplication: also bad!
+ if (!data::Load("test_data_3_1000.csv", dataForTree))
+ BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
+
+ arma::mat singleQuery(dataForTree);
+ arma::mat naiveQuery(dataForTree);
+
+ AllkFN allkfn(singleQuery, false, true);
+
+ AllkFN naive(naiveQuery, true);
+
+ arma::Mat<size_t> resultingNeighborsTree;
+ arma::mat distancesTree;
+ allkfn.Search(15, resultingNeighborsTree, distancesTree);
+
+ arma::Mat<size_t> resultingNeighborsNaive;
+ arma::mat distancesNaive;
+ naive.Search(15, resultingNeighborsNaive, distancesNaive);
+
+ for (size_t i = 0; i < resultingNeighborsTree.n_elem; i++)
+ {
+ BOOST_REQUIRE(resultingNeighborsTree[i] == resultingNeighborsNaive[i]);
+ BOOST_REQUIRE_CLOSE(distancesTree[i], distancesNaive[i], 1e-5);
+ }
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/allknn_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/allknn_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/allknn_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,682 +0,0 @@
-/**
- * @file allknn_test.cpp
- *
- * Test file for AllkNN class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
-#include <mlpack/methods/neighbor_search/unmap.hpp>
-#include <mlpack/core/tree/cover_tree.hpp>
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::neighbor;
-
-BOOST_AUTO_TEST_SUITE(AllkNNTest);
-
-/**
- * Test that Unmap() works in the dual-tree case (see unmap.hpp).
- */
-BOOST_AUTO_TEST_CASE(DualTreeUnmapTest)
-{
- std::vector<size_t> refMap;
- refMap.push_back(3);
- refMap.push_back(4);
- refMap.push_back(1);
- refMap.push_back(2);
- refMap.push_back(0);
-
- std::vector<size_t> queryMap;
- queryMap.push_back(2);
- queryMap.push_back(0);
- queryMap.push_back(4);
- queryMap.push_back(3);
- queryMap.push_back(1);
- queryMap.push_back(5);
-
- // Now generate some results. 6 queries, 5 references.
- arma::Mat<size_t> neighbors("3 1 2 0 4;"
- "1 0 2 3 4;"
- "0 1 2 3 4;"
- "4 1 0 3 2;"
- "3 0 4 1 2;"
- "3 0 4 1 2;");
- neighbors = neighbors.t();
-
- // Integer distances will work fine here.
- arma::mat distances("3 1 2 0 4;"
- "1 0 2 3 4;"
- "0 1 2 3 4;"
- "4 1 0 3 2;"
- "3 0 4 1 2;"
- "3 0 4 1 2;");
- distances = distances.t();
-
- // This is what the results should be when they are unmapped.
- arma::Mat<size_t> correctNeighbors("4 3 1 2 0;"
- "2 3 0 4 1;"
- "2 4 1 3 0;"
- "0 4 3 2 1;"
- "3 4 1 2 0;"
- "2 3 0 4 1;");
- correctNeighbors = correctNeighbors.t();
-
- arma::mat correctDistances("1 0 2 3 4;"
- "3 0 4 1 2;"
- "3 1 2 0 4;"
- "4 1 0 3 2;"
- "0 1 2 3 4;"
- "3 0 4 1 2;");
- correctDistances = correctDistances.t();
-
- // Perform the unmapping.
- arma::Mat<size_t> neighborsOut;
- arma::mat distancesOut;
-
- Unmap(neighbors, distances, refMap, queryMap, neighborsOut, distancesOut);
-
- for (size_t i = 0; i < correctNeighbors.n_elem; ++i)
- {
- BOOST_REQUIRE_EQUAL(neighborsOut[i], correctNeighbors[i]);
- BOOST_REQUIRE_CLOSE(distancesOut[i], correctDistances[i], 1e-5);
- }
-
- // Now try taking the square root.
- Unmap(neighbors, distances, refMap, queryMap, neighborsOut, distancesOut,
- true);
-
- for (size_t i = 0; i < correctNeighbors.n_elem; ++i)
- {
- BOOST_REQUIRE_EQUAL(neighborsOut[i], correctNeighbors[i]);
- BOOST_REQUIRE_CLOSE(distancesOut[i], sqrt(correctDistances[i]), 1e-5);
- }
-}
-
-/**
- * Check that Unmap() works in the single-tree case.
- */
-BOOST_AUTO_TEST_CASE(SingleTreeUnmapTest)
-{
- std::vector<size_t> refMap;
- refMap.push_back(3);
- refMap.push_back(4);
- refMap.push_back(1);
- refMap.push_back(2);
- refMap.push_back(0);
-
- // Now generate some results. 6 queries, 5 references.
- arma::Mat<size_t> neighbors("3 1 2 0 4;"
- "1 0 2 3 4;"
- "0 1 2 3 4;"
- "4 1 0 3 2;"
- "3 0 4 1 2;"
- "3 0 4 1 2;");
- neighbors = neighbors.t();
-
- // Integer distances will work fine here.
- arma::mat distances("3 1 2 0 4;"
- "1 0 2 3 4;"
- "0 1 2 3 4;"
- "4 1 0 3 2;"
- "3 0 4 1 2;"
- "3 0 4 1 2;");
- distances = distances.t();
-
- // This is what the results should be when they are unmapped.
- arma::Mat<size_t> correctNeighbors("2 4 1 3 0;"
- "4 3 1 2 0;"
- "3 4 1 2 0;"
- "0 4 3 2 1;"
- "2 3 0 4 1;"
- "2 3 0 4 1;");
- correctNeighbors = correctNeighbors.t();
-
- arma::mat correctDistances = distances;
-
- // Perform the unmapping.
- arma::Mat<size_t> neighborsOut;
- arma::mat distancesOut;
-
- Unmap(neighbors, distances, refMap, neighborsOut, distancesOut);
-
- for (size_t i = 0; i < correctNeighbors.n_elem; ++i)
- {
- BOOST_REQUIRE_EQUAL(neighborsOut[i], correctNeighbors[i]);
- BOOST_REQUIRE_CLOSE(distancesOut[i], correctDistances[i], 1e-5);
- }
-
- // Now try taking the square root.
- Unmap(neighbors, distances, refMap, neighborsOut, distancesOut, true);
-
- for (size_t i = 0; i < correctNeighbors.n_elem; ++i)
- {
- BOOST_REQUIRE_EQUAL(neighborsOut[i], correctNeighbors[i]);
- BOOST_REQUIRE_CLOSE(distancesOut[i], sqrt(correctDistances[i]), 1e-5);
- }
-}
-
-/**
- * Simple nearest-neighbors test with small, synthetic dataset. This is an
- * exhaustive test, which checks that each method for performing the calculation
- * (dual-tree, single-tree, naive) produces the correct results. An
- * eleven-point dataset and the ten nearest neighbors are taken. The dataset is
- * in one dimension for simplicity -- the correct functionality of distance
- * functions is not tested here.
- */
-BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
-{
- // Set up our data.
- arma::mat data(1, 11);
- data[0] = 0.05; // Row addressing is unnecessary (they are all 0).
- data[1] = 0.35;
- data[2] = 0.15;
- data[3] = 1.25;
- data[4] = 5.05;
- data[5] = -0.22;
- data[6] = -2.00;
- data[7] = -1.30;
- data[8] = 0.45;
- data[9] = 0.90;
- data[10] = 1.00;
-
- // We will loop through three times, one for each method of performing the
- // calculation.
- for (int i = 0; i < 3; i++)
- {
- AllkNN* allknn;
- arma::mat dataMutable = data;
- switch (i)
- {
- case 0: // Use the dual-tree method.
- allknn = new AllkNN(dataMutable, false, false, 1);
- break;
- case 1: // Use the single-tree method.
- allknn = new AllkNN(dataMutable, false, true, 1);
- break;
- case 2: // Use the naive method.
- allknn = new AllkNN(dataMutable, true);
- break;
- }
-
- // Now perform the actual calculation.
- arma::Mat<size_t> neighbors;
- arma::mat distances;
- allknn->Search(10, neighbors, distances);
-
- // Now the exhaustive check for correctness. This will be long. We must
- // also remember that the distances returned are squared distances. As a
- // result, distance comparisons are written out as (distance * distance) for
- // readability.
-
- // Neighbors of point 0.
- BOOST_REQUIRE(neighbors(0, 0) == 2);
- BOOST_REQUIRE_CLOSE(distances(0, 0), 0.10, 1e-5);
- BOOST_REQUIRE(neighbors(1, 0) == 5);
- BOOST_REQUIRE_CLOSE(distances(1, 0), 0.27, 1e-5);
- BOOST_REQUIRE(neighbors(2, 0) == 1);
- BOOST_REQUIRE_CLOSE(distances(2, 0), 0.30, 1e-5);
- BOOST_REQUIRE(neighbors(3, 0) == 8);
- BOOST_REQUIRE_CLOSE(distances(3, 0), 0.40, 1e-5);
- BOOST_REQUIRE(neighbors(4, 0) == 9);
- BOOST_REQUIRE_CLOSE(distances(4, 0), 0.85, 1e-5);
- BOOST_REQUIRE(neighbors(5, 0) == 10);
- BOOST_REQUIRE_CLOSE(distances(5, 0), 0.95, 1e-5);
- BOOST_REQUIRE(neighbors(6, 0) == 3);
- BOOST_REQUIRE_CLOSE(distances(6, 0), 1.20, 1e-5);
- BOOST_REQUIRE(neighbors(7, 0) == 7);
- BOOST_REQUIRE_CLOSE(distances(7, 0), 1.35, 1e-5);
- BOOST_REQUIRE(neighbors(8, 0) == 6);
- BOOST_REQUIRE_CLOSE(distances(8, 0), 2.05, 1e-5);
- BOOST_REQUIRE(neighbors(9, 0) == 4);
- BOOST_REQUIRE_CLOSE(distances(9, 0), 5.00, 1e-5);
-
- // Neighbors of point 1.
- BOOST_REQUIRE(neighbors(0, 1) == 8);
- BOOST_REQUIRE_CLOSE(distances(0, 1), 0.10, 1e-5);
- BOOST_REQUIRE(neighbors(1, 1) == 2);
- BOOST_REQUIRE_CLOSE(distances(1, 1), 0.20, 1e-5);
- BOOST_REQUIRE(neighbors(2, 1) == 0);
- BOOST_REQUIRE_CLOSE(distances(2, 1), 0.30, 1e-5);
- BOOST_REQUIRE(neighbors(3, 1) == 9);
- BOOST_REQUIRE_CLOSE(distances(3, 1), 0.55, 1e-5);
- BOOST_REQUIRE(neighbors(4, 1) == 5);
- BOOST_REQUIRE_CLOSE(distances(4, 1), 0.57, 1e-5);
- BOOST_REQUIRE(neighbors(5, 1) == 10);
- BOOST_REQUIRE_CLOSE(distances(5, 1), 0.65, 1e-5);
- BOOST_REQUIRE(neighbors(6, 1) == 3);
- BOOST_REQUIRE_CLOSE(distances(6, 1), 0.90, 1e-5);
- BOOST_REQUIRE(neighbors(7, 1) == 7);
- BOOST_REQUIRE_CLOSE(distances(7, 1), 1.65, 1e-5);
- BOOST_REQUIRE(neighbors(8, 1) == 6);
- BOOST_REQUIRE_CLOSE(distances(8, 1), 2.35, 1e-5);
- BOOST_REQUIRE(neighbors(9, 1) == 4);
- BOOST_REQUIRE_CLOSE(distances(9, 1), 4.70, 1e-5);
-
- // Neighbors of point 2.
- BOOST_REQUIRE(neighbors(0, 2) == 0);
- BOOST_REQUIRE_CLOSE(distances(0, 2), 0.10, 1e-5);
- BOOST_REQUIRE(neighbors(1, 2) == 1);
- BOOST_REQUIRE_CLOSE(distances(1, 2), 0.20, 1e-5);
- BOOST_REQUIRE(neighbors(2, 2) == 8);
- BOOST_REQUIRE_CLOSE(distances(2, 2), 0.30, 1e-5);
- BOOST_REQUIRE(neighbors(3, 2) == 5);
- BOOST_REQUIRE_CLOSE(distances(3, 2), 0.37, 1e-5);
- BOOST_REQUIRE(neighbors(4, 2) == 9);
- BOOST_REQUIRE_CLOSE(distances(4, 2), 0.75, 1e-5);
- BOOST_REQUIRE(neighbors(5, 2) == 10);
- BOOST_REQUIRE_CLOSE(distances(5, 2), 0.85, 1e-5);
- BOOST_REQUIRE(neighbors(6, 2) == 3);
- BOOST_REQUIRE_CLOSE(distances(6, 2), 1.10, 1e-5);
- BOOST_REQUIRE(neighbors(7, 2) == 7);
- BOOST_REQUIRE_CLOSE(distances(7, 2), 1.45, 1e-5);
- BOOST_REQUIRE(neighbors(8, 2) == 6);
- BOOST_REQUIRE_CLOSE(distances(8, 2), 2.15, 1e-5);
- BOOST_REQUIRE(neighbors(9, 2) == 4);
- BOOST_REQUIRE_CLOSE(distances(9, 2), 4.90, 1e-5);
-
- // Neighbors of point 3.
- BOOST_REQUIRE(neighbors(0, 3) == 10);
- BOOST_REQUIRE_CLOSE(distances(0, 3), 0.25, 1e-5);
- BOOST_REQUIRE(neighbors(1, 3) == 9);
- BOOST_REQUIRE_CLOSE(distances(1, 3), 0.35, 1e-5);
- BOOST_REQUIRE(neighbors(2, 3) == 8);
- BOOST_REQUIRE_CLOSE(distances(2, 3), 0.80, 1e-5);
- BOOST_REQUIRE(neighbors(3, 3) == 1);
- BOOST_REQUIRE_CLOSE(distances(3, 3), 0.90, 1e-5);
- BOOST_REQUIRE(neighbors(4, 3) == 2);
- BOOST_REQUIRE_CLOSE(distances(4, 3), 1.10, 1e-5);
- BOOST_REQUIRE(neighbors(5, 3) == 0);
- BOOST_REQUIRE_CLOSE(distances(5, 3), 1.20, 1e-5);
- BOOST_REQUIRE(neighbors(6, 3) == 5);
- BOOST_REQUIRE_CLOSE(distances(6, 3), 1.47, 1e-5);
- BOOST_REQUIRE(neighbors(7, 3) == 7);
- BOOST_REQUIRE_CLOSE(distances(7, 3), 2.55, 1e-5);
- BOOST_REQUIRE(neighbors(8, 3) == 6);
- BOOST_REQUIRE_CLOSE(distances(8, 3), 3.25, 1e-5);
- BOOST_REQUIRE(neighbors(9, 3) == 4);
- BOOST_REQUIRE_CLOSE(distances(9, 3), 3.80, 1e-5);
-
- // Neighbors of point 4.
- BOOST_REQUIRE(neighbors(0, 4) == 3);
- BOOST_REQUIRE_CLOSE(distances(0, 4), 3.80, 1e-5);
- BOOST_REQUIRE(neighbors(1, 4) == 10);
- BOOST_REQUIRE_CLOSE(distances(1, 4), 4.05, 1e-5);
- BOOST_REQUIRE(neighbors(2, 4) == 9);
- BOOST_REQUIRE_CLOSE(distances(2, 4), 4.15, 1e-5);
- BOOST_REQUIRE(neighbors(3, 4) == 8);
- BOOST_REQUIRE_CLOSE(distances(3, 4), 4.60, 1e-5);
- BOOST_REQUIRE(neighbors(4, 4) == 1);
- BOOST_REQUIRE_CLOSE(distances(4, 4), 4.70, 1e-5);
- BOOST_REQUIRE(neighbors(5, 4) == 2);
- BOOST_REQUIRE_CLOSE(distances(5, 4), 4.90, 1e-5);
- BOOST_REQUIRE(neighbors(6, 4) == 0);
- BOOST_REQUIRE_CLOSE(distances(6, 4), 5.00, 1e-5);
- BOOST_REQUIRE(neighbors(7, 4) == 5);
- BOOST_REQUIRE_CLOSE(distances(7, 4), 5.27, 1e-5);
- BOOST_REQUIRE(neighbors(8, 4) == 7);
- BOOST_REQUIRE_CLOSE(distances(8, 4), 6.35, 1e-5);
- BOOST_REQUIRE(neighbors(9, 4) == 6);
- BOOST_REQUIRE_CLOSE(distances(9, 4), 7.05, 1e-5);
-
- // Neighbors of point 5.
- BOOST_REQUIRE(neighbors(0, 5) == 0);
- BOOST_REQUIRE_CLOSE(distances(0, 5), 0.27, 1e-5);
- BOOST_REQUIRE(neighbors(1, 5) == 2);
- BOOST_REQUIRE_CLOSE(distances(1, 5), 0.37, 1e-5);
- BOOST_REQUIRE(neighbors(2, 5) == 1);
- BOOST_REQUIRE_CLOSE(distances(2, 5), 0.57, 1e-5);
- BOOST_REQUIRE(neighbors(3, 5) == 8);
- BOOST_REQUIRE_CLOSE(distances(3, 5), 0.67, 1e-5);
- BOOST_REQUIRE(neighbors(4, 5) == 7);
- BOOST_REQUIRE_CLOSE(distances(4, 5), 1.08, 1e-5);
- BOOST_REQUIRE(neighbors(5, 5) == 9);
- BOOST_REQUIRE_CLOSE(distances(5, 5), 1.12, 1e-5);
- BOOST_REQUIRE(neighbors(6, 5) == 10);
- BOOST_REQUIRE_CLOSE(distances(6, 5), 1.22, 1e-5);
- BOOST_REQUIRE(neighbors(7, 5) == 3);
- BOOST_REQUIRE_CLOSE(distances(7, 5), 1.47, 1e-5);
- BOOST_REQUIRE(neighbors(8, 5) == 6);
- BOOST_REQUIRE_CLOSE(distances(8, 5), 1.78, 1e-5);
- BOOST_REQUIRE(neighbors(9, 5) == 4);
- BOOST_REQUIRE_CLOSE(distances(9, 5), 5.27, 1e-5);
-
- // Neighbors of point 6.
- BOOST_REQUIRE(neighbors(0, 6) == 7);
- BOOST_REQUIRE_CLOSE(distances(0, 6), 0.70, 1e-5);
- BOOST_REQUIRE(neighbors(1, 6) == 5);
- BOOST_REQUIRE_CLOSE(distances(1, 6), 1.78, 1e-5);
- BOOST_REQUIRE(neighbors(2, 6) == 0);
- BOOST_REQUIRE_CLOSE(distances(2, 6), 2.05, 1e-5);
- BOOST_REQUIRE(neighbors(3, 6) == 2);
- BOOST_REQUIRE_CLOSE(distances(3, 6), 2.15, 1e-5);
- BOOST_REQUIRE(neighbors(4, 6) == 1);
- BOOST_REQUIRE_CLOSE(distances(4, 6), 2.35, 1e-5);
- BOOST_REQUIRE(neighbors(5, 6) == 8);
- BOOST_REQUIRE_CLOSE(distances(5, 6), 2.45, 1e-5);
- BOOST_REQUIRE(neighbors(6, 6) == 9);
- BOOST_REQUIRE_CLOSE(distances(6, 6), 2.90, 1e-5);
- BOOST_REQUIRE(neighbors(7, 6) == 10);
- BOOST_REQUIRE_CLOSE(distances(7, 6), 3.00, 1e-5);
- BOOST_REQUIRE(neighbors(8, 6) == 3);
- BOOST_REQUIRE_CLOSE(distances(8, 6), 3.25, 1e-5);
- BOOST_REQUIRE(neighbors(9, 6) == 4);
- BOOST_REQUIRE_CLOSE(distances(9, 6), 7.05, 1e-5);
-
- // Neighbors of point 7.
- BOOST_REQUIRE(neighbors(0, 7) == 6);
- BOOST_REQUIRE_CLOSE(distances(0, 7), 0.70, 1e-5);
- BOOST_REQUIRE(neighbors(1, 7) == 5);
- BOOST_REQUIRE_CLOSE(distances(1, 7), 1.08, 1e-5);
- BOOST_REQUIRE(neighbors(2, 7) == 0);
- BOOST_REQUIRE_CLOSE(distances(2, 7), 1.35, 1e-5);
- BOOST_REQUIRE(neighbors(3, 7) == 2);
- BOOST_REQUIRE_CLOSE(distances(3, 7), 1.45, 1e-5);
- BOOST_REQUIRE(neighbors(4, 7) == 1);
- BOOST_REQUIRE_CLOSE(distances(4, 7), 1.65, 1e-5);
- BOOST_REQUIRE(neighbors(5, 7) == 8);
- BOOST_REQUIRE_CLOSE(distances(5, 7), 1.75, 1e-5);
- BOOST_REQUIRE(neighbors(6, 7) == 9);
- BOOST_REQUIRE_CLOSE(distances(6, 7), 2.20, 1e-5);
- BOOST_REQUIRE(neighbors(7, 7) == 10);
- BOOST_REQUIRE_CLOSE(distances(7, 7), 2.30, 1e-5);
- BOOST_REQUIRE(neighbors(8, 7) == 3);
- BOOST_REQUIRE_CLOSE(distances(8, 7), 2.55, 1e-5);
- BOOST_REQUIRE(neighbors(9, 7) == 4);
- BOOST_REQUIRE_CLOSE(distances(9, 7), 6.35, 1e-5);
-
- // Neighbors of point 8.
- BOOST_REQUIRE(neighbors(0, 8) == 1);
- BOOST_REQUIRE_CLOSE(distances(0, 8), 0.10, 1e-5);
- BOOST_REQUIRE(neighbors(1, 8) == 2);
- BOOST_REQUIRE_CLOSE(distances(1, 8), 0.30, 1e-5);
- BOOST_REQUIRE(neighbors(2, 8) == 0);
- BOOST_REQUIRE_CLOSE(distances(2, 8), 0.40, 1e-5);
- BOOST_REQUIRE(neighbors(3, 8) == 9);
- BOOST_REQUIRE_CLOSE(distances(3, 8), 0.45, 1e-5);
- BOOST_REQUIRE(neighbors(4, 8) == 10);
- BOOST_REQUIRE_CLOSE(distances(4, 8), 0.55, 1e-5);
- BOOST_REQUIRE(neighbors(5, 8) == 5);
- BOOST_REQUIRE_CLOSE(distances(5, 8), 0.67, 1e-5);
- BOOST_REQUIRE(neighbors(6, 8) == 3);
- BOOST_REQUIRE_CLOSE(distances(6, 8), 0.80, 1e-5);
- BOOST_REQUIRE(neighbors(7, 8) == 7);
- BOOST_REQUIRE_CLOSE(distances(7, 8), 1.75, 1e-5);
- BOOST_REQUIRE(neighbors(8, 8) == 6);
- BOOST_REQUIRE_CLOSE(distances(8, 8), 2.45, 1e-5);
- BOOST_REQUIRE(neighbors(9, 8) == 4);
- BOOST_REQUIRE_CLOSE(distances(9, 8), 4.60, 1e-5);
-
- // Neighbors of point 9.
- BOOST_REQUIRE(neighbors(0, 9) == 10);
- BOOST_REQUIRE_CLOSE(distances(0, 9), 0.10, 1e-5);
- BOOST_REQUIRE(neighbors(1, 9) == 3);
- BOOST_REQUIRE_CLOSE(distances(1, 9), 0.35, 1e-5);
- BOOST_REQUIRE(neighbors(2, 9) == 8);
- BOOST_REQUIRE_CLOSE(distances(2, 9), 0.45, 1e-5);
- BOOST_REQUIRE(neighbors(3, 9) == 1);
- BOOST_REQUIRE_CLOSE(distances(3, 9), 0.55, 1e-5);
- BOOST_REQUIRE(neighbors(4, 9) == 2);
- BOOST_REQUIRE_CLOSE(distances(4, 9), 0.75, 1e-5);
- BOOST_REQUIRE(neighbors(5, 9) == 0);
- BOOST_REQUIRE_CLOSE(distances(5, 9), 0.85, 1e-5);
- BOOST_REQUIRE(neighbors(6, 9) == 5);
- BOOST_REQUIRE_CLOSE(distances(6, 9), 1.12, 1e-5);
- BOOST_REQUIRE(neighbors(7, 9) == 7);
- BOOST_REQUIRE_CLOSE(distances(7, 9), 2.20, 1e-5);
- BOOST_REQUIRE(neighbors(8, 9) == 6);
- BOOST_REQUIRE_CLOSE(distances(8, 9), 2.90, 1e-5);
- BOOST_REQUIRE(neighbors(9, 9) == 4);
- BOOST_REQUIRE_CLOSE(distances(9, 9), 4.15, 1e-5);
-
- // Neighbors of point 10.
- BOOST_REQUIRE(neighbors(0, 10) == 9);
- BOOST_REQUIRE_CLOSE(distances(0, 10), 0.10, 1e-5);
- BOOST_REQUIRE(neighbors(1, 10) == 3);
- BOOST_REQUIRE_CLOSE(distances(1, 10), 0.25, 1e-5);
- BOOST_REQUIRE(neighbors(2, 10) == 8);
- BOOST_REQUIRE_CLOSE(distances(2, 10), 0.55, 1e-5);
- BOOST_REQUIRE(neighbors(3, 10) == 1);
- BOOST_REQUIRE_CLOSE(distances(3, 10), 0.65, 1e-5);
- BOOST_REQUIRE(neighbors(4, 10) == 2);
- BOOST_REQUIRE_CLOSE(distances(4, 10), 0.85, 1e-5);
- BOOST_REQUIRE(neighbors(5, 10) == 0);
- BOOST_REQUIRE_CLOSE(distances(5, 10), 0.95, 1e-5);
- BOOST_REQUIRE(neighbors(6, 10) == 5);
- BOOST_REQUIRE_CLOSE(distances(6, 10), 1.22, 1e-5);
- BOOST_REQUIRE(neighbors(7, 10) == 7);
- BOOST_REQUIRE_CLOSE(distances(7, 10), 2.30, 1e-5);
- BOOST_REQUIRE(neighbors(8, 10) == 6);
- BOOST_REQUIRE_CLOSE(distances(8, 10), 3.00, 1e-5);
- BOOST_REQUIRE(neighbors(9, 10) == 4);
- BOOST_REQUIRE_CLOSE(distances(9, 10), 4.05, 1e-5);
-
- // Clean the memory.
- delete allknn;
- }
-}
-
-/**
- * Test the dual-tree nearest-neighbors method with the naive method. This
- * uses both a query and reference dataset.
- *
- * Errors are produced if the results are not identical.
- */
-BOOST_AUTO_TEST_CASE(DualTreeVsNaive1)
-{
- arma::mat dataForTree;
-
- // Hard-coded filename: bad!
- if (!data::Load("test_data_3_1000.csv", dataForTree))
- BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
-
- // Set up matrices to work with.
- arma::mat dualQuery(dataForTree);
- arma::mat dualReferences(dataForTree);
- arma::mat naiveQuery(dataForTree);
- arma::mat naiveReferences(dataForTree);
-
- AllkNN allknn(dualQuery, dualReferences);
-
- AllkNN naive(naiveQuery, naiveReferences, true);
-
- arma::Mat<size_t> resultingNeighborsTree;
- arma::mat distancesTree;
- allknn.Search(15, resultingNeighborsTree, distancesTree);
-
- arma::Mat<size_t> resultingNeighborsNaive;
- arma::mat distancesNaive;
- naive.Search(15, resultingNeighborsNaive, distancesNaive);
-
- for (size_t i = 0; i < resultingNeighborsTree.n_elem; i++)
- {
- BOOST_REQUIRE(resultingNeighborsTree[i] == resultingNeighborsNaive[i]);
- BOOST_REQUIRE_CLOSE(distancesTree[i], distancesNaive[i], 1e-5);
- }
-}
-
-/**
- * Test the dual-tree nearest-neighbors method with the naive method. This uses
- * only a reference dataset.
- *
- * Errors are produced if the results are not identical.
- */
-BOOST_AUTO_TEST_CASE(DualTreeVsNaive2)
-{
- arma::mat dataForTree;
-
- // Hard-coded filename: bad!
- // Code duplication: also bad!
- if (!data::Load("test_data_3_1000.csv", dataForTree))
- BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
-
- // Set up matrices to work with (may not be necessary with no ALIAS_MATRIX?).
- arma::mat dualQuery(dataForTree);
- arma::mat naiveQuery(dataForTree);
-
- AllkNN allknn(dualQuery);
-
- // Set naive mode.
- AllkNN naive(naiveQuery, true);
-
- arma::Mat<size_t> resultingNeighborsTree;
- arma::mat distancesTree;
- allknn.Search(15, resultingNeighborsTree, distancesTree);
-
- arma::Mat<size_t> resultingNeighborsNaive;
- arma::mat distancesNaive;
- naive.Search(15, resultingNeighborsNaive, distancesNaive);
-
- for (size_t i = 0; i < resultingNeighborsTree.n_elem; i++)
- {
- BOOST_REQUIRE(resultingNeighborsTree[i] == resultingNeighborsNaive[i]);
- BOOST_REQUIRE_CLOSE(distancesTree[i], distancesNaive[i], 1e-5);
- }
-}
-
-/**
- * Test the single-tree nearest-neighbors method with the naive method. This
- * uses only a reference dataset.
- *
- * Errors are produced if the results are not identical.
- */
-BOOST_AUTO_TEST_CASE(SingleTreeVsNaive)
-{
- arma::mat dataForTree;
-
- // Hard-coded filename: bad!
- // Code duplication: also bad!
- if (!data::Load("test_data_3_1000.csv", dataForTree))
- BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
-
- // Set up matrices to work with (may not be necessary with no ALIAS_MATRIX?).
- arma::mat singleQuery(dataForTree);
- arma::mat naiveQuery(dataForTree);
-
- AllkNN allknn(singleQuery, false, true);
-
- // Set up computation for naive mode.
- AllkNN naive(naiveQuery, true);
-
- arma::Mat<size_t> resultingNeighborsTree;
- arma::mat distancesTree;
- allknn.Search(15, resultingNeighborsTree, distancesTree);
-
- arma::Mat<size_t> resultingNeighborsNaive;
- arma::mat distancesNaive;
- naive.Search(15, resultingNeighborsNaive, distancesNaive);
-
- for (size_t i = 0; i < resultingNeighborsTree.n_elem; i++)
- {
- BOOST_REQUIRE(resultingNeighborsTree[i] == resultingNeighborsNaive[i]);
- BOOST_REQUIRE_CLOSE(distancesTree[i], distancesNaive[i], 1e-5);
- }
-}
-
-/**
- * Test the cover tree single-tree nearest-neighbors method against the naive
- * method. This uses only a random reference dataset.
- *
- * Errors are produced if the results are not identical.
- */
-BOOST_AUTO_TEST_CASE(SingleCoverTreeTest)
-{
- arma::mat data;
- data.randu(75, 1000); // 75 dimensional, 1000 points.
-
- arma::mat naiveQuery(data); // For naive AllkNN.
-
- tree::CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> > tree = tree::CoverTree<
- metric::LMetric<2>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> >(data);
-
- NeighborSearch<NearestNeighborSort, metric::LMetric<2>,
- tree::CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> > >
- coverTreeSearch(&tree, data, true);
-
- AllkNN naive(naiveQuery, true);
-
- arma::Mat<size_t> coverTreeNeighbors;
- arma::mat coverTreeDistances;
- coverTreeSearch.Search(15, coverTreeNeighbors, coverTreeDistances);
-
- arma::Mat<size_t> naiveNeighbors;
- arma::mat naiveDistances;
- naive.Search(15, naiveNeighbors, naiveDistances);
-
- for (size_t i = 0; i < coverTreeNeighbors.n_elem; ++i)
- {
- BOOST_REQUIRE_EQUAL(coverTreeNeighbors[i], naiveNeighbors[i]);
- BOOST_REQUIRE_CLOSE(coverTreeDistances[i], naiveDistances[i], 1e-5);
- }
-}
-
-/**
- * Test the cover tree dual-tree nearest neighbors method against the naive
- * method.
- */
-BOOST_AUTO_TEST_CASE(DualCoverTreeTest)
-{
- arma::mat dataset;
-// srand(time(NULL));
-// dataset.randn(5, 5000);
- data::Load("test_data_3_1000.csv", dataset);
-
- arma::mat kdtreeData(dataset);
-
- AllkNN tree(kdtreeData);
-
- arma::Mat<size_t> kdNeighbors;
- arma::mat kdDistances;
- tree.Search(5, kdNeighbors, kdDistances);
-
- tree::CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> > referenceTree = tree::CoverTree<
- metric::LMetric<2, true>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> >(dataset);
-
- NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
- tree::CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
- QueryStat<NearestNeighborSort> > >
- coverTreeSearch(&referenceTree, dataset);
-
- arma::Mat<size_t> coverNeighbors;
- arma::mat coverDistances;
- coverTreeSearch.Search(5, coverNeighbors, coverDistances);
-
- for (size_t i = 0; i < coverNeighbors.n_cols; ++i)
- {
- Log::Debug << "cover neighbors col " << i << "\n" <<
- trans(coverNeighbors.col(i));
- Log::Debug << "cover distances col " << i << "\n" <<
- trans(coverDistances.col(i));
- Log::Debug << "kd neighbors col " << i << "\n" <<
- trans(kdNeighbors.col(i));
- Log::Debug << "kd distances col " << i << "\n" <<
- trans(kdDistances.col(i));
- for (size_t j = 0; j < coverNeighbors.n_rows; ++j)
- {
- BOOST_REQUIRE_EQUAL(coverNeighbors(j, i), kdNeighbors(j, i));
- BOOST_REQUIRE_CLOSE(coverDistances(j, i), kdDistances(j, i), 1e-5);
- }
- }
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/allknn_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/allknn_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/allknn_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/allknn_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,682 @@
+/**
+ * @file allknn_test.cpp
+ *
+ * Test file for AllkNN class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+#include <mlpack/methods/neighbor_search/unmap.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::neighbor;
+
+BOOST_AUTO_TEST_SUITE(AllkNNTest);
+
+/**
+ * Test that Unmap() works in the dual-tree case (see unmap.hpp).
+ */
+BOOST_AUTO_TEST_CASE(DualTreeUnmapTest)
+{
+ std::vector<size_t> refMap;
+ refMap.push_back(3);
+ refMap.push_back(4);
+ refMap.push_back(1);
+ refMap.push_back(2);
+ refMap.push_back(0);
+
+ std::vector<size_t> queryMap;
+ queryMap.push_back(2);
+ queryMap.push_back(0);
+ queryMap.push_back(4);
+ queryMap.push_back(3);
+ queryMap.push_back(1);
+ queryMap.push_back(5);
+
+ // Now generate some results. 6 queries, 5 references.
+ arma::Mat<size_t> neighbors("3 1 2 0 4;"
+ "1 0 2 3 4;"
+ "0 1 2 3 4;"
+ "4 1 0 3 2;"
+ "3 0 4 1 2;"
+ "3 0 4 1 2;");
+ neighbors = neighbors.t();
+
+ // Integer distances will work fine here.
+ arma::mat distances("3 1 2 0 4;"
+ "1 0 2 3 4;"
+ "0 1 2 3 4;"
+ "4 1 0 3 2;"
+ "3 0 4 1 2;"
+ "3 0 4 1 2;");
+ distances = distances.t();
+
+ // This is what the results should be when they are unmapped.
+ arma::Mat<size_t> correctNeighbors("4 3 1 2 0;"
+ "2 3 0 4 1;"
+ "2 4 1 3 0;"
+ "0 4 3 2 1;"
+ "3 4 1 2 0;"
+ "2 3 0 4 1;");
+ correctNeighbors = correctNeighbors.t();
+
+ arma::mat correctDistances("1 0 2 3 4;"
+ "3 0 4 1 2;"
+ "3 1 2 0 4;"
+ "4 1 0 3 2;"
+ "0 1 2 3 4;"
+ "3 0 4 1 2;");
+ correctDistances = correctDistances.t();
+
+ // Perform the unmapping.
+ arma::Mat<size_t> neighborsOut;
+ arma::mat distancesOut;
+
+ Unmap(neighbors, distances, refMap, queryMap, neighborsOut, distancesOut);
+
+ for (size_t i = 0; i < correctNeighbors.n_elem; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(neighborsOut[i], correctNeighbors[i]);
+ BOOST_REQUIRE_CLOSE(distancesOut[i], correctDistances[i], 1e-5);
+ }
+
+ // Now try taking the square root.
+ Unmap(neighbors, distances, refMap, queryMap, neighborsOut, distancesOut,
+ true);
+
+ for (size_t i = 0; i < correctNeighbors.n_elem; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(neighborsOut[i], correctNeighbors[i]);
+ BOOST_REQUIRE_CLOSE(distancesOut[i], sqrt(correctDistances[i]), 1e-5);
+ }
+}
+
+/**
+ * Check that Unmap() works in the single-tree case.
+ */
+BOOST_AUTO_TEST_CASE(SingleTreeUnmapTest)
+{
+ std::vector<size_t> refMap;
+ refMap.push_back(3);
+ refMap.push_back(4);
+ refMap.push_back(1);
+ refMap.push_back(2);
+ refMap.push_back(0);
+
+ // Now generate some results. 6 queries, 5 references.
+ arma::Mat<size_t> neighbors("3 1 2 0 4;"
+ "1 0 2 3 4;"
+ "0 1 2 3 4;"
+ "4 1 0 3 2;"
+ "3 0 4 1 2;"
+ "3 0 4 1 2;");
+ neighbors = neighbors.t();
+
+ // Integer distances will work fine here.
+ arma::mat distances("3 1 2 0 4;"
+ "1 0 2 3 4;"
+ "0 1 2 3 4;"
+ "4 1 0 3 2;"
+ "3 0 4 1 2;"
+ "3 0 4 1 2;");
+ distances = distances.t();
+
+ // This is what the results should be when they are unmapped.
+ arma::Mat<size_t> correctNeighbors("2 4 1 3 0;"
+ "4 3 1 2 0;"
+ "3 4 1 2 0;"
+ "0 4 3 2 1;"
+ "2 3 0 4 1;"
+ "2 3 0 4 1;");
+ correctNeighbors = correctNeighbors.t();
+
+ arma::mat correctDistances = distances;
+
+ // Perform the unmapping.
+ arma::Mat<size_t> neighborsOut;
+ arma::mat distancesOut;
+
+ Unmap(neighbors, distances, refMap, neighborsOut, distancesOut);
+
+ for (size_t i = 0; i < correctNeighbors.n_elem; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(neighborsOut[i], correctNeighbors[i]);
+ BOOST_REQUIRE_CLOSE(distancesOut[i], correctDistances[i], 1e-5);
+ }
+
+ // Now try taking the square root.
+ Unmap(neighbors, distances, refMap, neighborsOut, distancesOut, true);
+
+ for (size_t i = 0; i < correctNeighbors.n_elem; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(neighborsOut[i], correctNeighbors[i]);
+ BOOST_REQUIRE_CLOSE(distancesOut[i], sqrt(correctDistances[i]), 1e-5);
+ }
+}
+
+/**
+ * Simple nearest-neighbors test with small, synthetic dataset. This is an
+ * exhaustive test, which checks that each method for performing the calculation
+ * (dual-tree, single-tree, naive) produces the correct results. An
+ * eleven-point dataset and the ten nearest neighbors are taken. The dataset is
+ * in one dimension for simplicity -- the correct functionality of distance
+ * functions is not tested here.
+ */
+BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
+{
+ // Set up our data.
+ arma::mat data(1, 11);
+ data[0] = 0.05; // Row addressing is unnecessary (they are all 0).
+ data[1] = 0.35;
+ data[2] = 0.15;
+ data[3] = 1.25;
+ data[4] = 5.05;
+ data[5] = -0.22;
+ data[6] = -2.00;
+ data[7] = -1.30;
+ data[8] = 0.45;
+ data[9] = 0.90;
+ data[10] = 1.00;
+
+ // We will loop through three times, one for each method of performing the
+ // calculation.
+ for (int i = 0; i < 3; i++)
+ {
+ AllkNN* allknn;
+ arma::mat dataMutable = data;
+ switch (i)
+ {
+ case 0: // Use the dual-tree method.
+ allknn = new AllkNN(dataMutable, false, false, 1);
+ break;
+ case 1: // Use the single-tree method.
+ allknn = new AllkNN(dataMutable, false, true, 1);
+ break;
+ case 2: // Use the naive method.
+ allknn = new AllkNN(dataMutable, true);
+ break;
+ }
+
+ // Now perform the actual calculation.
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+ allknn->Search(10, neighbors, distances);
+
+ // Now the exhaustive check for correctness. This will be long. We must
+ // also remember that the distances returned are squared distances. As a
+ // result, distance comparisons are written out as (distance * distance) for
+ // readability.
+
+ // Neighbors of point 0.
+ BOOST_REQUIRE(neighbors(0, 0) == 2);
+ BOOST_REQUIRE_CLOSE(distances(0, 0), 0.10, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 0) == 5);
+ BOOST_REQUIRE_CLOSE(distances(1, 0), 0.27, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 0) == 1);
+ BOOST_REQUIRE_CLOSE(distances(2, 0), 0.30, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 0) == 8);
+ BOOST_REQUIRE_CLOSE(distances(3, 0), 0.40, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 0) == 9);
+ BOOST_REQUIRE_CLOSE(distances(4, 0), 0.85, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 0) == 10);
+ BOOST_REQUIRE_CLOSE(distances(5, 0), 0.95, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 0) == 3);
+ BOOST_REQUIRE_CLOSE(distances(6, 0), 1.20, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 0) == 7);
+ BOOST_REQUIRE_CLOSE(distances(7, 0), 1.35, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 0) == 6);
+ BOOST_REQUIRE_CLOSE(distances(8, 0), 2.05, 1e-5);
+ BOOST_REQUIRE(neighbors(9, 0) == 4);
+ BOOST_REQUIRE_CLOSE(distances(9, 0), 5.00, 1e-5);
+
+ // Neighbors of point 1.
+ BOOST_REQUIRE(neighbors(0, 1) == 8);
+ BOOST_REQUIRE_CLOSE(distances(0, 1), 0.10, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 1) == 2);
+ BOOST_REQUIRE_CLOSE(distances(1, 1), 0.20, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 1) == 0);
+ BOOST_REQUIRE_CLOSE(distances(2, 1), 0.30, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 1) == 9);
+ BOOST_REQUIRE_CLOSE(distances(3, 1), 0.55, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 1) == 5);
+ BOOST_REQUIRE_CLOSE(distances(4, 1), 0.57, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 1) == 10);
+ BOOST_REQUIRE_CLOSE(distances(5, 1), 0.65, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 1) == 3);
+ BOOST_REQUIRE_CLOSE(distances(6, 1), 0.90, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 1) == 7);
+ BOOST_REQUIRE_CLOSE(distances(7, 1), 1.65, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 1) == 6);
+ BOOST_REQUIRE_CLOSE(distances(8, 1), 2.35, 1e-5);
+ BOOST_REQUIRE(neighbors(9, 1) == 4);
+ BOOST_REQUIRE_CLOSE(distances(9, 1), 4.70, 1e-5);
+
+ // Neighbors of point 2.
+ BOOST_REQUIRE(neighbors(0, 2) == 0);
+ BOOST_REQUIRE_CLOSE(distances(0, 2), 0.10, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 2) == 1);
+ BOOST_REQUIRE_CLOSE(distances(1, 2), 0.20, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 2) == 8);
+ BOOST_REQUIRE_CLOSE(distances(2, 2), 0.30, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 2) == 5);
+ BOOST_REQUIRE_CLOSE(distances(3, 2), 0.37, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 2) == 9);
+ BOOST_REQUIRE_CLOSE(distances(4, 2), 0.75, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 2) == 10);
+ BOOST_REQUIRE_CLOSE(distances(5, 2), 0.85, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 2) == 3);
+ BOOST_REQUIRE_CLOSE(distances(6, 2), 1.10, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 2) == 7);
+ BOOST_REQUIRE_CLOSE(distances(7, 2), 1.45, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 2) == 6);
+ BOOST_REQUIRE_CLOSE(distances(8, 2), 2.15, 1e-5);
+ BOOST_REQUIRE(neighbors(9, 2) == 4);
+ BOOST_REQUIRE_CLOSE(distances(9, 2), 4.90, 1e-5);
+
+ // Neighbors of point 3.
+ BOOST_REQUIRE(neighbors(0, 3) == 10);
+ BOOST_REQUIRE_CLOSE(distances(0, 3), 0.25, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 3) == 9);
+ BOOST_REQUIRE_CLOSE(distances(1, 3), 0.35, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 3) == 8);
+ BOOST_REQUIRE_CLOSE(distances(2, 3), 0.80, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 3) == 1);
+ BOOST_REQUIRE_CLOSE(distances(3, 3), 0.90, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 3) == 2);
+ BOOST_REQUIRE_CLOSE(distances(4, 3), 1.10, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 3) == 0);
+ BOOST_REQUIRE_CLOSE(distances(5, 3), 1.20, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 3) == 5);
+ BOOST_REQUIRE_CLOSE(distances(6, 3), 1.47, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 3) == 7);
+ BOOST_REQUIRE_CLOSE(distances(7, 3), 2.55, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 3) == 6);
+ BOOST_REQUIRE_CLOSE(distances(8, 3), 3.25, 1e-5);
+ BOOST_REQUIRE(neighbors(9, 3) == 4);
+ BOOST_REQUIRE_CLOSE(distances(9, 3), 3.80, 1e-5);
+
+ // Neighbors of point 4.
+ BOOST_REQUIRE(neighbors(0, 4) == 3);
+ BOOST_REQUIRE_CLOSE(distances(0, 4), 3.80, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 4) == 10);
+ BOOST_REQUIRE_CLOSE(distances(1, 4), 4.05, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 4) == 9);
+ BOOST_REQUIRE_CLOSE(distances(2, 4), 4.15, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 4) == 8);
+ BOOST_REQUIRE_CLOSE(distances(3, 4), 4.60, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 4) == 1);
+ BOOST_REQUIRE_CLOSE(distances(4, 4), 4.70, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 4) == 2);
+ BOOST_REQUIRE_CLOSE(distances(5, 4), 4.90, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 4) == 0);
+ BOOST_REQUIRE_CLOSE(distances(6, 4), 5.00, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 4) == 5);
+ BOOST_REQUIRE_CLOSE(distances(7, 4), 5.27, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 4) == 7);
+ BOOST_REQUIRE_CLOSE(distances(8, 4), 6.35, 1e-5);
+ BOOST_REQUIRE(neighbors(9, 4) == 6);
+ BOOST_REQUIRE_CLOSE(distances(9, 4), 7.05, 1e-5);
+
+ // Neighbors of point 5.
+ BOOST_REQUIRE(neighbors(0, 5) == 0);
+ BOOST_REQUIRE_CLOSE(distances(0, 5), 0.27, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 5) == 2);
+ BOOST_REQUIRE_CLOSE(distances(1, 5), 0.37, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 5) == 1);
+ BOOST_REQUIRE_CLOSE(distances(2, 5), 0.57, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 5) == 8);
+ BOOST_REQUIRE_CLOSE(distances(3, 5), 0.67, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 5) == 7);
+ BOOST_REQUIRE_CLOSE(distances(4, 5), 1.08, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 5) == 9);
+ BOOST_REQUIRE_CLOSE(distances(5, 5), 1.12, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 5) == 10);
+ BOOST_REQUIRE_CLOSE(distances(6, 5), 1.22, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 5) == 3);
+ BOOST_REQUIRE_CLOSE(distances(7, 5), 1.47, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 5) == 6);
+ BOOST_REQUIRE_CLOSE(distances(8, 5), 1.78, 1e-5);
+ BOOST_REQUIRE(neighbors(9, 5) == 4);
+ BOOST_REQUIRE_CLOSE(distances(9, 5), 5.27, 1e-5);
+
+ // Neighbors of point 6.
+ BOOST_REQUIRE(neighbors(0, 6) == 7);
+ BOOST_REQUIRE_CLOSE(distances(0, 6), 0.70, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 6) == 5);
+ BOOST_REQUIRE_CLOSE(distances(1, 6), 1.78, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 6) == 0);
+ BOOST_REQUIRE_CLOSE(distances(2, 6), 2.05, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 6) == 2);
+ BOOST_REQUIRE_CLOSE(distances(3, 6), 2.15, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 6) == 1);
+ BOOST_REQUIRE_CLOSE(distances(4, 6), 2.35, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 6) == 8);
+ BOOST_REQUIRE_CLOSE(distances(5, 6), 2.45, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 6) == 9);
+ BOOST_REQUIRE_CLOSE(distances(6, 6), 2.90, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 6) == 10);
+ BOOST_REQUIRE_CLOSE(distances(7, 6), 3.00, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 6) == 3);
+ BOOST_REQUIRE_CLOSE(distances(8, 6), 3.25, 1e-5);
+ BOOST_REQUIRE(neighbors(9, 6) == 4);
+ BOOST_REQUIRE_CLOSE(distances(9, 6), 7.05, 1e-5);
+
+ // Neighbors of point 7.
+ BOOST_REQUIRE(neighbors(0, 7) == 6);
+ BOOST_REQUIRE_CLOSE(distances(0, 7), 0.70, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 7) == 5);
+ BOOST_REQUIRE_CLOSE(distances(1, 7), 1.08, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 7) == 0);
+ BOOST_REQUIRE_CLOSE(distances(2, 7), 1.35, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 7) == 2);
+ BOOST_REQUIRE_CLOSE(distances(3, 7), 1.45, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 7) == 1);
+ BOOST_REQUIRE_CLOSE(distances(4, 7), 1.65, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 7) == 8);
+ BOOST_REQUIRE_CLOSE(distances(5, 7), 1.75, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 7) == 9);
+ BOOST_REQUIRE_CLOSE(distances(6, 7), 2.20, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 7) == 10);
+ BOOST_REQUIRE_CLOSE(distances(7, 7), 2.30, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 7) == 3);
+ BOOST_REQUIRE_CLOSE(distances(8, 7), 2.55, 1e-5);
+ BOOST_REQUIRE(neighbors(9, 7) == 4);
+ BOOST_REQUIRE_CLOSE(distances(9, 7), 6.35, 1e-5);
+
+ // Neighbors of point 8.
+ BOOST_REQUIRE(neighbors(0, 8) == 1);
+ BOOST_REQUIRE_CLOSE(distances(0, 8), 0.10, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 8) == 2);
+ BOOST_REQUIRE_CLOSE(distances(1, 8), 0.30, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 8) == 0);
+ BOOST_REQUIRE_CLOSE(distances(2, 8), 0.40, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 8) == 9);
+ BOOST_REQUIRE_CLOSE(distances(3, 8), 0.45, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 8) == 10);
+ BOOST_REQUIRE_CLOSE(distances(4, 8), 0.55, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 8) == 5);
+ BOOST_REQUIRE_CLOSE(distances(5, 8), 0.67, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 8) == 3);
+ BOOST_REQUIRE_CLOSE(distances(6, 8), 0.80, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 8) == 7);
+ BOOST_REQUIRE_CLOSE(distances(7, 8), 1.75, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 8) == 6);
+ BOOST_REQUIRE_CLOSE(distances(8, 8), 2.45, 1e-5);
+ BOOST_REQUIRE(neighbors(9, 8) == 4);
+ BOOST_REQUIRE_CLOSE(distances(9, 8), 4.60, 1e-5);
+
+ // Neighbors of point 9.
+ BOOST_REQUIRE(neighbors(0, 9) == 10);
+ BOOST_REQUIRE_CLOSE(distances(0, 9), 0.10, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 9) == 3);
+ BOOST_REQUIRE_CLOSE(distances(1, 9), 0.35, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 9) == 8);
+ BOOST_REQUIRE_CLOSE(distances(2, 9), 0.45, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 9) == 1);
+ BOOST_REQUIRE_CLOSE(distances(3, 9), 0.55, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 9) == 2);
+ BOOST_REQUIRE_CLOSE(distances(4, 9), 0.75, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 9) == 0);
+ BOOST_REQUIRE_CLOSE(distances(5, 9), 0.85, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 9) == 5);
+ BOOST_REQUIRE_CLOSE(distances(6, 9), 1.12, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 9) == 7);
+ BOOST_REQUIRE_CLOSE(distances(7, 9), 2.20, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 9) == 6);
+ BOOST_REQUIRE_CLOSE(distances(8, 9), 2.90, 1e-5);
+ BOOST_REQUIRE(neighbors(9, 9) == 4);
+ BOOST_REQUIRE_CLOSE(distances(9, 9), 4.15, 1e-5);
+
+ // Neighbors of point 10.
+ BOOST_REQUIRE(neighbors(0, 10) == 9);
+ BOOST_REQUIRE_CLOSE(distances(0, 10), 0.10, 1e-5);
+ BOOST_REQUIRE(neighbors(1, 10) == 3);
+ BOOST_REQUIRE_CLOSE(distances(1, 10), 0.25, 1e-5);
+ BOOST_REQUIRE(neighbors(2, 10) == 8);
+ BOOST_REQUIRE_CLOSE(distances(2, 10), 0.55, 1e-5);
+ BOOST_REQUIRE(neighbors(3, 10) == 1);
+ BOOST_REQUIRE_CLOSE(distances(3, 10), 0.65, 1e-5);
+ BOOST_REQUIRE(neighbors(4, 10) == 2);
+ BOOST_REQUIRE_CLOSE(distances(4, 10), 0.85, 1e-5);
+ BOOST_REQUIRE(neighbors(5, 10) == 0);
+ BOOST_REQUIRE_CLOSE(distances(5, 10), 0.95, 1e-5);
+ BOOST_REQUIRE(neighbors(6, 10) == 5);
+ BOOST_REQUIRE_CLOSE(distances(6, 10), 1.22, 1e-5);
+ BOOST_REQUIRE(neighbors(7, 10) == 7);
+ BOOST_REQUIRE_CLOSE(distances(7, 10), 2.30, 1e-5);
+ BOOST_REQUIRE(neighbors(8, 10) == 6);
+ BOOST_REQUIRE_CLOSE(distances(8, 10), 3.00, 1e-5);
+ BOOST_REQUIRE(neighbors(9, 10) == 4);
+ BOOST_REQUIRE_CLOSE(distances(9, 10), 4.05, 1e-5);
+
+ // Clean the memory.
+ delete allknn;
+ }
+}
+
+/**
+ * Test the dual-tree nearest-neighbors method with the naive method. This
+ * uses both a query and reference dataset.
+ *
+ * Errors are produced if the results are not identical.
+ */
+BOOST_AUTO_TEST_CASE(DualTreeVsNaive1)
+{
+ arma::mat dataForTree;
+
+ // Hard-coded filename: bad!
+ if (!data::Load("test_data_3_1000.csv", dataForTree))
+ BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
+
+ // Set up matrices to work with.
+ arma::mat dualQuery(dataForTree);
+ arma::mat dualReferences(dataForTree);
+ arma::mat naiveQuery(dataForTree);
+ arma::mat naiveReferences(dataForTree);
+
+ AllkNN allknn(dualQuery, dualReferences);
+
+ AllkNN naive(naiveQuery, naiveReferences, true);
+
+ arma::Mat<size_t> resultingNeighborsTree;
+ arma::mat distancesTree;
+ allknn.Search(15, resultingNeighborsTree, distancesTree);
+
+ arma::Mat<size_t> resultingNeighborsNaive;
+ arma::mat distancesNaive;
+ naive.Search(15, resultingNeighborsNaive, distancesNaive);
+
+ for (size_t i = 0; i < resultingNeighborsTree.n_elem; i++)
+ {
+ BOOST_REQUIRE(resultingNeighborsTree[i] == resultingNeighborsNaive[i]);
+ BOOST_REQUIRE_CLOSE(distancesTree[i], distancesNaive[i], 1e-5);
+ }
+}
+
+/**
+ * Test the dual-tree nearest-neighbors method with the naive method. This uses
+ * only a reference dataset.
+ *
+ * Errors are produced if the results are not identical.
+ */
+BOOST_AUTO_TEST_CASE(DualTreeVsNaive2)
+{
+ arma::mat dataForTree;
+
+ // Hard-coded filename: bad!
+ // Code duplication: also bad!
+ if (!data::Load("test_data_3_1000.csv", dataForTree))
+ BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
+
+ // Set up matrices to work with (may not be necessary with no ALIAS_MATRIX?).
+ arma::mat dualQuery(dataForTree);
+ arma::mat naiveQuery(dataForTree);
+
+ AllkNN allknn(dualQuery);
+
+ // Set naive mode.
+ AllkNN naive(naiveQuery, true);
+
+ arma::Mat<size_t> resultingNeighborsTree;
+ arma::mat distancesTree;
+ allknn.Search(15, resultingNeighborsTree, distancesTree);
+
+ arma::Mat<size_t> resultingNeighborsNaive;
+ arma::mat distancesNaive;
+ naive.Search(15, resultingNeighborsNaive, distancesNaive);
+
+ for (size_t i = 0; i < resultingNeighborsTree.n_elem; i++)
+ {
+ BOOST_REQUIRE(resultingNeighborsTree[i] == resultingNeighborsNaive[i]);
+ BOOST_REQUIRE_CLOSE(distancesTree[i], distancesNaive[i], 1e-5);
+ }
+}
+
+/**
+ * Test the single-tree nearest-neighbors method with the naive method. This
+ * uses only a reference dataset.
+ *
+ * Errors are produced if the results are not identical.
+ */
+BOOST_AUTO_TEST_CASE(SingleTreeVsNaive)
+{
+ arma::mat dataForTree;
+
+ // Hard-coded filename: bad!
+ // Code duplication: also bad!
+ if (!data::Load("test_data_3_1000.csv", dataForTree))
+ BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
+
+ // Set up matrices to work with (may not be necessary with no ALIAS_MATRIX?).
+ arma::mat singleQuery(dataForTree);
+ arma::mat naiveQuery(dataForTree);
+
+ AllkNN allknn(singleQuery, false, true);
+
+ // Set up computation for naive mode.
+ AllkNN naive(naiveQuery, true);
+
+ arma::Mat<size_t> resultingNeighborsTree;
+ arma::mat distancesTree;
+ allknn.Search(15, resultingNeighborsTree, distancesTree);
+
+ arma::Mat<size_t> resultingNeighborsNaive;
+ arma::mat distancesNaive;
+ naive.Search(15, resultingNeighborsNaive, distancesNaive);
+
+ for (size_t i = 0; i < resultingNeighborsTree.n_elem; i++)
+ {
+ BOOST_REQUIRE(resultingNeighborsTree[i] == resultingNeighborsNaive[i]);
+ BOOST_REQUIRE_CLOSE(distancesTree[i], distancesNaive[i], 1e-5);
+ }
+}
+
+/**
+ * Test the cover tree single-tree nearest-neighbors method against the naive
+ * method. This uses only a random reference dataset.
+ *
+ * Errors are produced if the results are not identical.
+ */
+BOOST_AUTO_TEST_CASE(SingleCoverTreeTest)
+{
+ arma::mat data;
+ data.randu(75, 1000); // 75 dimensional, 1000 points.
+
+ arma::mat naiveQuery(data); // For naive AllkNN.
+
+ tree::CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> > tree = tree::CoverTree<
+ metric::LMetric<2>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> >(data);
+
+ NeighborSearch<NearestNeighborSort, metric::LMetric<2>,
+ tree::CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> > >
+ coverTreeSearch(&tree, data, true);
+
+ AllkNN naive(naiveQuery, true);
+
+ arma::Mat<size_t> coverTreeNeighbors;
+ arma::mat coverTreeDistances;
+ coverTreeSearch.Search(15, coverTreeNeighbors, coverTreeDistances);
+
+ arma::Mat<size_t> naiveNeighbors;
+ arma::mat naiveDistances;
+ naive.Search(15, naiveNeighbors, naiveDistances);
+
+ for (size_t i = 0; i < coverTreeNeighbors.n_elem; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(coverTreeNeighbors[i], naiveNeighbors[i]);
+ BOOST_REQUIRE_CLOSE(coverTreeDistances[i], naiveDistances[i], 1e-5);
+ }
+}
+
+/**
+ * Test the cover tree dual-tree nearest neighbors method against the naive
+ * method.
+ */
+BOOST_AUTO_TEST_CASE(DualCoverTreeTest)
+{
+ arma::mat dataset;
+// srand(time(NULL));
+// dataset.randn(5, 5000);
+ data::Load("test_data_3_1000.csv", dataset);
+
+ arma::mat kdtreeData(dataset);
+
+ AllkNN tree(kdtreeData);
+
+ arma::Mat<size_t> kdNeighbors;
+ arma::mat kdDistances;
+ tree.Search(5, kdNeighbors, kdDistances);
+
+ tree::CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> > referenceTree = tree::CoverTree<
+ metric::LMetric<2, true>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> >(dataset);
+
+ NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+ tree::CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
+ QueryStat<NearestNeighborSort> > >
+ coverTreeSearch(&referenceTree, dataset);
+
+ arma::Mat<size_t> coverNeighbors;
+ arma::mat coverDistances;
+ coverTreeSearch.Search(5, coverNeighbors, coverDistances);
+
+ for (size_t i = 0; i < coverNeighbors.n_cols; ++i)
+ {
+ Log::Debug << "cover neighbors col " << i << "\n" <<
+ trans(coverNeighbors.col(i));
+ Log::Debug << "cover distances col " << i << "\n" <<
+ trans(coverDistances.col(i));
+ Log::Debug << "kd neighbors col " << i << "\n" <<
+ trans(kdNeighbors.col(i));
+ Log::Debug << "kd distances col " << i << "\n" <<
+ trans(kdDistances.col(i));
+ for (size_t j = 0; j < coverNeighbors.n_rows; ++j)
+ {
+ BOOST_REQUIRE_EQUAL(coverNeighbors(j, i), kdNeighbors(j, i));
+ BOOST_REQUIRE_CLOSE(coverDistances(j, i), kdDistances(j, i), 1e-5);
+ }
+ }
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/allkrann_search_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/allkrann_search_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/allkrann_search_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,287 +0,0 @@
-/**
- * @file allkrann_search_test.cpp
- *
- * Unit tests for the 'RASearch' class and consequently the 'RASearchRules'
- * class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <time.h>
-#include <mlpack/core.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-
-// So that we can test private members. This is hackish (for now).
-#define private public
-#include <mlpack/methods/rann/ra_search.hpp>
-#undef private
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace std;
-using namespace mlpack;
-using namespace mlpack::neighbor;
-
-
-BOOST_AUTO_TEST_SUITE(AllkRANNTest);
-
-// Test AllkRANN in naive mode for exact results when the random seeds are set
-// the same. This may not be the best test; if the implementation of RANN-RS
-// gets random numbers in a different way, then this test might fail.
-BOOST_AUTO_TEST_CASE(AllkRANNNaiveSearchExact)
-{
- // First test on a small set.
- arma::mat rdata(2, 10);
- rdata << 3 << 2 << 4 << 3 << 5 << 6 << 0 << 8 << 3 << 1 << arma::endr <<
- 0 << 3 << 4 << 7 << 8 << 4 << 1 << 0 << 4 << 3 << arma::endr;
-
- arma::mat qdata(2, 3);
- qdata << 3 << 2 << 0 << arma::endr
- << 5 << 3 << 4 << arma::endr;
-
- metric::SquaredEuclideanDistance dMetric;
- double rankApproximation = 30;
- double successProb = 0.95;
-
- // Search for 1 rank-approximate nearest-neighbors in the top 30% of the point
- // (rank error of 3).
- arma::Mat<size_t> neighbors;
- arma::mat distances;
-
- // Test naive rank-approximate search.
- // Predict what the actual RANN-RS result would be.
- math::RandomSeed(0);
-
- size_t numSamples = (size_t) ceil(log(1.0 / (1.0 - successProb)) /
- log(1.0 / (1.0 - (rankApproximation / 100.0))));
-
- arma::Mat<size_t> samples(qdata.n_cols, numSamples);
- for (size_t j = 0; j < qdata.n_cols; j++)
- for (size_t i = 0; i < numSamples; i++)
- samples(j, i) = (size_t) math::RandInt(10);
-
- arma::Col<size_t> rann(qdata.n_cols);
- arma::vec rannDistances(qdata.n_cols);
- rannDistances.fill(DBL_MAX);
-
- for (size_t j = 0; j < qdata.n_cols; j++)
- {
- for (size_t i = 0; i < numSamples; i++)
- {
- double dist = dMetric.Evaluate(qdata.unsafe_col(j),
- rdata.unsafe_col(samples(j, i)));
- if (dist < rannDistances[j])
- {
- rann[j] = samples(j, i);
- rannDistances[j] = dist;
- }
- }
- }
-
- // Use RANN-RS implementation.
- math::RandomSeed(0);
-
- RASearch<> naive(rdata, qdata, true);
- naive.Search(1, neighbors, distances, rankApproximation);
-
- // Things to check:
- //
- // 1. (implicitly) The minimum number of required samples for guaranteed
- // approximation.
- // 2. (implicitly) Check the samples obtained.
- // 3. Check the neighbor returned.
- for (size_t i = 0; i < qdata.n_cols; i++)
- {
- BOOST_REQUIRE(neighbors(0, i) == rann[i]);
- BOOST_REQUIRE_CLOSE(distances(0, i), rannDistances[i], 1e-5);
- }
-}
-
-// Test the correctness and guarantees of AllkRANN when in naive mode.
-BOOST_AUTO_TEST_CASE(AllkRANNNaiveGuaranteeTest)
-{
- arma::Mat<size_t> neighbors;
- arma::mat distances;
-
- arma::mat refData;
- arma::mat queryData;
-
- data::Load("rann_test_r_3_900.csv", refData, true);
- data::Load("rann_test_q_3_100.csv", queryData, true);
-
- RASearch<> rsRann(refData, queryData, true);
-
- arma::mat qrRanks;
- data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
-
- size_t numRounds = 1000;
- arma::Col<size_t> numSuccessRounds(queryData.n_cols);
- numSuccessRounds.fill(0);
-
- // 1% of 900 is 9, so the rank is expected to be less than 10
- size_t expectedRankErrorUB = 10;
-
- for (size_t rounds = 0; rounds < numRounds; rounds++)
- {
- rsRann.Search(1, neighbors, distances, 1.0);
-
- for (size_t i = 0; i < queryData.n_cols; i++)
- if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
- numSuccessRounds[i]++;
-
- neighbors.reset();
- distances.reset();
- }
-
- // Find the 95%-tile threshold so that 95% of the queries should pass this
- // threshold.
- size_t threshold = floor(numRounds *
- (0.95 - (1.96 * sqrt(0.95 * 0.05 / numRounds))));
- size_t numQueriesFail = 0;
- for (size_t i = 0; i < queryData.n_cols; i++)
- if (numSuccessRounds[i] < threshold)
- numQueriesFail++;
-
- Log::Warn << "RANN-RS: RANN guarantee fails on " << numQueriesFail
- << " queries." << endl;
-
- // assert that at most 5% of the queries fall out of this threshold
- // 5% of 100 queries is 5.
- size_t maxNumQueriesFail = 6;
-
- BOOST_REQUIRE(numQueriesFail < maxNumQueriesFail);
-}
-
-// Test single-tree rank-approximate search (harder to test because of
-// the randomness involved).
-BOOST_AUTO_TEST_CASE(AllkRANNSingleTreeSearch)
-{
- arma::mat refData;
- arma::mat queryData;
-
- data::Load("rann_test_r_3_900.csv", refData, true);
- data::Load("rann_test_q_3_100.csv", queryData, true);
-
- // Search for 1 rank-approximate nearest-neighbors in the top 30% of the point
- // (rank error of 3).
- arma::Mat<size_t> neighbors;
- arma::mat distances;
-
- RASearch<> tssRann(refData, queryData, false, true, 5);
-
- // The relative ranks for the given query reference pair
- arma::Mat<size_t> qrRanks;
- data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
-
- size_t numRounds = 1000;
- arma::Col<size_t> numSuccessRounds(queryData.n_cols);
- numSuccessRounds.fill(0);
-
- // 1% of 900 is 9, so the rank is expected to be less than 10.
- size_t expectedRankErrorUB = 10;
-
- for (size_t rounds = 0; rounds < numRounds; rounds++)
- {
- tssRann.Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
-
- for (size_t i = 0; i < queryData.n_cols; i++)
- if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
- numSuccessRounds[i]++;
-
- neighbors.reset();
- distances.reset();
- }
-
- // Find the 95%-tile threshold so that 95% of the queries should pass this
- // threshold.
- size_t threshold = floor(numRounds *
- (0.95 - (1.96 * sqrt(0.95 * 0.05 / numRounds))));
- size_t numQueriesFail = 0;
- for (size_t i = 0; i < queryData.n_cols; i++)
- if (numSuccessRounds[i] < threshold)
- numQueriesFail++;
-
- Log::Warn << "RANN-TSS: RANN guarantee fails on " << numQueriesFail
- << " queries." << endl;
-
- // Assert that at most 5% of the queries fall out of this threshold.
- // 5% of 100 queries is 5.
- size_t maxNumQueriesFail = 6;
-
- BOOST_REQUIRE(numQueriesFail < maxNumQueriesFail);
-}
-
-// Test dual-tree rank-approximate search (harder to test because of the
-// randomness involved).
-BOOST_AUTO_TEST_CASE(AllkRANNDualTreeSearch)
-{
- arma::mat refData;
- arma::mat queryData;
-
- data::Load("rann_test_r_3_900.csv", refData, true);
- data::Load("rann_test_q_3_100.csv", queryData, true);
-
- // Search for 1 rank-approximate nearest-neighbors in the top 30% of the point
- // (rank error of 3).
- arma::Mat<size_t> neighbors;
- arma::mat distances;
-
- RASearch<> tsdRann(refData, queryData, false, false, 5);
-
- arma::Mat<size_t> qrRanks;
- data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
-
- size_t numRounds = 1000;
- arma::Col<size_t> numSuccessRounds(queryData.n_cols);
- numSuccessRounds.fill(0);
-
- // 1% of 900 is 9, so the rank is expected to be less than 10.
- size_t expectedRankErrorUB = 10;
-
- for (size_t rounds = 0; rounds < numRounds; rounds++)
- {
- tsdRann.Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
-
- for (size_t i = 0; i < queryData.n_cols; i++)
- if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
- numSuccessRounds[i]++;
-
- neighbors.reset();
- distances.reset();
-
- tsdRann.ResetQueryTree();
- }
-
- // Find the 95%-tile threshold so that 95% of the queries should pass this
- // threshold.
- size_t threshold = floor(numRounds *
- (0.95 - (1.96 * sqrt(0.95 * 0.05 / numRounds))));
- size_t numQueriesFail = 0;
- for (size_t i = 0; i < queryData.n_cols; i++)
- if (numSuccessRounds[i] < threshold)
- numQueriesFail++;
-
- Log::Warn << "RANN-TSD: RANN guarantee fails on " << numQueriesFail
- << " queries." << endl;
-
- // assert that at most 5% of the queries fall out of this threshold
- // 5% of 100 queries is 5.
- size_t maxNumQueriesFail = 6;
-
- BOOST_REQUIRE(numQueriesFail < maxNumQueriesFail);
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/allkrann_search_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/allkrann_search_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/allkrann_search_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/allkrann_search_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,287 @@
+/**
+ * @file allkrann_search_test.cpp
+ *
+ * Unit tests for the 'RASearch' class and consequently the 'RASearchRules'
+ * class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <time.h>
+#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+
+// So that we can test private members. This is hackish (for now).
+#define private public
+#include <mlpack/methods/rann/ra_search.hpp>
+#undef private
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::neighbor;
+
+
+BOOST_AUTO_TEST_SUITE(AllkRANNTest);
+
+// Test AllkRANN in naive mode for exact results when the random seeds are set
+// the same. This may not be the best test; if the implementation of RANN-RS
+// gets random numbers in a different way, then this test might fail.
+BOOST_AUTO_TEST_CASE(AllkRANNNaiveSearchExact)
+{
+ // First test on a small set.
+ arma::mat rdata(2, 10);
+ rdata << 3 << 2 << 4 << 3 << 5 << 6 << 0 << 8 << 3 << 1 << arma::endr <<
+ 0 << 3 << 4 << 7 << 8 << 4 << 1 << 0 << 4 << 3 << arma::endr;
+
+ arma::mat qdata(2, 3);
+ qdata << 3 << 2 << 0 << arma::endr
+ << 5 << 3 << 4 << arma::endr;
+
+ metric::SquaredEuclideanDistance dMetric;
+ double rankApproximation = 30;
+ double successProb = 0.95;
+
+ // Search for 1 rank-approximate nearest-neighbors in the top 30% of the point
+ // (rank error of 3).
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ // Test naive rank-approximate search.
+ // Predict what the actual RANN-RS result would be.
+ math::RandomSeed(0);
+
+ size_t numSamples = (size_t) ceil(log(1.0 / (1.0 - successProb)) /
+ log(1.0 / (1.0 - (rankApproximation / 100.0))));
+
+ arma::Mat<size_t> samples(qdata.n_cols, numSamples);
+ for (size_t j = 0; j < qdata.n_cols; j++)
+ for (size_t i = 0; i < numSamples; i++)
+ samples(j, i) = (size_t) math::RandInt(10);
+
+ arma::Col<size_t> rann(qdata.n_cols);
+ arma::vec rannDistances(qdata.n_cols);
+ rannDistances.fill(DBL_MAX);
+
+ for (size_t j = 0; j < qdata.n_cols; j++)
+ {
+ for (size_t i = 0; i < numSamples; i++)
+ {
+ double dist = dMetric.Evaluate(qdata.unsafe_col(j),
+ rdata.unsafe_col(samples(j, i)));
+ if (dist < rannDistances[j])
+ {
+ rann[j] = samples(j, i);
+ rannDistances[j] = dist;
+ }
+ }
+ }
+
+ // Use RANN-RS implementation.
+ math::RandomSeed(0);
+
+ RASearch<> naive(rdata, qdata, true);
+ naive.Search(1, neighbors, distances, rankApproximation);
+
+ // Things to check:
+ //
+ // 1. (implicitly) The minimum number of required samples for guaranteed
+ // approximation.
+ // 2. (implicitly) Check the samples obtained.
+ // 3. Check the neighbor returned.
+ for (size_t i = 0; i < qdata.n_cols; i++)
+ {
+ BOOST_REQUIRE(neighbors(0, i) == rann[i]);
+ BOOST_REQUIRE_CLOSE(distances(0, i), rannDistances[i], 1e-5);
+ }
+}
+
+// Test the correctness and guarantees of AllkRANN when in naive mode.
+BOOST_AUTO_TEST_CASE(AllkRANNNaiveGuaranteeTest)
+{
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ arma::mat refData;
+ arma::mat queryData;
+
+ data::Load("rann_test_r_3_900.csv", refData, true);
+ data::Load("rann_test_q_3_100.csv", queryData, true);
+
+ RASearch<> rsRann(refData, queryData, true);
+
+ arma::mat qrRanks;
+ data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
+
+ size_t numRounds = 1000;
+ arma::Col<size_t> numSuccessRounds(queryData.n_cols);
+ numSuccessRounds.fill(0);
+
+ // 1% of 900 is 9, so the rank is expected to be less than 10
+ size_t expectedRankErrorUB = 10;
+
+ for (size_t rounds = 0; rounds < numRounds; rounds++)
+ {
+ rsRann.Search(1, neighbors, distances, 1.0);
+
+ for (size_t i = 0; i < queryData.n_cols; i++)
+ if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
+ numSuccessRounds[i]++;
+
+ neighbors.reset();
+ distances.reset();
+ }
+
+ // Find the 95%-tile threshold so that 95% of the queries should pass this
+ // threshold.
+ size_t threshold = floor(numRounds *
+ (0.95 - (1.96 * sqrt(0.95 * 0.05 / numRounds))));
+ size_t numQueriesFail = 0;
+ for (size_t i = 0; i < queryData.n_cols; i++)
+ if (numSuccessRounds[i] < threshold)
+ numQueriesFail++;
+
+ Log::Warn << "RANN-RS: RANN guarantee fails on " << numQueriesFail
+ << " queries." << endl;
+
+ // assert that at most 5% of the queries fall out of this threshold
+ // 5% of 100 queries is 5.
+ size_t maxNumQueriesFail = 6;
+
+ BOOST_REQUIRE(numQueriesFail < maxNumQueriesFail);
+}
+
+// Test single-tree rank-approximate search (harder to test because of
+// the randomness involved).
+BOOST_AUTO_TEST_CASE(AllkRANNSingleTreeSearch)
+{
+ arma::mat refData;
+ arma::mat queryData;
+
+ data::Load("rann_test_r_3_900.csv", refData, true);
+ data::Load("rann_test_q_3_100.csv", queryData, true);
+
+ // Search for 1 rank-approximate nearest-neighbors in the top 30% of the point
+ // (rank error of 3).
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ RASearch<> tssRann(refData, queryData, false, true, 5);
+
+ // The relative ranks for the given query reference pair
+ arma::Mat<size_t> qrRanks;
+ data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
+
+ size_t numRounds = 1000;
+ arma::Col<size_t> numSuccessRounds(queryData.n_cols);
+ numSuccessRounds.fill(0);
+
+ // 1% of 900 is 9, so the rank is expected to be less than 10.
+ size_t expectedRankErrorUB = 10;
+
+ for (size_t rounds = 0; rounds < numRounds; rounds++)
+ {
+ tssRann.Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
+
+ for (size_t i = 0; i < queryData.n_cols; i++)
+ if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
+ numSuccessRounds[i]++;
+
+ neighbors.reset();
+ distances.reset();
+ }
+
+ // Find the 95%-tile threshold so that 95% of the queries should pass this
+ // threshold.
+ size_t threshold = floor(numRounds *
+ (0.95 - (1.96 * sqrt(0.95 * 0.05 / numRounds))));
+ size_t numQueriesFail = 0;
+ for (size_t i = 0; i < queryData.n_cols; i++)
+ if (numSuccessRounds[i] < threshold)
+ numQueriesFail++;
+
+ Log::Warn << "RANN-TSS: RANN guarantee fails on " << numQueriesFail
+ << " queries." << endl;
+
+ // Assert that at most 5% of the queries fall out of this threshold.
+ // 5% of 100 queries is 5.
+ size_t maxNumQueriesFail = 6;
+
+ BOOST_REQUIRE(numQueriesFail < maxNumQueriesFail);
+}
+
+// Test dual-tree rank-approximate search (harder to test because of the
+// randomness involved).
+BOOST_AUTO_TEST_CASE(AllkRANNDualTreeSearch)
+{
+ arma::mat refData;
+ arma::mat queryData;
+
+ data::Load("rann_test_r_3_900.csv", refData, true);
+ data::Load("rann_test_q_3_100.csv", queryData, true);
+
+ // Search for 1 rank-approximate nearest-neighbors in the top 30% of the point
+ // (rank error of 3).
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ RASearch<> tsdRann(refData, queryData, false, false, 5);
+
+ arma::Mat<size_t> qrRanks;
+ data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
+
+ size_t numRounds = 1000;
+ arma::Col<size_t> numSuccessRounds(queryData.n_cols);
+ numSuccessRounds.fill(0);
+
+ // 1% of 900 is 9, so the rank is expected to be less than 10.
+ size_t expectedRankErrorUB = 10;
+
+ for (size_t rounds = 0; rounds < numRounds; rounds++)
+ {
+ tsdRann.Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
+
+ for (size_t i = 0; i < queryData.n_cols; i++)
+ if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
+ numSuccessRounds[i]++;
+
+ neighbors.reset();
+ distances.reset();
+
+ tsdRann.ResetQueryTree();
+ }
+
+ // Find the 95%-tile threshold so that 95% of the queries should pass this
+ // threshold.
+ size_t threshold = floor(numRounds *
+ (0.95 - (1.96 * sqrt(0.95 * 0.05 / numRounds))));
+ size_t numQueriesFail = 0;
+ for (size_t i = 0; i < queryData.n_cols; i++)
+ if (numSuccessRounds[i] < threshold)
+ numQueriesFail++;
+
+ Log::Warn << "RANN-TSD: RANN guarantee fails on " << numQueriesFail
+ << " queries." << endl;
+
+ // assert that at most 5% of the queries fall out of this threshold
+ // 5% of 100 queries is 5.
+ size_t maxNumQueriesFail = 6;
+
+ BOOST_REQUIRE(numQueriesFail < maxNumQueriesFail);
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/arma_extend_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/arma_extend_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/arma_extend_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,65 +0,0 @@
-/**
- * @file arma_extend_test.cpp
- * @author Ryan Curtin
- *
- * Test of the MLPACK extensions to Armadillo.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#include <mlpack/core.hpp>
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-
-BOOST_AUTO_TEST_SUITE(ArmaExtendTest);
-
-/**
- * Make sure we can reshape a matrix in-place without changing anything.
- */
-BOOST_AUTO_TEST_CASE(InplaceReshapeColumnTest)
-{
- arma::mat X;
- X.randu(1, 10);
- arma::mat oldX = X;
-
- arma::inplace_reshape(X, 2, 5);
-
- BOOST_REQUIRE_EQUAL(X.n_rows, 2);
- BOOST_REQUIRE_EQUAL(X.n_cols, 5);
- for (size_t i = 0; i < 10; ++i)
- BOOST_REQUIRE_CLOSE(X[i], oldX[i], 1e-5); // Order should be preserved.
-}
-
-/**
- * Make sure we can reshape a large matrix.
- */
-BOOST_AUTO_TEST_CASE(InplaceReshapeMatrixTest)
-{
- arma::mat X;
- X.randu(8, 10);
- arma::mat oldX = X;
-
- arma::inplace_reshape(X, 10, 8);
-
- BOOST_REQUIRE_EQUAL(X.n_rows, 10);
- BOOST_REQUIRE_EQUAL(X.n_cols, 8);
- for (size_t i = 0; i < 80; ++i)
- BOOST_REQUIRE_CLOSE(X[i], oldX[i], 1e-5); // Order should be preserved.
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/arma_extend_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/arma_extend_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/arma_extend_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/arma_extend_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,65 @@
+/**
+ * @file arma_extend_test.cpp
+ * @author Ryan Curtin
+ *
+ * Test of the MLPACK extensions to Armadillo.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include <mlpack/core.hpp>
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+
+BOOST_AUTO_TEST_SUITE(ArmaExtendTest);
+
+/**
+ * Make sure we can reshape a matrix in-place without changing anything.
+ */
+BOOST_AUTO_TEST_CASE(InplaceReshapeColumnTest)
+{
+ arma::mat X;
+ X.randu(1, 10);
+ arma::mat oldX = X;
+
+ arma::inplace_reshape(X, 2, 5);
+
+ BOOST_REQUIRE_EQUAL(X.n_rows, 2);
+ BOOST_REQUIRE_EQUAL(X.n_cols, 5);
+ for (size_t i = 0; i < 10; ++i)
+ BOOST_REQUIRE_CLOSE(X[i], oldX[i], 1e-5); // Order should be preserved.
+}
+
+/**
+ * Make sure we can reshape a large matrix.
+ */
+BOOST_AUTO_TEST_CASE(InplaceReshapeMatrixTest)
+{
+ arma::mat X;
+ X.randu(8, 10);
+ arma::mat oldX = X;
+
+ arma::inplace_reshape(X, 10, 8);
+
+ BOOST_REQUIRE_EQUAL(X.n_rows, 10);
+ BOOST_REQUIRE_EQUAL(X.n_cols, 8);
+ for (size_t i = 0; i < 80; ++i)
+ BOOST_REQUIRE_CLOSE(X[i], oldX[i], 1e-5); // Order should be preserved.
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/aug_lagrangian_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/aug_lagrangian_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/aug_lagrangian_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,79 +0,0 @@
-/**
- * @file aug_lagrangian_test.cpp
- * @author Ryan Curtin
- *
- * Test of the AugmentedLagrangian class using the test functions defined in
- * aug_lagrangian_test_functions.hpp.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#include <mlpack/core.hpp>
-#include <mlpack/core/optimizers/aug_lagrangian/aug_lagrangian.hpp>
-#include <mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.hpp>
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::optimization;
-
-BOOST_AUTO_TEST_SUITE(AugLagrangianTest);
-
-/**
- * Tests the Augmented Lagrangian optimizer using the
- * AugmentedLagrangianTestFunction class.
- */
-BOOST_AUTO_TEST_CASE(AugLagrangianTestFunctionTest)
-{
- // The choice of 10 memory slots is arbitrary.
- AugLagrangianTestFunction f;
- AugLagrangian<AugLagrangianTestFunction> aug(f);
-
- arma::vec coords = f.GetInitialPoint();
-
- if (!aug.Optimize(coords, 0))
- BOOST_FAIL("Optimization reported failure.");
-
- double finalValue = f.Evaluate(coords);
-
- BOOST_REQUIRE_CLOSE(finalValue, 70.0, 1e-5);
- BOOST_REQUIRE_CLOSE(coords[0], 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(coords[1], 4.0, 1e-5);
-}
-
-/**
- * Tests the Augmented Lagrangian optimizer using the Gockenbach function.
- */
-BOOST_AUTO_TEST_CASE(GockenbachFunctionTest)
-{
- GockenbachFunction f;
- AugLagrangian<GockenbachFunction> aug(f);
-
- arma::vec coords = f.GetInitialPoint();
-
- if (!aug.Optimize(coords, 0))
- BOOST_FAIL("Optimization reported failure.");
-
- double finalValue = f.Evaluate(coords);
-
- // Higher tolerance for smaller values.
- BOOST_REQUIRE_CLOSE(finalValue, 29.633926, 1e-5);
- BOOST_REQUIRE_CLOSE(coords[0], 0.12288178, 1e-3);
- BOOST_REQUIRE_CLOSE(coords[1], -1.10778185, 1e-5);
- BOOST_REQUIRE_CLOSE(coords[2], 0.015099932, 1e-3);
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/aug_lagrangian_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/aug_lagrangian_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/aug_lagrangian_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/aug_lagrangian_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,79 @@
+/**
+ * @file aug_lagrangian_test.cpp
+ * @author Ryan Curtin
+ *
+ * Test of the AugmentedLagrangian class using the test functions defined in
+ * aug_lagrangian_test_functions.hpp.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/optimizers/aug_lagrangian/aug_lagrangian.hpp>
+#include <mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.hpp>
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::optimization;
+
+BOOST_AUTO_TEST_SUITE(AugLagrangianTest);
+
+/**
+ * Tests the Augmented Lagrangian optimizer using the
+ * AugmentedLagrangianTestFunction class.
+ */
+BOOST_AUTO_TEST_CASE(AugLagrangianTestFunctionTest)
+{
+ // The choice of 10 memory slots is arbitrary.
+ AugLagrangianTestFunction f;
+ AugLagrangian<AugLagrangianTestFunction> aug(f);
+
+ arma::vec coords = f.GetInitialPoint();
+
+ if (!aug.Optimize(coords, 0))
+ BOOST_FAIL("Optimization reported failure.");
+
+ double finalValue = f.Evaluate(coords);
+
+ BOOST_REQUIRE_CLOSE(finalValue, 70.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(coords[0], 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(coords[1], 4.0, 1e-5);
+}
+
+/**
+ * Tests the Augmented Lagrangian optimizer using the Gockenbach function.
+ */
+BOOST_AUTO_TEST_CASE(GockenbachFunctionTest)
+{
+ GockenbachFunction f;
+ AugLagrangian<GockenbachFunction> aug(f);
+
+ arma::vec coords = f.GetInitialPoint();
+
+ if (!aug.Optimize(coords, 0))
+ BOOST_FAIL("Optimization reported failure.");
+
+ double finalValue = f.Evaluate(coords);
+
+ // Higher tolerance for smaller values.
+ BOOST_REQUIRE_CLOSE(finalValue, 29.633926, 1e-5);
+ BOOST_REQUIRE_CLOSE(coords[0], 0.12288178, 1e-3);
+ BOOST_REQUIRE_CLOSE(coords[1], -1.10778185, 1e-5);
+ BOOST_REQUIRE_CLOSE(coords[2], 0.015099932, 1e-3);
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/cli_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/cli_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/cli_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,273 +0,0 @@
-/**
- * @file cli_test.cpp
- * @author Matthew Amidon, Ryan Curtin
- *
- * Test for the CLI input parameter system.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#include <iostream>
-#include <sstream>
-#ifndef _WIN32
- #include <sys/time.h>
-#endif
-
-// For Sleep().
-#ifdef _WIN32
- #include <Windows.h>
-#endif
-
-#include <mlpack/core.hpp>
-
-#define DEFAULT_INT 42
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-#define BASH_RED "\033[0;31m"
-#define BASH_GREEN "\033[0;32m"
-#define BASH_YELLOW "\033[0;33m"
-#define BASH_CYAN "\033[0;36m"
-#define BASH_CLEAR "\033[0m"
-
-using namespace mlpack;
-using namespace mlpack::util;
-
-BOOST_AUTO_TEST_SUITE(CLITest);
-
-/**
- * Tests that CLI works as intended, namely that CLI::Add propagates
- * successfully.
- */
-BOOST_AUTO_TEST_CASE(TestCLIAdd)
-{
- // Check that the CLI::HasParam returns false if no value has been specified
- // on the commandline and ignores any programmatical assignments.
- CLI::Add<bool>("global/bool", "True or False", "alias/bool");
-
- // CLI::HasParam should return false here.
- BOOST_REQUIRE(!CLI::HasParam("global/bool"));
-
- // Check the description of our variable.
- BOOST_REQUIRE_EQUAL(CLI::GetDescription("global/bool").compare(
- std::string("True or False")) , 0);
-
- // Check that our aliasing works.
- BOOST_REQUIRE_EQUAL(CLI::HasParam("global/bool"),
- CLI::HasParam("alias/bool"));
- BOOST_REQUIRE_EQUAL(CLI::GetDescription("global/bool").compare(
- CLI::GetDescription("alias/bool")), 0);
- BOOST_REQUIRE_EQUAL(CLI::GetParam<bool>("global/bool"),
- CLI::GetParam<bool>("alias/bool"));
-}
-
-/**
- * Test the output of CLI. We will pass bogus input to a stringstream so that
- * none of it gets to the screen.
- */
-BOOST_AUTO_TEST_CASE(TestPrefixedOutStreamBasic)
-{
- std::stringstream ss;
- PrefixedOutStream pss(ss, BASH_GREEN "[INFO ] " BASH_CLEAR);
-
- pss << "This shouldn't break anything" << std::endl;
- BOOST_REQUIRE_EQUAL(ss.str(),
- BASH_GREEN "[INFO ] " BASH_CLEAR "This shouldn't break anything\n");
-
- ss.str("");
- pss << "Test the new lines...";
- pss << "shouldn't get 'Info' here." << std::endl;
- BOOST_REQUIRE_EQUAL(ss.str(),
- BASH_GREEN "[INFO ] " BASH_CLEAR
- "Test the new lines...shouldn't get 'Info' here.\n");
-
- pss << "But now I should." << std::endl << std::endl;
- pss << "";
- BOOST_REQUIRE_EQUAL(ss.str(),
- BASH_GREEN "[INFO ] " BASH_CLEAR
- "Test the new lines...shouldn't get 'Info' here.\n"
- BASH_GREEN "[INFO ] " BASH_CLEAR "But now I should.\n"
- BASH_GREEN "[INFO ] " BASH_CLEAR "\n"
- BASH_GREEN "[INFO ] " BASH_CLEAR "");
-}
-
-/**
- * Tests that the various PARAM_* macros work properly.
- */
-BOOST_AUTO_TEST_CASE(TestOption)
-{
- // This test will involve creating an option, and making sure CLI reflects
- // this.
- PARAM(int, "test_parent/test", "test desc", "", DEFAULT_INT, false);
-
- BOOST_REQUIRE_EQUAL(CLI::GetDescription("test_parent/test"), "test desc");
- BOOST_REQUIRE_EQUAL(CLI::GetParam<int>("test_parent/test"), DEFAULT_INT);
-}
-
-/**
- * Ensure that a Boolean option which we define is set correctly.
- */
-BOOST_AUTO_TEST_CASE(TestBooleanOption)
-{
- PARAM_FLAG("test_parent/flag_test", "flag test description", "");
-
- BOOST_REQUIRE_EQUAL(CLI::HasParam("test_parent/flag_test"), false);
-
- BOOST_REQUIRE_EQUAL(CLI::GetDescription("test_parent/flag_test"),
- "flag test description");
-
- // Now check that CLI reflects that it is false by default.
- BOOST_REQUIRE_EQUAL(CLI::GetParam<bool>("test_parent/flag_test"), false);
-}
-
-/**
- * Test that we can correctly output Armadillo objects to PrefixedOutStream
- * objects.
- */
-BOOST_AUTO_TEST_CASE(TestArmadilloPrefixedOutStream)
-{
- // We will test this with both a vector and a matrix.
- arma::vec test("1.0 1.5 2.0 2.5 3.0 3.5 4.0");
-
- std::stringstream ss;
- PrefixedOutStream pss(ss, BASH_GREEN "[INFO ] " BASH_CLEAR);
-
- pss << test;
- // This should result in nothing being on the current line (since it clears
- // it).
- BOOST_REQUIRE_EQUAL(ss.str(), BASH_GREEN "[INFO ] " BASH_CLEAR " 1.0000\n"
- BASH_GREEN "[INFO ] " BASH_CLEAR " 1.5000\n"
- BASH_GREEN "[INFO ] " BASH_CLEAR " 2.0000\n"
- BASH_GREEN "[INFO ] " BASH_CLEAR " 2.5000\n"
- BASH_GREEN "[INFO ] " BASH_CLEAR " 3.0000\n"
- BASH_GREEN "[INFO ] " BASH_CLEAR " 3.5000\n"
- BASH_GREEN "[INFO ] " BASH_CLEAR " 4.0000\n");
-
- ss.str("");
- pss << trans(test);
- // This should result in there being stuff on the line.
- BOOST_REQUIRE_EQUAL(ss.str(), BASH_GREEN "[INFO ] " BASH_CLEAR
- " 1.0000 1.5000 2.0000 2.5000 3.0000 3.5000 4.0000\n");
-
- arma::mat test2("1.0 1.5 2.0; 2.5 3.0 3.5; 4.0 4.5 4.99999");
- ss.str("");
- pss << test2;
- BOOST_REQUIRE_EQUAL(ss.str(),
- BASH_GREEN "[INFO ] " BASH_CLEAR " 1.0000 1.5000 2.0000\n"
- BASH_GREEN "[INFO ] " BASH_CLEAR " 2.5000 3.0000 3.5000\n"
- BASH_GREEN "[INFO ] " BASH_CLEAR " 4.0000 4.5000 5.0000\n");
-
- // Try and throw a curveball by not clearing the line before outputting
- // something else. The PrefixedOutStream should not force Armadillo objects
- // onto their own lines.
- ss.str("");
- pss << "hello" << test2;
- BOOST_REQUIRE_EQUAL(ss.str(),
- BASH_GREEN "[INFO ] " BASH_CLEAR "hello 1.0000 1.5000 2.0000\n"
- BASH_GREEN "[INFO ] " BASH_CLEAR " 2.5000 3.0000 3.5000\n"
- BASH_GREEN "[INFO ] " BASH_CLEAR " 4.0000 4.5000 5.0000\n");
-}
-
-/**
- * Test that we can correctly output things in general.
- */
-BOOST_AUTO_TEST_CASE(TestPrefixedOutStream)
-{
- std::stringstream ss;
- PrefixedOutStream pss(ss, BASH_GREEN "[INFO ] " BASH_CLEAR);
-
- pss << "hello world I am ";
- pss << 7;
-
- BOOST_REQUIRE_EQUAL(ss.str(),
- BASH_GREEN "[INFO ] " BASH_CLEAR "hello world I am 7");
-
- pss << std::endl;
- BOOST_REQUIRE_EQUAL(ss.str(),
- BASH_GREEN "[INFO ] " BASH_CLEAR "hello world I am 7\n");
-
- ss.str("");
- pss << std::endl;
- BOOST_REQUIRE_EQUAL(ss.str(),
- BASH_GREEN "[INFO ] " BASH_CLEAR "\n");
-}
-
-/**
- * Test format modifiers.
- */
-BOOST_AUTO_TEST_CASE(TestPrefixedOutStreamModifiers)
-{
- std::stringstream ss;
- PrefixedOutStream pss(ss, BASH_GREEN "[INFO ] " BASH_CLEAR);
-
- pss << "I have a precise number which is ";
- pss << std::setw(6) << std::setfill('0') << (int)156;
-
- BOOST_REQUIRE_EQUAL(ss.str(),
- BASH_GREEN "[INFO ] " BASH_CLEAR
- "I have a precise number which is 000156");
-}
-
-/**
- * We should be able to start and then stop a timer multiple times and it should
- * save the value.
- */
-BOOST_AUTO_TEST_CASE(MultiRunTimerTest)
-{
- Timer::Start("test_timer");
-
- // On Windows (or, at least, in Windows not using VS2010) we cannot use
- // usleep() because it is not provided. Instead we will use Sleep() for a
- // number of milliseconds.
- #ifdef _WIN32
- Sleep(10);
- #else
- usleep(10000);
- #endif
-
- Timer::Stop("test_timer");
-
- BOOST_REQUIRE_GE(Timer::Get("test_timer").tv_usec, 10000);
-
- // Restart it.
- Timer::Start("test_timer");
-
- #ifdef _WIN32
- Sleep(10);
- #else
- usleep(10000);
- #endif
-
- Timer::Stop("test_timer");
-
- BOOST_REQUIRE_GE(Timer::Get("test_timer").tv_usec, 20000);
-
- // Just one more time, for good measure...
- Timer::Start("test_timer");
-
- #ifdef _WIN32
- Sleep(20);
- #else
- usleep(20000);
- #endif
-
- Timer::Stop("test_timer");
-
- BOOST_REQUIRE_GE(Timer::Get("test_timer").tv_usec, 40000);
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/cli_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/cli_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/cli_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/cli_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,273 @@
+/**
+ * @file cli_test.cpp
+ * @author Matthew Amidon, Ryan Curtin
+ *
+ * Test for the CLI input parameter system.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include <iostream>
+#include <sstream>
+#ifndef _WIN32
+ #include <sys/time.h>
+#endif
+
+// For Sleep().
+#ifdef _WIN32
+ #include <Windows.h>
+#endif
+
+#include <mlpack/core.hpp>
+
+#define DEFAULT_INT 42
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+#define BASH_RED "\033[0;31m"
+#define BASH_GREEN "\033[0;32m"
+#define BASH_YELLOW "\033[0;33m"
+#define BASH_CYAN "\033[0;36m"
+#define BASH_CLEAR "\033[0m"
+
+using namespace mlpack;
+using namespace mlpack::util;
+
+BOOST_AUTO_TEST_SUITE(CLITest);
+
+/**
+ * Tests that CLI works as intended, namely that CLI::Add propagates
+ * successfully.
+ */
+BOOST_AUTO_TEST_CASE(TestCLIAdd)
+{
+ // Check that the CLI::HasParam returns false if no value has been specified
+ // on the commandline and ignores any programmatical assignments.
+ CLI::Add<bool>("global/bool", "True or False", "alias/bool");
+
+ // CLI::HasParam should return false here.
+ BOOST_REQUIRE(!CLI::HasParam("global/bool"));
+
+ // Check the description of our variable.
+ BOOST_REQUIRE_EQUAL(CLI::GetDescription("global/bool").compare(
+ std::string("True or False")) , 0);
+
+ // Check that our aliasing works.
+ BOOST_REQUIRE_EQUAL(CLI::HasParam("global/bool"),
+ CLI::HasParam("alias/bool"));
+ BOOST_REQUIRE_EQUAL(CLI::GetDescription("global/bool").compare(
+ CLI::GetDescription("alias/bool")), 0);
+ BOOST_REQUIRE_EQUAL(CLI::GetParam<bool>("global/bool"),
+ CLI::GetParam<bool>("alias/bool"));
+}
+
+/**
+ * Test the output of CLI. We will pass bogus input to a stringstream so that
+ * none of it gets to the screen.
+ */
+BOOST_AUTO_TEST_CASE(TestPrefixedOutStreamBasic)
+{
+ std::stringstream ss;
+ PrefixedOutStream pss(ss, BASH_GREEN "[INFO ] " BASH_CLEAR);
+
+ pss << "This shouldn't break anything" << std::endl;
+ BOOST_REQUIRE_EQUAL(ss.str(),
+ BASH_GREEN "[INFO ] " BASH_CLEAR "This shouldn't break anything\n");
+
+ ss.str("");
+ pss << "Test the new lines...";
+ pss << "shouldn't get 'Info' here." << std::endl;
+ BOOST_REQUIRE_EQUAL(ss.str(),
+ BASH_GREEN "[INFO ] " BASH_CLEAR
+ "Test the new lines...shouldn't get 'Info' here.\n");
+
+ pss << "But now I should." << std::endl << std::endl;
+ pss << "";
+ BOOST_REQUIRE_EQUAL(ss.str(),
+ BASH_GREEN "[INFO ] " BASH_CLEAR
+ "Test the new lines...shouldn't get 'Info' here.\n"
+ BASH_GREEN "[INFO ] " BASH_CLEAR "But now I should.\n"
+ BASH_GREEN "[INFO ] " BASH_CLEAR "\n"
+ BASH_GREEN "[INFO ] " BASH_CLEAR "");
+}
+
+/**
+ * Tests that the various PARAM_* macros work properly.
+ */
+BOOST_AUTO_TEST_CASE(TestOption)
+{
+ // This test will involve creating an option, and making sure CLI reflects
+ // this.
+ PARAM(int, "test_parent/test", "test desc", "", DEFAULT_INT, false);
+
+ BOOST_REQUIRE_EQUAL(CLI::GetDescription("test_parent/test"), "test desc");
+ BOOST_REQUIRE_EQUAL(CLI::GetParam<int>("test_parent/test"), DEFAULT_INT);
+}
+
+/**
+ * Ensure that a Boolean option which we define is set correctly.
+ */
+BOOST_AUTO_TEST_CASE(TestBooleanOption)
+{
+ PARAM_FLAG("test_parent/flag_test", "flag test description", "");
+
+ BOOST_REQUIRE_EQUAL(CLI::HasParam("test_parent/flag_test"), false);
+
+ BOOST_REQUIRE_EQUAL(CLI::GetDescription("test_parent/flag_test"),
+ "flag test description");
+
+ // Now check that CLI reflects that it is false by default.
+ BOOST_REQUIRE_EQUAL(CLI::GetParam<bool>("test_parent/flag_test"), false);
+}
+
+/**
+ * Test that we can correctly output Armadillo objects to PrefixedOutStream
+ * objects.
+ */
+BOOST_AUTO_TEST_CASE(TestArmadilloPrefixedOutStream)
+{
+ // We will test this with both a vector and a matrix.
+ arma::vec test("1.0 1.5 2.0 2.5 3.0 3.5 4.0");
+
+ std::stringstream ss;
+ PrefixedOutStream pss(ss, BASH_GREEN "[INFO ] " BASH_CLEAR);
+
+ pss << test;
+ // This should result in nothing being on the current line (since it clears
+ // it).
+ BOOST_REQUIRE_EQUAL(ss.str(), BASH_GREEN "[INFO ] " BASH_CLEAR " 1.0000\n"
+ BASH_GREEN "[INFO ] " BASH_CLEAR " 1.5000\n"
+ BASH_GREEN "[INFO ] " BASH_CLEAR " 2.0000\n"
+ BASH_GREEN "[INFO ] " BASH_CLEAR " 2.5000\n"
+ BASH_GREEN "[INFO ] " BASH_CLEAR " 3.0000\n"
+ BASH_GREEN "[INFO ] " BASH_CLEAR " 3.5000\n"
+ BASH_GREEN "[INFO ] " BASH_CLEAR " 4.0000\n");
+
+ ss.str("");
+ pss << trans(test);
+ // This should result in there being stuff on the line.
+ BOOST_REQUIRE_EQUAL(ss.str(), BASH_GREEN "[INFO ] " BASH_CLEAR
+ " 1.0000 1.5000 2.0000 2.5000 3.0000 3.5000 4.0000\n");
+
+ arma::mat test2("1.0 1.5 2.0; 2.5 3.0 3.5; 4.0 4.5 4.99999");
+ ss.str("");
+ pss << test2;
+ BOOST_REQUIRE_EQUAL(ss.str(),
+ BASH_GREEN "[INFO ] " BASH_CLEAR " 1.0000 1.5000 2.0000\n"
+ BASH_GREEN "[INFO ] " BASH_CLEAR " 2.5000 3.0000 3.5000\n"
+ BASH_GREEN "[INFO ] " BASH_CLEAR " 4.0000 4.5000 5.0000\n");
+
+ // Try and throw a curveball by not clearing the line before outputting
+ // something else. The PrefixedOutStream should not force Armadillo objects
+ // onto their own lines.
+ ss.str("");
+ pss << "hello" << test2;
+ BOOST_REQUIRE_EQUAL(ss.str(),
+ BASH_GREEN "[INFO ] " BASH_CLEAR "hello 1.0000 1.5000 2.0000\n"
+ BASH_GREEN "[INFO ] " BASH_CLEAR " 2.5000 3.0000 3.5000\n"
+ BASH_GREEN "[INFO ] " BASH_CLEAR " 4.0000 4.5000 5.0000\n");
+}
+
+/**
+ * Test that we can correctly output things in general.
+ */
+BOOST_AUTO_TEST_CASE(TestPrefixedOutStream)
+{
+ std::stringstream ss;
+ PrefixedOutStream pss(ss, BASH_GREEN "[INFO ] " BASH_CLEAR);
+
+ pss << "hello world I am ";
+ pss << 7;
+
+ BOOST_REQUIRE_EQUAL(ss.str(),
+ BASH_GREEN "[INFO ] " BASH_CLEAR "hello world I am 7");
+
+ pss << std::endl;
+ BOOST_REQUIRE_EQUAL(ss.str(),
+ BASH_GREEN "[INFO ] " BASH_CLEAR "hello world I am 7\n");
+
+ ss.str("");
+ pss << std::endl;
+ BOOST_REQUIRE_EQUAL(ss.str(),
+ BASH_GREEN "[INFO ] " BASH_CLEAR "\n");
+}
+
+/**
+ * Test format modifiers.
+ */
+BOOST_AUTO_TEST_CASE(TestPrefixedOutStreamModifiers)
+{
+ std::stringstream ss;
+ PrefixedOutStream pss(ss, BASH_GREEN "[INFO ] " BASH_CLEAR);
+
+ pss << "I have a precise number which is ";
+ pss << std::setw(6) << std::setfill('0') << (int)156;
+
+ BOOST_REQUIRE_EQUAL(ss.str(),
+ BASH_GREEN "[INFO ] " BASH_CLEAR
+ "I have a precise number which is 000156");
+}
+
+/**
+ * We should be able to start and then stop a timer multiple times and it should
+ * save the value.
+ */
+BOOST_AUTO_TEST_CASE(MultiRunTimerTest)
+{
+ Timer::Start("test_timer");
+
+ // On Windows (or, at least, in Windows not using VS2010) we cannot use
+ // usleep() because it is not provided. Instead we will use Sleep() for a
+ // number of milliseconds.
+ #ifdef _WIN32
+ Sleep(10);
+ #else
+ usleep(10000);
+ #endif
+
+ Timer::Stop("test_timer");
+
+ BOOST_REQUIRE_GE(Timer::Get("test_timer").tv_usec, 10000);
+
+ // Restart it.
+ Timer::Start("test_timer");
+
+ #ifdef _WIN32
+ Sleep(10);
+ #else
+ usleep(10000);
+ #endif
+
+ Timer::Stop("test_timer");
+
+ BOOST_REQUIRE_GE(Timer::Get("test_timer").tv_usec, 20000);
+
+ // Just one more time, for good measure...
+ Timer::Start("test_timer");
+
+ #ifdef _WIN32
+ Sleep(20);
+ #else
+ usleep(20000);
+ #endif
+
+ Timer::Stop("test_timer");
+
+ BOOST_REQUIRE_GE(Timer::Get("test_timer").tv_usec, 40000);
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/det_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/det_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/det_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,363 +0,0 @@
-/**
- * @file det_test.cpp
- * @author Parikshit Ram (pram at cc.gatech.edu)
- *
- * Unit tests for the functions of the class DTree
- * and the utility functions using this class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#define protected public
-#define private public
-#include <mlpack/methods/det/dtree.hpp>
-#include <mlpack/methods/det/dt_utils.hpp>
-#undef protected
-#undef private
-
-#include <mlpack/core.hpp>
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::det;
-using namespace std;
-
-BOOST_AUTO_TEST_SUITE(DETTest);
-
-// Tests for the private functions.
-
-BOOST_AUTO_TEST_CASE(TestGetMaxMinVals)
-{
- arma::mat testData(3, 5);
-
- testData << 4 << 5 << 7 << 3 << 5 << arma::endr
- << 5 << 0 << 1 << 7 << 1 << arma::endr
- << 5 << 6 << 7 << 1 << 8 << arma::endr;
-
- DTree tree(testData);
-
- BOOST_REQUIRE_EQUAL(tree.maxVals[0], 7);
- BOOST_REQUIRE_EQUAL(tree.minVals[0], 3);
- BOOST_REQUIRE_EQUAL(tree.maxVals[1], 7);
- BOOST_REQUIRE_EQUAL(tree.minVals[1], 0);
- BOOST_REQUIRE_EQUAL(tree.maxVals[2], 8);
- BOOST_REQUIRE_EQUAL(tree.minVals[2], 1);
-}
-
-BOOST_AUTO_TEST_CASE(TestComputeNodeError)
-{
- arma::vec maxVals("7 7 8");
- arma::vec minVals("3 0 1");
-
- DTree testDTree(maxVals, minVals, 5);
- double trueNodeError = -log(4.0) - log(7.0) - log(7.0);
-
- BOOST_REQUIRE_CLOSE((double) testDTree.logNegError, trueNodeError, 1e-10);
-
- testDTree.start = 3;
- testDTree.end = 5;
-
- double nodeError = testDTree.LogNegativeError(5);
- trueNodeError = 2 * log(2.0 / 5.0) - log(4.0) - log(7.0) - log(7.0);
- BOOST_REQUIRE_CLOSE(nodeError, trueNodeError, 1e-10);
-}
-
-BOOST_AUTO_TEST_CASE(TestWithinRange)
-{
- arma::vec maxVals("7 7 8");
- arma::vec minVals("3 0 1");
-
- DTree testDTree(maxVals, minVals, 5);
-
- arma::vec testQuery(3);
- testQuery << 4.5 << 2.5 << 2;
-
- BOOST_REQUIRE_EQUAL(testDTree.WithinRange(testQuery), true);
-
- testQuery << 8.5 << 2.5 << 2;
-
- BOOST_REQUIRE_EQUAL(testDTree.WithinRange(testQuery), false);
-}
-
-BOOST_AUTO_TEST_CASE(TestFindSplit)
-{
- arma::mat testData(3,5);
-
- testData << 4 << 5 << 7 << 3 << 5 << arma::endr
- << 5 << 0 << 1 << 7 << 1 << arma::endr
- << 5 << 6 << 7 << 1 << 8 << arma::endr;
-
- DTree testDTree(testData);
-
- size_t obDim, trueDim;
- double trueLeftError, obLeftError, trueRightError, obRightError,
- obSplit, trueSplit;
-
- trueDim = 2;
- trueSplit = 5.5;
- trueLeftError = 2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5));
- trueRightError = 2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5));
-
- testDTree.logVolume = log(7.0) + log(4.0) + log(7.0);
- BOOST_REQUIRE(testDTree.FindSplit(testData, obDim, obSplit, obLeftError,
- obRightError, 2, 1));
-
- BOOST_REQUIRE(trueDim == obDim);
- BOOST_REQUIRE_CLOSE(trueSplit, obSplit, 1e-10);
-
- BOOST_REQUIRE_CLOSE(trueLeftError, obLeftError, 1e-10);
- BOOST_REQUIRE_CLOSE(trueRightError, obRightError, 1e-10);
-}
-
-BOOST_AUTO_TEST_CASE(TestSplitData)
-{
- arma::mat testData(3, 5);
-
- testData << 4 << 5 << 7 << 3 << 5 << arma::endr
- << 5 << 0 << 1 << 7 << 1 << arma::endr
- << 5 << 6 << 7 << 1 << 8 << arma::endr;
-
- DTree testDTree(testData);
-
- arma::Col<size_t> oTest(5);
- oTest << 1 << 2 << 3 << 4 << 5;
-
- size_t splitDim = 2;
- double trueSplitVal = 5.5;
-
- size_t splitInd = testDTree.SplitData(testData, splitDim, trueSplitVal,
- oTest);
-
- BOOST_REQUIRE_EQUAL(splitInd, 2); // 2 points on left side.
-
- BOOST_REQUIRE_EQUAL(oTest[0], 1);
- BOOST_REQUIRE_EQUAL(oTest[1], 4);
- BOOST_REQUIRE_EQUAL(oTest[2], 3);
- BOOST_REQUIRE_EQUAL(oTest[3], 2);
- BOOST_REQUIRE_EQUAL(oTest[4], 5);
-}
-
-// Tests for the public functions.
-
-BOOST_AUTO_TEST_CASE(TestGrow)
-{
- arma::mat testData(3, 5);
-
- testData << 4 << 5 << 7 << 3 << 5 << arma::endr
- << 5 << 0 << 1 << 7 << 1 << arma::endr
- << 5 << 6 << 7 << 1 << 8 << arma::endr;
-
- arma::Col<size_t> oTest(5);
- oTest << 0 << 1 << 2 << 3 << 4;
-
- double rootError, lError, rError, rlError, rrError;
-
- rootError = -log(4.0) - log(7.0) - log(7.0);
-
- lError = 2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5));
- rError = 2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5));
-
- rlError = 2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5));
- rrError = 2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5));
-
- DTree testDTree(testData);
- double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
-
- BOOST_REQUIRE_EQUAL(oTest[0], 0);
- BOOST_REQUIRE_EQUAL(oTest[1], 3);
- BOOST_REQUIRE_EQUAL(oTest[2], 1);
- BOOST_REQUIRE_EQUAL(oTest[3], 2);
- BOOST_REQUIRE_EQUAL(oTest[4], 4);
-
- // Test the structure of the tree.
- BOOST_REQUIRE(testDTree.Left()->Left() == NULL);
- BOOST_REQUIRE(testDTree.Left()->Right() == NULL);
- BOOST_REQUIRE(testDTree.Right()->Left()->Left() == NULL);
- BOOST_REQUIRE(testDTree.Right()->Left()->Right() == NULL);
- BOOST_REQUIRE(testDTree.Right()->Right()->Left() == NULL);
- BOOST_REQUIRE(testDTree.Right()->Right()->Right() == NULL);
-
- BOOST_REQUIRE(testDTree.SubtreeLeaves() == 3);
-
- BOOST_REQUIRE(testDTree.SplitDim() == 2);
- BOOST_REQUIRE_CLOSE(testDTree.SplitValue(), 5.5, 1e-5);
- BOOST_REQUIRE(testDTree.Right()->SplitDim() == 1);
- BOOST_REQUIRE_CLOSE(testDTree.Right()->SplitValue(), 0.5, 1e-5);
-
- // Test node errors for every node.
- BOOST_REQUIRE_CLOSE(testDTree.logNegError, rootError, 1e-10);
- BOOST_REQUIRE_CLOSE(testDTree.Left()->logNegError, lError, 1e-10);
- BOOST_REQUIRE_CLOSE(testDTree.Right()->logNegError, rError, 1e-10);
- BOOST_REQUIRE_CLOSE(testDTree.Right()->Left()->logNegError, rlError, 1e-10);
- BOOST_REQUIRE_CLOSE(testDTree.Right()->Right()->logNegError, rrError, 1e-10);
-
- // Test alpha.
- double rootAlpha, rAlpha;
- rootAlpha = std::log(-((std::exp(rootError) - (std::exp(lError) +
- std::exp(rlError) + std::exp(rrError))) / 2));
- rAlpha = std::log(-(std::exp(rError) - (std::exp(rlError) +
- std::exp(rrError))));
-
- BOOST_REQUIRE_CLOSE(alpha, min(rootAlpha, rAlpha), 1e-10);
-}
-
-BOOST_AUTO_TEST_CASE(TestPruneAndUpdate)
-{
- arma::mat testData(3, 5);
-
- testData << 4 << 5 << 7 << 3 << 5 << arma::endr
- << 5 << 0 << 1 << 7 << 1 << arma::endr
- << 5 << 6 << 7 << 1 << 8 << arma::endr;
-
- arma::Col<size_t> oTest(5);
- oTest << 0 << 1 << 2 << 3 << 4;
- DTree testDTree(testData);
- double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
- alpha = testDTree.PruneAndUpdate(alpha, testData.n_cols, false);
-
- BOOST_REQUIRE_CLOSE(alpha, numeric_limits<double>::max(), 1e-10);
- BOOST_REQUIRE(testDTree.SubtreeLeaves() == 1);
-
- double rootError = -log(4.0) - log(7.0) - log(7.0);
-
- BOOST_REQUIRE_CLOSE(testDTree.LogNegError(), rootError, 1e-10);
- BOOST_REQUIRE_CLOSE(testDTree.SubtreeLeavesLogNegError(), rootError, 1e-10);
- BOOST_REQUIRE(testDTree.Left() == NULL);
- BOOST_REQUIRE(testDTree.Right() == NULL);
-}
-
-BOOST_AUTO_TEST_CASE(TestComputeValue)
-{
- arma::mat testData(3, 5);
-
- testData << 4 << 5 << 7 << 3 << 5 << arma::endr
- << 5 << 0 << 1 << 7 << 1 << arma::endr
- << 5 << 6 << 7 << 1 << 8 << arma::endr;
-
- arma::vec q1(3), q2(3), q3(3), q4(3);
-
- q1 << 4 << 2 << 2;
- q2 << 5 << 0.25 << 6;
- q3 << 5 << 3 << 7;
- q4 << 2 << 3 << 3;
-
- arma::Col<size_t> oTest(5);
- oTest << 0 << 1 << 2 << 3 << 4;
-
- DTree testDTree(testData);
- double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
-
- double d1 = (2.0 / 5.0) / exp(log(4.0) + log(7.0) + log(4.5));
- double d2 = (1.0 / 5.0) / exp(log(4.0) + log(0.5) + log(2.5));
- double d3 = (2.0 / 5.0) / exp(log(4.0) + log(6.5) + log(2.5));
-
- BOOST_REQUIRE_CLOSE(d1, testDTree.ComputeValue(q1), 1e-10);
- BOOST_REQUIRE_CLOSE(d2, testDTree.ComputeValue(q2), 1e-10);
- BOOST_REQUIRE_CLOSE(d3, testDTree.ComputeValue(q3), 1e-10);
- BOOST_REQUIRE_CLOSE(0.0, testDTree.ComputeValue(q4), 1e-10);
-
- alpha = testDTree.PruneAndUpdate(alpha, testData.n_cols, false);
-
- double d = 1.0 / exp(log(4.0) + log(7.0) + log(7.0));
-
- BOOST_REQUIRE_CLOSE(d, testDTree.ComputeValue(q1), 1e-10);
- BOOST_REQUIRE_CLOSE(d, testDTree.ComputeValue(q2), 1e-10);
- BOOST_REQUIRE_CLOSE(d, testDTree.ComputeValue(q3), 1e-10);
- BOOST_REQUIRE_CLOSE(0.0, testDTree.ComputeValue(q4), 1e-10);
-}
-
-BOOST_AUTO_TEST_CASE(TestVariableImportance)
-{
- arma::mat testData(3, 5);
-
- testData << 4 << 5 << 7 << 3 << 5 << arma::endr
- << 5 << 0 << 1 << 7 << 1 << arma::endr
- << 5 << 6 << 7 << 1 << 8 << arma::endr;
-
- double rootError, lError, rError, rlError, rrError;
-
- rootError = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
-
- lError = -1.0 * exp(2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5)));
- rError = -1.0 * exp(2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5)));
-
- rlError = -1.0 * exp(2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5)));
- rrError = -1.0 * exp(2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5)));
-
- arma::Col<size_t> oTest(5);
- oTest << 0 << 1 << 2 << 3 << 4;
-
- DTree testDTree(testData);
- testDTree.Grow(testData, oTest, false, 2, 1);
-
- arma::vec imps;
-
- testDTree.ComputeVariableImportance(imps);
-
- BOOST_REQUIRE_CLOSE((double) 0.0, imps[0], 1e-10);
- BOOST_REQUIRE_CLOSE((double) (rError - (rlError + rrError)), imps[1], 1e-10);
- BOOST_REQUIRE_CLOSE((double) (rootError - (lError + rError)), imps[2], 1e-10);
-}
-
-/**
- * These are not yet implemented.
- *
-BOOST_AUTO_TEST_CASE(TestTagTree)
-{
- MatType testData(3, 5);
-
- testData << 4 << 5 << 7 << 3 << 5 << arma::endr
- << 5 << 0 << 1 << 7 << 1 << arma::endr
- << 5 << 6 << 7 << 1 << 8 << arma::endr;
-
- DTree<>* testDTree = new DTree<>(&testData);
-
- delete testDTree;
-}
-
-BOOST_AUTO_TEST_CASE(TestFindBucket)
-{
- MatType testData(3, 5);
-
- testData << 4 << 5 << 7 << 3 << 5 << arma::endr
- << 5 << 0 << 1 << 7 << 1 << arma::endr
- << 5 << 6 << 7 << 1 << 8 << arma::endr;
-
- DTree<>* testDTree = new DTree<>(&testData);
-
- delete testDTree;
-}
-
-// Test functions in dt_utils.hpp
-
-BOOST_AUTO_TEST_CASE(TestTrainer)
-{
-
-}
-
-BOOST_AUTO_TEST_CASE(TestPrintVariableImportance)
-{
-
-}
-
-BOOST_AUTO_TEST_CASE(TestPrintLeafMembership)
-{
-
-}
-*/
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/det_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/det_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/det_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/det_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,363 @@
+/**
+ * @file det_test.cpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * Unit tests for the functions of the class DTree
+ * and the utility functions using this class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#define protected public
+#define private public
+#include <mlpack/methods/det/dtree.hpp>
+#include <mlpack/methods/det/dt_utils.hpp>
+#undef protected
+#undef private
+
+#include <mlpack/core.hpp>
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::det;
+using namespace std;
+
+BOOST_AUTO_TEST_SUITE(DETTest);
+
+// Tests for the private functions.
+
+BOOST_AUTO_TEST_CASE(TestGetMaxMinVals)
+{
+ arma::mat testData(3, 5);
+
+ testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
+
+ DTree tree(testData);
+
+ BOOST_REQUIRE_EQUAL(tree.maxVals[0], 7);
+ BOOST_REQUIRE_EQUAL(tree.minVals[0], 3);
+ BOOST_REQUIRE_EQUAL(tree.maxVals[1], 7);
+ BOOST_REQUIRE_EQUAL(tree.minVals[1], 0);
+ BOOST_REQUIRE_EQUAL(tree.maxVals[2], 8);
+ BOOST_REQUIRE_EQUAL(tree.minVals[2], 1);
+}
+
+BOOST_AUTO_TEST_CASE(TestComputeNodeError)
+{
+ arma::vec maxVals("7 7 8");
+ arma::vec minVals("3 0 1");
+
+ DTree testDTree(maxVals, minVals, 5);
+ double trueNodeError = -log(4.0) - log(7.0) - log(7.0);
+
+ BOOST_REQUIRE_CLOSE((double) testDTree.logNegError, trueNodeError, 1e-10);
+
+ testDTree.start = 3;
+ testDTree.end = 5;
+
+ double nodeError = testDTree.LogNegativeError(5);
+ trueNodeError = 2 * log(2.0 / 5.0) - log(4.0) - log(7.0) - log(7.0);
+ BOOST_REQUIRE_CLOSE(nodeError, trueNodeError, 1e-10);
+}
+
+BOOST_AUTO_TEST_CASE(TestWithinRange)
+{
+ arma::vec maxVals("7 7 8");
+ arma::vec minVals("3 0 1");
+
+ DTree testDTree(maxVals, minVals, 5);
+
+ arma::vec testQuery(3);
+ testQuery << 4.5 << 2.5 << 2;
+
+ BOOST_REQUIRE_EQUAL(testDTree.WithinRange(testQuery), true);
+
+ testQuery << 8.5 << 2.5 << 2;
+
+ BOOST_REQUIRE_EQUAL(testDTree.WithinRange(testQuery), false);
+}
+
+BOOST_AUTO_TEST_CASE(TestFindSplit)
+{
+ arma::mat testData(3,5);
+
+ testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
+
+ DTree testDTree(testData);
+
+ size_t obDim, trueDim;
+ double trueLeftError, obLeftError, trueRightError, obRightError,
+ obSplit, trueSplit;
+
+ trueDim = 2;
+ trueSplit = 5.5;
+ trueLeftError = 2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5));
+ trueRightError = 2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5));
+
+ testDTree.logVolume = log(7.0) + log(4.0) + log(7.0);
+ BOOST_REQUIRE(testDTree.FindSplit(testData, obDim, obSplit, obLeftError,
+ obRightError, 2, 1));
+
+ BOOST_REQUIRE(trueDim == obDim);
+ BOOST_REQUIRE_CLOSE(trueSplit, obSplit, 1e-10);
+
+ BOOST_REQUIRE_CLOSE(trueLeftError, obLeftError, 1e-10);
+ BOOST_REQUIRE_CLOSE(trueRightError, obRightError, 1e-10);
+}
+
+BOOST_AUTO_TEST_CASE(TestSplitData)
+{
+ arma::mat testData(3, 5);
+
+ testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
+
+ DTree testDTree(testData);
+
+ arma::Col<size_t> oTest(5);
+ oTest << 1 << 2 << 3 << 4 << 5;
+
+ size_t splitDim = 2;
+ double trueSplitVal = 5.5;
+
+ size_t splitInd = testDTree.SplitData(testData, splitDim, trueSplitVal,
+ oTest);
+
+ BOOST_REQUIRE_EQUAL(splitInd, 2); // 2 points on left side.
+
+ BOOST_REQUIRE_EQUAL(oTest[0], 1);
+ BOOST_REQUIRE_EQUAL(oTest[1], 4);
+ BOOST_REQUIRE_EQUAL(oTest[2], 3);
+ BOOST_REQUIRE_EQUAL(oTest[3], 2);
+ BOOST_REQUIRE_EQUAL(oTest[4], 5);
+}
+
+// Tests for the public functions.
+
+BOOST_AUTO_TEST_CASE(TestGrow)
+{
+ arma::mat testData(3, 5);
+
+ testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
+
+ arma::Col<size_t> oTest(5);
+ oTest << 0 << 1 << 2 << 3 << 4;
+
+ double rootError, lError, rError, rlError, rrError;
+
+ rootError = -log(4.0) - log(7.0) - log(7.0);
+
+ lError = 2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5));
+ rError = 2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5));
+
+ rlError = 2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5));
+ rrError = 2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5));
+
+ DTree testDTree(testData);
+ double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
+
+ BOOST_REQUIRE_EQUAL(oTest[0], 0);
+ BOOST_REQUIRE_EQUAL(oTest[1], 3);
+ BOOST_REQUIRE_EQUAL(oTest[2], 1);
+ BOOST_REQUIRE_EQUAL(oTest[3], 2);
+ BOOST_REQUIRE_EQUAL(oTest[4], 4);
+
+ // Test the structure of the tree.
+ BOOST_REQUIRE(testDTree.Left()->Left() == NULL);
+ BOOST_REQUIRE(testDTree.Left()->Right() == NULL);
+ BOOST_REQUIRE(testDTree.Right()->Left()->Left() == NULL);
+ BOOST_REQUIRE(testDTree.Right()->Left()->Right() == NULL);
+ BOOST_REQUIRE(testDTree.Right()->Right()->Left() == NULL);
+ BOOST_REQUIRE(testDTree.Right()->Right()->Right() == NULL);
+
+ BOOST_REQUIRE(testDTree.SubtreeLeaves() == 3);
+
+ BOOST_REQUIRE(testDTree.SplitDim() == 2);
+ BOOST_REQUIRE_CLOSE(testDTree.SplitValue(), 5.5, 1e-5);
+ BOOST_REQUIRE(testDTree.Right()->SplitDim() == 1);
+ BOOST_REQUIRE_CLOSE(testDTree.Right()->SplitValue(), 0.5, 1e-5);
+
+ // Test node errors for every node.
+ BOOST_REQUIRE_CLOSE(testDTree.logNegError, rootError, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree.Left()->logNegError, lError, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree.Right()->logNegError, rError, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree.Right()->Left()->logNegError, rlError, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree.Right()->Right()->logNegError, rrError, 1e-10);
+
+ // Test alpha.
+ double rootAlpha, rAlpha;
+ rootAlpha = std::log(-((std::exp(rootError) - (std::exp(lError) +
+ std::exp(rlError) + std::exp(rrError))) / 2));
+ rAlpha = std::log(-(std::exp(rError) - (std::exp(rlError) +
+ std::exp(rrError))));
+
+ BOOST_REQUIRE_CLOSE(alpha, min(rootAlpha, rAlpha), 1e-10);
+}
+
+BOOST_AUTO_TEST_CASE(TestPruneAndUpdate)
+{
+ arma::mat testData(3, 5);
+
+ testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
+
+ arma::Col<size_t> oTest(5);
+ oTest << 0 << 1 << 2 << 3 << 4;
+ DTree testDTree(testData);
+ double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
+ alpha = testDTree.PruneAndUpdate(alpha, testData.n_cols, false);
+
+ BOOST_REQUIRE_CLOSE(alpha, numeric_limits<double>::max(), 1e-10);
+ BOOST_REQUIRE(testDTree.SubtreeLeaves() == 1);
+
+ double rootError = -log(4.0) - log(7.0) - log(7.0);
+
+ BOOST_REQUIRE_CLOSE(testDTree.LogNegError(), rootError, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree.SubtreeLeavesLogNegError(), rootError, 1e-10);
+ BOOST_REQUIRE(testDTree.Left() == NULL);
+ BOOST_REQUIRE(testDTree.Right() == NULL);
+}
+
+BOOST_AUTO_TEST_CASE(TestComputeValue)
+{
+ arma::mat testData(3, 5);
+
+ testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
+
+ arma::vec q1(3), q2(3), q3(3), q4(3);
+
+ q1 << 4 << 2 << 2;
+ q2 << 5 << 0.25 << 6;
+ q3 << 5 << 3 << 7;
+ q4 << 2 << 3 << 3;
+
+ arma::Col<size_t> oTest(5);
+ oTest << 0 << 1 << 2 << 3 << 4;
+
+ DTree testDTree(testData);
+ double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
+
+ double d1 = (2.0 / 5.0) / exp(log(4.0) + log(7.0) + log(4.5));
+ double d2 = (1.0 / 5.0) / exp(log(4.0) + log(0.5) + log(2.5));
+ double d3 = (2.0 / 5.0) / exp(log(4.0) + log(6.5) + log(2.5));
+
+ BOOST_REQUIRE_CLOSE(d1, testDTree.ComputeValue(q1), 1e-10);
+ BOOST_REQUIRE_CLOSE(d2, testDTree.ComputeValue(q2), 1e-10);
+ BOOST_REQUIRE_CLOSE(d3, testDTree.ComputeValue(q3), 1e-10);
+ BOOST_REQUIRE_CLOSE(0.0, testDTree.ComputeValue(q4), 1e-10);
+
+ alpha = testDTree.PruneAndUpdate(alpha, testData.n_cols, false);
+
+ double d = 1.0 / exp(log(4.0) + log(7.0) + log(7.0));
+
+ BOOST_REQUIRE_CLOSE(d, testDTree.ComputeValue(q1), 1e-10);
+ BOOST_REQUIRE_CLOSE(d, testDTree.ComputeValue(q2), 1e-10);
+ BOOST_REQUIRE_CLOSE(d, testDTree.ComputeValue(q3), 1e-10);
+ BOOST_REQUIRE_CLOSE(0.0, testDTree.ComputeValue(q4), 1e-10);
+}
+
+BOOST_AUTO_TEST_CASE(TestVariableImportance)
+{
+ arma::mat testData(3, 5);
+
+ testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
+
+ double rootError, lError, rError, rlError, rrError;
+
+ rootError = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
+
+ lError = -1.0 * exp(2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5)));
+ rError = -1.0 * exp(2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5)));
+
+ rlError = -1.0 * exp(2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5)));
+ rrError = -1.0 * exp(2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5)));
+
+ arma::Col<size_t> oTest(5);
+ oTest << 0 << 1 << 2 << 3 << 4;
+
+ DTree testDTree(testData);
+ testDTree.Grow(testData, oTest, false, 2, 1);
+
+ arma::vec imps;
+
+ testDTree.ComputeVariableImportance(imps);
+
+ BOOST_REQUIRE_CLOSE((double) 0.0, imps[0], 1e-10);
+ BOOST_REQUIRE_CLOSE((double) (rError - (rlError + rrError)), imps[1], 1e-10);
+ BOOST_REQUIRE_CLOSE((double) (rootError - (lError + rError)), imps[2], 1e-10);
+}
+
+/**
+ * These are not yet implemented.
+ *
+BOOST_AUTO_TEST_CASE(TestTagTree)
+{
+ MatType testData(3, 5);
+
+ testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
+
+ DTree<>* testDTree = new DTree<>(&testData);
+
+ delete testDTree;
+}
+
+BOOST_AUTO_TEST_CASE(TestFindBucket)
+{
+ MatType testData(3, 5);
+
+ testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
+
+ DTree<>* testDTree = new DTree<>(&testData);
+
+ delete testDTree;
+}
+
+// Test functions in dt_utils.hpp
+
+BOOST_AUTO_TEST_CASE(TestTrainer)
+{
+
+}
+
+BOOST_AUTO_TEST_CASE(TestPrintVariableImportance)
+{
+
+}
+
+BOOST_AUTO_TEST_CASE(TestPrintLeafMembership)
+{
+
+}
+*/
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/distribution_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/distribution_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/distribution_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,256 +0,0 @@
-/**
- * @file distribution_test.cpp
- * @author Ryan Curtin
- *
- * Test for the mlpack::distribution::DiscreteDistribution class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::distribution;
-
-BOOST_AUTO_TEST_SUITE(DistributionTest);
-
-/**
- * Make sure we initialize correctly.
- */
-BOOST_AUTO_TEST_CASE(DiscreteDistributionConstructorTest)
-{
- DiscreteDistribution d(5);
-
- BOOST_REQUIRE_EQUAL(d.Probabilities().n_elem, 5);
- BOOST_REQUIRE_CLOSE(d.Probability("0"), 0.2, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Probability("1"), 0.2, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Probability("2"), 0.2, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Probability("3"), 0.2, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Probability("4"), 0.2, 1e-5);
-}
-
-/**
- * Make sure we get the probabilities of observations right.
- */
-BOOST_AUTO_TEST_CASE(DiscreteDistributionProbabilityTest)
-{
- DiscreteDistribution d(5);
-
- d.Probabilities() = "0.2 0.4 0.1 0.1 0.2";
-
- BOOST_REQUIRE_CLOSE(d.Probability("0"), 0.2, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Probability("1"), 0.4, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Probability("2"), 0.1, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Probability("3"), 0.1, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Probability("4"), 0.2, 1e-5);
-}
-
-/**
- * Make sure we get random observations correct.
- */
-BOOST_AUTO_TEST_CASE(DiscreteDistributionRandomTest)
-{
- DiscreteDistribution d(3);
-
- d.Probabilities() = "0.3 0.6 0.1";
-
- arma::vec actualProb(3);
- actualProb.zeros();
-
- for (size_t i = 0; i < 10000; i++)
- actualProb((size_t) (d.Random()[0] + 0.5))++;
-
- // Normalize.
- actualProb /= accu(actualProb);
-
- // 8% tolerance, because this can be a noisy process.
- BOOST_REQUIRE_CLOSE(actualProb(0), 0.3, 8.0);
- BOOST_REQUIRE_CLOSE(actualProb(1), 0.6, 8.0);
- BOOST_REQUIRE_CLOSE(actualProb(2), 0.1, 8.0);
-}
-
-/**
- * Make sure we can estimate from observations correctly.
- */
-BOOST_AUTO_TEST_CASE(DiscreteDistributionEstimateTest)
-{
- DiscreteDistribution d(4);
-
- arma::mat obs("0 0 1 1 2 2 2 3");
-
- d.Estimate(obs);
-
- BOOST_REQUIRE_CLOSE(d.Probability("0"), 0.25, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Probability("1"), 0.25, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Probability("2"), 0.375, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Probability("3"), 0.125, 1e-5);
-}
-
-/**
- * Estimate from observations with probabilities.
- */
-BOOST_AUTO_TEST_CASE(DiscreteDistributionEstimateProbTest)
-{
- DiscreteDistribution d(3);
-
- arma::mat obs("0 0 1 2");
-
- arma::vec prob("0.25 0.25 0.5 1.0");
-
- d.Estimate(obs, prob);
-
- BOOST_REQUIRE_CLOSE(d.Probability("0"), 0.25, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Probability("1"), 0.25, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Probability("2"), 0.5, 1e-5);
-}
-
-/**
- * Make sure Gaussian distributions are initialized correctly.
- */
-BOOST_AUTO_TEST_CASE(GaussianDistributionEmptyConstructor)
-{
- GaussianDistribution d;
-
- BOOST_REQUIRE_EQUAL(d.Mean().n_elem, 0);
- BOOST_REQUIRE_EQUAL(d.Covariance().n_elem, 0);
-}
-
-/**
- * Make sure Gaussian distributions are initialized to the correct
- * dimensionality.
- */
-BOOST_AUTO_TEST_CASE(GaussianDistributionDimensionalityConstructor)
-{
- GaussianDistribution d(4);
-
- BOOST_REQUIRE_EQUAL(d.Mean().n_elem, 4);
- BOOST_REQUIRE_EQUAL(d.Covariance().n_rows, 4);
- BOOST_REQUIRE_EQUAL(d.Covariance().n_cols, 4);
-}
-
-/**
- * Make sure Gaussian distributions are initialized correctly when we give a
- * mean and covariance.
- */
-BOOST_AUTO_TEST_CASE(GaussianDistributionDistributionConstructor)
-{
- arma::vec mean(3);
- arma::mat covariance(3, 3);
-
- mean.randu();
- covariance.randu();
-
- GaussianDistribution d(mean, covariance);
-
- for (size_t i = 0; i < 3; i++)
- BOOST_REQUIRE_CLOSE(d.Mean()[i], mean[i], 1e-5);
-
- for (size_t i = 0; i < 3; i++)
- for (size_t j = 0; j < 3; j++)
- BOOST_REQUIRE_CLOSE(d.Covariance()(i, j), covariance(i, j), 1e-5);
-}
-
-/**
- * Make sure the probability of observations is correct.
- */
-BOOST_AUTO_TEST_CASE(GaussianDistributionProbabilityTest)
-{
- arma::vec mean("5 6 3 3 2");
- arma::mat cov("6 1 1 0 2;"
- "0 7 1 0 1;"
- "1 1 4 1 1;"
- "1 0 1 7 0;"
- "2 0 1 1 6");
-
- GaussianDistribution d(mean, cov);
-
- BOOST_REQUIRE_CLOSE(d.Probability("0 1 2 3 4"), 1.02531207499358e-6, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Probability("3 2 3 7 8"), 1.82353695848039e-7, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Probability("2 2 0 8 1"), 1.29759261892949e-6, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Probability("2 1 5 0 1"), 1.33218060268258e-6, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Probability("3 0 5 1 0"), 1.12120427975708e-6, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Probability("4 0 6 1 0"), 4.57951032485297e-7, 1e-5);
-}
-
-/**
- * Make sure random observations follow the probability distribution correctly.
- */
-BOOST_AUTO_TEST_CASE(GaussianDistributionRandomTest)
-{
- arma::vec mean("1.0 2.25");
- arma::mat cov("0.85 0.60;"
- "0.60 1.45");
-
- GaussianDistribution d(mean, cov);
-
- arma::mat obs(2, 5000);
-
- for (size_t i = 0; i < 5000; i++)
- obs.col(i) = d.Random();
-
- // Now make sure that reflects the actual distribution.
- arma::vec obsMean = arma::mean(obs, 1);
- arma::mat obsCov = ccov(obs);
-
- // 10% tolerance because this can be noisy.
- BOOST_REQUIRE_CLOSE(obsMean[0], mean[0], 10.0);
- BOOST_REQUIRE_CLOSE(obsMean[1], mean[1], 10.0);
-
- BOOST_REQUIRE_CLOSE(obsCov(0, 0), cov(0, 0), 10.0);
- BOOST_REQUIRE_CLOSE(obsCov(0, 1), cov(0, 1), 10.0);
- BOOST_REQUIRE_CLOSE(obsCov(1, 0), cov(1, 0), 10.0);
- BOOST_REQUIRE_CLOSE(obsCov(1, 1), cov(1, 1), 10.0);
-}
-
-/**
- * Make sure that we can properly estimate from given observations.
- */
-BOOST_AUTO_TEST_CASE(GaussianDistributionEstimateTest)
-{
- arma::vec mean("1.0 3.0 0.0 2.5");
- arma::mat cov("3.0 0.0 1.0 4.0;"
- "0.0 2.4 0.5 0.1;"
- "1.0 0.5 6.3 0.0;"
- "4.0 0.1 0.0 9.1");
-
- // Now generate the observations.
- arma::mat observations(4, 10000);
-
- arma::mat transChol = trans(chol(cov));
- for (size_t i = 0; i < 10000; i++)
- observations.col(i) = transChol * arma::randn<arma::vec>(4) + mean;
-
- // Now estimate.
- GaussianDistribution d;
-
- // Find actual mean and covariance of data.
- arma::vec actualMean = arma::mean(observations, 1);
- arma::mat actualCov = ccov(observations);
-
- d.Estimate(observations);
-
- // Check that everything is estimated right.
- for (size_t i = 0; i < 4; i++)
- BOOST_REQUIRE_SMALL(d.Mean()[i] - actualMean[i], 1e-5);
-
- for (size_t i = 0; i < 4; i++)
- for (size_t j = 0; j < 4; j++)
- BOOST_REQUIRE_SMALL(d.Covariance()(i, j) - actualCov(i, j), 1e-5);
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/distribution_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/distribution_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/distribution_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/distribution_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,256 @@
+/**
+ * @file distribution_test.cpp
+ * @author Ryan Curtin
+ *
+ * Test for the mlpack::distribution::DiscreteDistribution class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::distribution;
+
+BOOST_AUTO_TEST_SUITE(DistributionTest);
+
+/**
+ * Make sure we initialize correctly.
+ */
+BOOST_AUTO_TEST_CASE(DiscreteDistributionConstructorTest)
+{
+ DiscreteDistribution d(5);
+
+ BOOST_REQUIRE_EQUAL(d.Probabilities().n_elem, 5);
+ BOOST_REQUIRE_CLOSE(d.Probability("0"), 0.2, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Probability("1"), 0.2, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Probability("2"), 0.2, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Probability("3"), 0.2, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Probability("4"), 0.2, 1e-5);
+}
+
+/**
+ * Make sure we get the probabilities of observations right.
+ */
+BOOST_AUTO_TEST_CASE(DiscreteDistributionProbabilityTest)
+{
+ DiscreteDistribution d(5);
+
+ d.Probabilities() = "0.2 0.4 0.1 0.1 0.2";
+
+ BOOST_REQUIRE_CLOSE(d.Probability("0"), 0.2, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Probability("1"), 0.4, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Probability("2"), 0.1, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Probability("3"), 0.1, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Probability("4"), 0.2, 1e-5);
+}
+
+/**
+ * Make sure we get random observations correct.
+ */
+BOOST_AUTO_TEST_CASE(DiscreteDistributionRandomTest)
+{
+ DiscreteDistribution d(3);
+
+ d.Probabilities() = "0.3 0.6 0.1";
+
+ arma::vec actualProb(3);
+ actualProb.zeros();
+
+ for (size_t i = 0; i < 10000; i++)
+ actualProb((size_t) (d.Random()[0] + 0.5))++;
+
+ // Normalize.
+ actualProb /= accu(actualProb);
+
+ // 8% tolerance, because this can be a noisy process.
+ BOOST_REQUIRE_CLOSE(actualProb(0), 0.3, 8.0);
+ BOOST_REQUIRE_CLOSE(actualProb(1), 0.6, 8.0);
+ BOOST_REQUIRE_CLOSE(actualProb(2), 0.1, 8.0);
+}
+
+/**
+ * Make sure we can estimate from observations correctly.
+ */
+BOOST_AUTO_TEST_CASE(DiscreteDistributionEstimateTest)
+{
+ DiscreteDistribution d(4);
+
+ arma::mat obs("0 0 1 1 2 2 2 3");
+
+ d.Estimate(obs);
+
+ BOOST_REQUIRE_CLOSE(d.Probability("0"), 0.25, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Probability("1"), 0.25, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Probability("2"), 0.375, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Probability("3"), 0.125, 1e-5);
+}
+
+/**
+ * Estimate from observations with probabilities.
+ */
+BOOST_AUTO_TEST_CASE(DiscreteDistributionEstimateProbTest)
+{
+ DiscreteDistribution d(3);
+
+ arma::mat obs("0 0 1 2");
+
+ arma::vec prob("0.25 0.25 0.5 1.0");
+
+ d.Estimate(obs, prob);
+
+ BOOST_REQUIRE_CLOSE(d.Probability("0"), 0.25, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Probability("1"), 0.25, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Probability("2"), 0.5, 1e-5);
+}
+
+/**
+ * Make sure Gaussian distributions are initialized correctly.
+ */
+BOOST_AUTO_TEST_CASE(GaussianDistributionEmptyConstructor)
+{
+ GaussianDistribution d;
+
+ BOOST_REQUIRE_EQUAL(d.Mean().n_elem, 0);
+ BOOST_REQUIRE_EQUAL(d.Covariance().n_elem, 0);
+}
+
+/**
+ * Make sure Gaussian distributions are initialized to the correct
+ * dimensionality.
+ */
+BOOST_AUTO_TEST_CASE(GaussianDistributionDimensionalityConstructor)
+{
+ GaussianDistribution d(4);
+
+ BOOST_REQUIRE_EQUAL(d.Mean().n_elem, 4);
+ BOOST_REQUIRE_EQUAL(d.Covariance().n_rows, 4);
+ BOOST_REQUIRE_EQUAL(d.Covariance().n_cols, 4);
+}
+
+/**
+ * Make sure Gaussian distributions are initialized correctly when we give a
+ * mean and covariance.
+ */
+BOOST_AUTO_TEST_CASE(GaussianDistributionDistributionConstructor)
+{
+ arma::vec mean(3);
+ arma::mat covariance(3, 3);
+
+ mean.randu();
+ covariance.randu();
+
+ GaussianDistribution d(mean, covariance);
+
+ for (size_t i = 0; i < 3; i++)
+ BOOST_REQUIRE_CLOSE(d.Mean()[i], mean[i], 1e-5);
+
+ for (size_t i = 0; i < 3; i++)
+ for (size_t j = 0; j < 3; j++)
+ BOOST_REQUIRE_CLOSE(d.Covariance()(i, j), covariance(i, j), 1e-5);
+}
+
+/**
+ * Make sure the probability of observations is correct.
+ */
+BOOST_AUTO_TEST_CASE(GaussianDistributionProbabilityTest)
+{
+ arma::vec mean("5 6 3 3 2");
+ arma::mat cov("6 1 1 0 2;"
+ "0 7 1 0 1;"
+ "1 1 4 1 1;"
+ "1 0 1 7 0;"
+ "2 0 1 1 6");
+
+ GaussianDistribution d(mean, cov);
+
+ BOOST_REQUIRE_CLOSE(d.Probability("0 1 2 3 4"), 1.02531207499358e-6, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Probability("3 2 3 7 8"), 1.82353695848039e-7, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Probability("2 2 0 8 1"), 1.29759261892949e-6, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Probability("2 1 5 0 1"), 1.33218060268258e-6, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Probability("3 0 5 1 0"), 1.12120427975708e-6, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Probability("4 0 6 1 0"), 4.57951032485297e-7, 1e-5);
+}
+
+/**
+ * Make sure random observations follow the probability distribution correctly.
+ */
+BOOST_AUTO_TEST_CASE(GaussianDistributionRandomTest)
+{
+ arma::vec mean("1.0 2.25");
+ arma::mat cov("0.85 0.60;"
+ "0.60 1.45");
+
+ GaussianDistribution d(mean, cov);
+
+ arma::mat obs(2, 5000);
+
+ for (size_t i = 0; i < 5000; i++)
+ obs.col(i) = d.Random();
+
+ // Now make sure that reflects the actual distribution.
+ arma::vec obsMean = arma::mean(obs, 1);
+ arma::mat obsCov = ccov(obs);
+
+ // 10% tolerance because this can be noisy.
+ BOOST_REQUIRE_CLOSE(obsMean[0], mean[0], 10.0);
+ BOOST_REQUIRE_CLOSE(obsMean[1], mean[1], 10.0);
+
+ BOOST_REQUIRE_CLOSE(obsCov(0, 0), cov(0, 0), 10.0);
+ BOOST_REQUIRE_CLOSE(obsCov(0, 1), cov(0, 1), 10.0);
+ BOOST_REQUIRE_CLOSE(obsCov(1, 0), cov(1, 0), 10.0);
+ BOOST_REQUIRE_CLOSE(obsCov(1, 1), cov(1, 1), 10.0);
+}
+
+/**
+ * Make sure that we can properly estimate from given observations.
+ */
+BOOST_AUTO_TEST_CASE(GaussianDistributionEstimateTest)
+{
+ arma::vec mean("1.0 3.0 0.0 2.5");
+ arma::mat cov("3.0 0.0 1.0 4.0;"
+ "0.0 2.4 0.5 0.1;"
+ "1.0 0.5 6.3 0.0;"
+ "4.0 0.1 0.0 9.1");
+
+ // Now generate the observations.
+ arma::mat observations(4, 10000);
+
+ arma::mat transChol = trans(chol(cov));
+ for (size_t i = 0; i < 10000; i++)
+ observations.col(i) = transChol * arma::randn<arma::vec>(4) + mean;
+
+ // Now estimate.
+ GaussianDistribution d;
+
+ // Find actual mean and covariance of data.
+ arma::vec actualMean = arma::mean(observations, 1);
+ arma::mat actualCov = ccov(observations);
+
+ d.Estimate(observations);
+
+ // Check that everything is estimated right.
+ for (size_t i = 0; i < 4; i++)
+ BOOST_REQUIRE_SMALL(d.Mean()[i] - actualMean[i], 1e-5);
+
+ for (size_t i = 0; i < 4; i++)
+ for (size_t j = 0; j < 4; j++)
+ BOOST_REQUIRE_SMALL(d.Covariance()(i, j) - actualCov(i, j), 1e-5);
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/emst_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/emst_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/emst_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,143 +0,0 @@
-/**
- * @file emst_test.cpp
- *
- * Test file for EMST methods.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/methods/emst/dtb.hpp>
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::emst;
-
-BOOST_AUTO_TEST_SUITE(EMSTTest);
-
-/**
- * Simple emst test with small, synthetic dataset. This is an
- * exhaustive test, which checks that each method for performing the calculation
- * (dual-tree, single-tree, naive) produces the correct results. The dataset is
- * in one dimension for simplicity -- the correct functionality of distance
- * functions is not tested here.
- */
-BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
-{
- // Set up our data.
- arma::mat data(1, 11);
- data[0] = 0.05; // Row addressing is unnecessary (they are all 0).
- data[1] = 0.37;
- data[2] = 0.15;
- data[3] = 1.25;
- data[4] = 5.05;
- data[5] = -0.22;
- data[6] = -2.00;
- data[7] = -1.30;
- data[8] = 0.45;
- data[9] = 0.91;
- data[10] = 1.00;
-
- // Now perform the actual calculation.
- arma::mat results;
-
- DualTreeBoruvka<> dtb(data);
- dtb.ComputeMST(results);
-
- // Now the exhaustive check for correctness.
- BOOST_REQUIRE(results(0, 0) == 1);
- BOOST_REQUIRE(results(1, 0) == 8);
- BOOST_REQUIRE_CLOSE(results(2, 0), 0.08, 1e-5);
-
- BOOST_REQUIRE(results(0, 1) == 9);
- BOOST_REQUIRE(results(1, 1) == 10);
- BOOST_REQUIRE_CLOSE(results(2, 1), 0.09, 1e-5);
-
- BOOST_REQUIRE(results(0, 2) == 0);
- BOOST_REQUIRE(results(1, 2) == 2);
- BOOST_REQUIRE_CLOSE(results(2, 2), 0.1, 1e-5);
-
- BOOST_REQUIRE(results(0, 3) == 1);
- BOOST_REQUIRE(results(1, 3) == 2);
- BOOST_REQUIRE_CLOSE(results(2, 3), 0.22, 1e-5);
-
- BOOST_REQUIRE(results(0, 4) == 3);
- BOOST_REQUIRE(results(1, 4) == 10);
- BOOST_REQUIRE_CLOSE(results(2, 4), 0.25, 1e-5);
-
- BOOST_REQUIRE(results(0, 5) == 0);
- BOOST_REQUIRE(results(1, 5) == 5);
- BOOST_REQUIRE_CLOSE(results(2, 5), 0.27, 1e-5);
-
- BOOST_REQUIRE(results(0, 6) == 8);
- BOOST_REQUIRE(results(1, 6) == 9);
- BOOST_REQUIRE_CLOSE(results(2, 6), 0.46, 1e-5);
-
- BOOST_REQUIRE(results(0, 7) == 6);
- BOOST_REQUIRE(results(1, 7) == 7);
- BOOST_REQUIRE_CLOSE(results(2, 7), 0.7, 1e-5);
-
- BOOST_REQUIRE(results(0, 8) == 5);
- BOOST_REQUIRE(results(1, 8) == 7);
- BOOST_REQUIRE_CLOSE(results(2, 8), 1.08, 1e-5);
-
- BOOST_REQUIRE(results(0, 9) == 3);
- BOOST_REQUIRE(results(1, 9) == 4);
- BOOST_REQUIRE_CLOSE(results(2, 9), 3.8, 1e-5);
-}
-
-/**
- * Test the dual tree method against the naive computation.
- *
- * Errors are produced if the results are not identical.
- */
-BOOST_AUTO_TEST_CASE(DualTreeVsNaive)
-{
- arma::mat inputData;
-
- // Hard-coded filename: bad!
- // Code duplication: also bad!
- if (!data::Load("test_data_3_1000.csv", inputData))
- BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
-
- // Set up matrices to work with (may not be necessary with no ALIAS_MATRIX?).
- arma::mat dualData = arma::trans(inputData);
- arma::mat naiveData = arma::trans(inputData);
-
- // Reset parameters from last test.
- DualTreeBoruvka<> dtb(dualData);
-
- arma::mat dualResults;
- dtb.ComputeMST(dualResults);
-
- // Set naive mode.
- DualTreeBoruvka<> dtbNaive(naiveData, true);
-
- arma::mat naiveResults;
- dtbNaive.ComputeMST(naiveResults);
-
- BOOST_REQUIRE(dualResults.n_cols == naiveResults.n_cols);
- BOOST_REQUIRE(dualResults.n_rows == naiveResults.n_rows);
-
- for (size_t i = 0; i < dualResults.n_cols; i++)
- {
- BOOST_REQUIRE(dualResults(0, i) == naiveResults(0, i));
- BOOST_REQUIRE(dualResults(1, i) == naiveResults(1, i));
- BOOST_REQUIRE_CLOSE(dualResults(2, i), naiveResults(2, i), 1e-5);
- }
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/emst_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/emst_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/emst_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/emst_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,143 @@
+/**
+ * @file emst_test.cpp
+ *
+ * Test file for EMST methods.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/methods/emst/dtb.hpp>
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::emst;
+
+BOOST_AUTO_TEST_SUITE(EMSTTest);
+
+/**
+ * Simple emst test with small, synthetic dataset. This is an
+ * exhaustive test, which checks that each method for performing the calculation
+ * (dual-tree, single-tree, naive) produces the correct results. The dataset is
+ * in one dimension for simplicity -- the correct functionality of distance
+ * functions is not tested here.
+ */
+BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
+{
+ // Set up our data.
+ arma::mat data(1, 11);
+ data[0] = 0.05; // Row addressing is unnecessary (they are all 0).
+ data[1] = 0.37;
+ data[2] = 0.15;
+ data[3] = 1.25;
+ data[4] = 5.05;
+ data[5] = -0.22;
+ data[6] = -2.00;
+ data[7] = -1.30;
+ data[8] = 0.45;
+ data[9] = 0.91;
+ data[10] = 1.00;
+
+ // Now perform the actual calculation.
+ arma::mat results;
+
+ DualTreeBoruvka<> dtb(data);
+ dtb.ComputeMST(results);
+
+ // Now the exhaustive check for correctness.
+ BOOST_REQUIRE(results(0, 0) == 1);
+ BOOST_REQUIRE(results(1, 0) == 8);
+ BOOST_REQUIRE_CLOSE(results(2, 0), 0.08, 1e-5);
+
+ BOOST_REQUIRE(results(0, 1) == 9);
+ BOOST_REQUIRE(results(1, 1) == 10);
+ BOOST_REQUIRE_CLOSE(results(2, 1), 0.09, 1e-5);
+
+ BOOST_REQUIRE(results(0, 2) == 0);
+ BOOST_REQUIRE(results(1, 2) == 2);
+ BOOST_REQUIRE_CLOSE(results(2, 2), 0.1, 1e-5);
+
+ BOOST_REQUIRE(results(0, 3) == 1);
+ BOOST_REQUIRE(results(1, 3) == 2);
+ BOOST_REQUIRE_CLOSE(results(2, 3), 0.22, 1e-5);
+
+ BOOST_REQUIRE(results(0, 4) == 3);
+ BOOST_REQUIRE(results(1, 4) == 10);
+ BOOST_REQUIRE_CLOSE(results(2, 4), 0.25, 1e-5);
+
+ BOOST_REQUIRE(results(0, 5) == 0);
+ BOOST_REQUIRE(results(1, 5) == 5);
+ BOOST_REQUIRE_CLOSE(results(2, 5), 0.27, 1e-5);
+
+ BOOST_REQUIRE(results(0, 6) == 8);
+ BOOST_REQUIRE(results(1, 6) == 9);
+ BOOST_REQUIRE_CLOSE(results(2, 6), 0.46, 1e-5);
+
+ BOOST_REQUIRE(results(0, 7) == 6);
+ BOOST_REQUIRE(results(1, 7) == 7);
+ BOOST_REQUIRE_CLOSE(results(2, 7), 0.7, 1e-5);
+
+ BOOST_REQUIRE(results(0, 8) == 5);
+ BOOST_REQUIRE(results(1, 8) == 7);
+ BOOST_REQUIRE_CLOSE(results(2, 8), 1.08, 1e-5);
+
+ BOOST_REQUIRE(results(0, 9) == 3);
+ BOOST_REQUIRE(results(1, 9) == 4);
+ BOOST_REQUIRE_CLOSE(results(2, 9), 3.8, 1e-5);
+}
+
+/**
+ * Test the dual tree method against the naive computation.
+ *
+ * Errors are produced if the results are not identical.
+ */
+BOOST_AUTO_TEST_CASE(DualTreeVsNaive)
+{
+ arma::mat inputData;
+
+ // Hard-coded filename: bad!
+ // Code duplication: also bad!
+ if (!data::Load("test_data_3_1000.csv", inputData))
+ BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
+
+ // Set up matrices to work with (may not be necessary with no ALIAS_MATRIX?).
+ arma::mat dualData = arma::trans(inputData);
+ arma::mat naiveData = arma::trans(inputData);
+
+ // Reset parameters from last test.
+ DualTreeBoruvka<> dtb(dualData);
+
+ arma::mat dualResults;
+ dtb.ComputeMST(dualResults);
+
+ // Set naive mode.
+ DualTreeBoruvka<> dtbNaive(naiveData, true);
+
+ arma::mat naiveResults;
+ dtbNaive.ComputeMST(naiveResults);
+
+ BOOST_REQUIRE(dualResults.n_cols == naiveResults.n_cols);
+ BOOST_REQUIRE(dualResults.n_rows == naiveResults.n_rows);
+
+ for (size_t i = 0; i < dualResults.n_cols; i++)
+ {
+ BOOST_REQUIRE(dualResults(0, i) == naiveResults(0, i));
+ BOOST_REQUIRE(dualResults(1, i) == naiveResults(1, i));
+ BOOST_REQUIRE_CLOSE(dualResults(2, i), naiveResults(2, i), 1e-5);
+ }
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/fastmks_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/fastmks_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/fastmks_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,139 +0,0 @@
-/**
- * @file fastmks_test.cpp
- * @author Ryan Curtin
- *
- * Ensure that fast max-kernel search is correct.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/methods/fastmks/fastmks.hpp>
-#include <mlpack/core/kernels/linear_kernel.hpp>
-#include <mlpack/core/kernels/polynomial_kernel.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::tree;
-using namespace mlpack::fastmks;
-using namespace mlpack::kernel;
-
-BOOST_AUTO_TEST_SUITE(FastMKSTest);
-
-/**
- * Compare single-tree and naive.
- */
-BOOST_AUTO_TEST_CASE(SingleTreeVsNaive)
-{
- // First create a random dataset.
- arma::mat data;
- data.randn(5, 1000);
- LinearKernel lk;
-
- // Now run FastMKS naively.
- FastMKS<LinearKernel> naive(data, lk, false, true);
-
- arma::Mat<size_t> naiveIndices;
- arma::mat naiveProducts;
- naive.Search(10, naiveIndices, naiveProducts);
-
- // Now run it in single-tree mode.
- FastMKS<LinearKernel> single(data, lk, true);
-
- arma::Mat<size_t> singleIndices;
- arma::mat singleProducts;
- single.Search(10, singleIndices, singleProducts);
-
- // Compare the results.
- for (size_t q = 0; q < singleIndices.n_cols; ++q)
- {
- for (size_t r = 0; r < singleIndices.n_rows; ++r)
- {
- BOOST_REQUIRE_EQUAL(singleIndices(r, q), naiveIndices(r, q));
- BOOST_REQUIRE_CLOSE(singleProducts(r, q), naiveProducts(r, q), 1e-5);
- }
- }
-}
-
-/**
- * Compare dual-tree and naive.
- */
-BOOST_AUTO_TEST_CASE(DualTreeVsNaive)
-{
- // First create a random dataset.
- arma::mat data;
- data.randn(10, 5000);
- LinearKernel lk;
-
- // Now run FastMKS naively.
- FastMKS<LinearKernel> naive(data, lk, false, true);
-
- arma::Mat<size_t> naiveIndices;
- arma::mat naiveProducts;
- naive.Search(10, naiveIndices, naiveProducts);
-
- // Now run it in dual-tree mode.
- FastMKS<LinearKernel> tree(data, lk);
-
- arma::Mat<size_t> treeIndices;
- arma::mat treeProducts;
- tree.Search(10, treeIndices, treeProducts);
-
- for (size_t q = 0; q < treeIndices.n_cols; ++q)
- {
- for (size_t r = 0; r < treeIndices.n_rows; ++r)
- {
- BOOST_REQUIRE_EQUAL(treeIndices(r, q), naiveIndices(r, q));
- BOOST_REQUIRE_CLOSE(treeProducts(r, q), naiveProducts(r, q), 1e-5);
- }
- }
-}
-
-/**
- * Compare dual-tree and single-tree on a larger dataset.
- */
-BOOST_AUTO_TEST_CASE(DualTreeVsSingleTree)
-{
- // First create a random dataset.
- arma::mat data;
- data.randu(20, 15000);
- PolynomialKernel pk(5.0, 2.5);
-
- FastMKS<PolynomialKernel> single(data, pk, true);
-
- arma::Mat<size_t> singleIndices;
- arma::mat singleProducts;
- single.Search(10, singleIndices, singleProducts);
-
- // Now run it in dual-tree mode.
- FastMKS<PolynomialKernel> tree(data, pk);
-
- arma::Mat<size_t> treeIndices;
- arma::mat treeProducts;
- tree.Search(10, treeIndices, treeProducts);
-
- for (size_t q = 0; q < treeIndices.n_cols; ++q)
- {
- for (size_t r = 0; r < treeIndices.n_rows; ++r)
- {
- BOOST_REQUIRE_EQUAL(treeIndices(r, q), singleIndices(r, q));
- BOOST_REQUIRE_CLOSE(treeProducts(r, q), singleProducts(r, q), 1e-5);
- }
- }
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/fastmks_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/fastmks_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/fastmks_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/fastmks_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,139 @@
+/**
+ * @file fastmks_test.cpp
+ * @author Ryan Curtin
+ *
+ * Ensure that fast max-kernel search is correct.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/methods/fastmks/fastmks.hpp>
+#include <mlpack/core/kernels/linear_kernel.hpp>
+#include <mlpack/core/kernels/polynomial_kernel.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::tree;
+using namespace mlpack::fastmks;
+using namespace mlpack::kernel;
+
+BOOST_AUTO_TEST_SUITE(FastMKSTest);
+
+/**
+ * Compare single-tree and naive.
+ */
+BOOST_AUTO_TEST_CASE(SingleTreeVsNaive)
+{
+ // First create a random dataset.
+ arma::mat data;
+ data.randn(5, 1000);
+ LinearKernel lk;
+
+ // Now run FastMKS naively.
+ FastMKS<LinearKernel> naive(data, lk, false, true);
+
+ arma::Mat<size_t> naiveIndices;
+ arma::mat naiveProducts;
+ naive.Search(10, naiveIndices, naiveProducts);
+
+ // Now run it in single-tree mode.
+ FastMKS<LinearKernel> single(data, lk, true);
+
+ arma::Mat<size_t> singleIndices;
+ arma::mat singleProducts;
+ single.Search(10, singleIndices, singleProducts);
+
+ // Compare the results.
+ for (size_t q = 0; q < singleIndices.n_cols; ++q)
+ {
+ for (size_t r = 0; r < singleIndices.n_rows; ++r)
+ {
+ BOOST_REQUIRE_EQUAL(singleIndices(r, q), naiveIndices(r, q));
+ BOOST_REQUIRE_CLOSE(singleProducts(r, q), naiveProducts(r, q), 1e-5);
+ }
+ }
+}
+
+/**
+ * Compare dual-tree and naive.
+ */
+BOOST_AUTO_TEST_CASE(DualTreeVsNaive)
+{
+ // First create a random dataset.
+ arma::mat data;
+ data.randn(10, 5000);
+ LinearKernel lk;
+
+ // Now run FastMKS naively.
+ FastMKS<LinearKernel> naive(data, lk, false, true);
+
+ arma::Mat<size_t> naiveIndices;
+ arma::mat naiveProducts;
+ naive.Search(10, naiveIndices, naiveProducts);
+
+ // Now run it in dual-tree mode.
+ FastMKS<LinearKernel> tree(data, lk);
+
+ arma::Mat<size_t> treeIndices;
+ arma::mat treeProducts;
+ tree.Search(10, treeIndices, treeProducts);
+
+ for (size_t q = 0; q < treeIndices.n_cols; ++q)
+ {
+ for (size_t r = 0; r < treeIndices.n_rows; ++r)
+ {
+ BOOST_REQUIRE_EQUAL(treeIndices(r, q), naiveIndices(r, q));
+ BOOST_REQUIRE_CLOSE(treeProducts(r, q), naiveProducts(r, q), 1e-5);
+ }
+ }
+}
+
+/**
+ * Compare dual-tree and single-tree on a larger dataset.
+ */
+BOOST_AUTO_TEST_CASE(DualTreeVsSingleTree)
+{
+ // First create a random dataset.
+ arma::mat data;
+ data.randu(20, 15000);
+ PolynomialKernel pk(5.0, 2.5);
+
+ FastMKS<PolynomialKernel> single(data, pk, true);
+
+ arma::Mat<size_t> singleIndices;
+ arma::mat singleProducts;
+ single.Search(10, singleIndices, singleProducts);
+
+ // Now run it in dual-tree mode.
+ FastMKS<PolynomialKernel> tree(data, pk);
+
+ arma::Mat<size_t> treeIndices;
+ arma::mat treeProducts;
+ tree.Search(10, treeIndices, treeProducts);
+
+ for (size_t q = 0; q < treeIndices.n_cols; ++q)
+ {
+ for (size_t r = 0; r < treeIndices.n_rows; ++r)
+ {
+ BOOST_REQUIRE_EQUAL(treeIndices(r, q), singleIndices(r, q));
+ BOOST_REQUIRE_CLOSE(treeProducts(r, q), singleProducts(r, q), 1e-5);
+ }
+ }
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/gmm_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/gmm_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/gmm_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,637 +0,0 @@
-/**
- * @file gmm_test.cpp
- * @author Ryan Curtin
- *
- * Test for the Gaussian Mixture Model class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-
-#include <mlpack/methods/gmm/gmm.hpp>
-#include <mlpack/methods/gmm/phi.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::gmm;
-
-BOOST_AUTO_TEST_SUITE(GMMTest);
-/**
- * Test the phi() function, in the univariate Gaussian case.
- */
-BOOST_AUTO_TEST_CASE(UnivariatePhiTest)
-{
- // Simple case.
- BOOST_REQUIRE_CLOSE(phi(0.0, 0.0, 1.0), 0.398942280401433, 1e-5);
-
- // A few more cases...
- BOOST_REQUIRE_CLOSE(phi(0.0, 0.0, 2.0), 0.282094791773878, 1e-5);
-
- BOOST_REQUIRE_CLOSE(phi(1.0, 0.0, 1.0), 0.241970724519143, 1e-5);
- BOOST_REQUIRE_CLOSE(phi(-1.0, 0.0, 1.0), 0.241970724519143, 1e-5);
-
- BOOST_REQUIRE_CLOSE(phi(1.0, 0.0, 2.0), 0.219695644733861, 1e-5);
- BOOST_REQUIRE_CLOSE(phi(-1.0, 0.0, 2.0), 0.219695644733861, 1e-5);
-
- BOOST_REQUIRE_CLOSE(phi(1.0, 1.0, 1.0), 0.398942280401433, 1e-5);
-
- BOOST_REQUIRE_CLOSE(phi(-1.0, 1.0, 2.0), 0.103776874355149, 1e-5);
-}
-
-/**
- * Test the phi() function, in the multivariate Gaussian case.
- */
-BOOST_AUTO_TEST_CASE(MultivariatePhiTest)
-{
- // Simple case.
- arma::vec mean = "0 0";
- arma::mat cov = "1 0; 0 1";
- arma::vec x = "0 0";
-
- BOOST_REQUIRE_CLOSE(phi(x, mean, cov), 0.159154943091895, 1e-5);
-
- cov = "2 0; 0 2";
-
- BOOST_REQUIRE_CLOSE(phi(x, mean, cov), 0.0795774715459477, 1e-5);
-
- x = "1 1";
-
- BOOST_REQUIRE_CLOSE(phi(x, mean, cov), 0.0482661763150270, 1e-5);
- BOOST_REQUIRE_CLOSE(phi(-x, mean, cov), 0.0482661763150270, 1e-5);
-
- mean = "1 1";
-
- BOOST_REQUIRE_CLOSE(phi(x, mean, cov), 0.0795774715459477, 1e-5);
- BOOST_REQUIRE_CLOSE(phi(-x, -mean, cov), 0.0795774715459477, 1e-5);
-
- cov = "2 1.5; 1 4";
-
- BOOST_REQUIRE_CLOSE(phi(x, mean, cov), 0.0624257046546403, 1e-5);
- BOOST_REQUIRE_CLOSE(phi(-x, -mean, cov), 0.0624257046546403, 1e-5);
-
- x = "-1 4";
-
- BOOST_REQUIRE_CLOSE(phi(x, mean, cov), 0.00144014867515135, 1e-5);
- BOOST_REQUIRE_CLOSE(phi(-x, mean, cov), 0.00133352162064845, 1e-5);
-
- // Higher-dimensional case.
- x = "0 1 2 3 4";
- mean = "5 6 3 3 2";
- cov = "6 1 1 0 2;"
- "0 7 1 0 1;"
- "1 1 4 1 1;"
- "1 0 1 7 0;"
- "2 0 1 1 6";
-
- BOOST_REQUIRE_CLOSE(phi(x, mean, cov), 1.02531207499358e-6, 1e-5);
- BOOST_REQUIRE_CLOSE(phi(-x, -mean, cov), 1.02531207499358e-6, 1e-5);
- BOOST_REQUIRE_CLOSE(phi(x, -mean, cov), 1.06784794079363e-8, 1e-5);
- BOOST_REQUIRE_CLOSE(phi(-x, mean, cov), 1.06784794079363e-8, 1e-5);
-}
-
-/**
- * Test the phi() function, for multiple points in the multivariate Gaussian
- * case.
- */
-BOOST_AUTO_TEST_CASE(MultipointMultivariatePhiTest)
-{
- // Same case as before.
- arma::vec mean = "5 6 3 3 2";
- arma::mat cov = "6 1 1 0 2; 0 7 1 0 1; 1 1 4 1 1; 1 0 1 7 0; 2 0 1 1 6";
-
- arma::mat points = "0 3 2 2 3 4;"
- "1 2 2 1 0 0;"
- "2 3 0 5 5 6;"
- "3 7 8 0 1 1;"
- "4 8 1 1 0 0;";
-
- arma::vec phis;
- phi(points, mean, cov, phis);
-
- BOOST_REQUIRE_EQUAL(phis.n_elem, 6);
-
- BOOST_REQUIRE_CLOSE(phis(0), 1.02531207499358e-6, 1e-5);
- BOOST_REQUIRE_CLOSE(phis(1), 1.82353695848039e-7, 1e-5);
- BOOST_REQUIRE_CLOSE(phis(2), 1.29759261892949e-6, 1e-5);
- BOOST_REQUIRE_CLOSE(phis(3), 1.33218060268258e-6, 1e-5);
- BOOST_REQUIRE_CLOSE(phis(4), 1.12120427975708e-6, 1e-5);
- BOOST_REQUIRE_CLOSE(phis(5), 4.57951032485297e-7, 1e-5);
-}
-
-/**
- * Test GMM::Probability() for a single observation for a few cases.
- */
-BOOST_AUTO_TEST_CASE(GMMProbabilityTest)
-{
- // Create a GMM.
- GMM<> gmm(2, 2);
- gmm.Means()[0] = "0 0";
- gmm.Means()[1] = "3 3";
- gmm.Covariances()[0] = "1 0; 0 1";
- gmm.Covariances()[1] = "2 1; 1 2";
- gmm.Weights() = "0.3 0.7";
-
- // Now test a couple observations. These comparisons are calculated by hand.
- BOOST_REQUIRE_CLOSE(gmm.Probability("0 0"), 0.05094887202, 1e-5);
- BOOST_REQUIRE_CLOSE(gmm.Probability("1 1"), 0.03451996667, 1e-5);
- BOOST_REQUIRE_CLOSE(gmm.Probability("2 2"), 0.04696302254, 1e-5);
- BOOST_REQUIRE_CLOSE(gmm.Probability("3 3"), 0.06432759685, 1e-5);
- BOOST_REQUIRE_CLOSE(gmm.Probability("-1 5.3"), 2.503171278804e-6, 1e-5);
- BOOST_REQUIRE_CLOSE(gmm.Probability("1.4 0"), 0.024676682176, 1e-5);
-}
-
-/**
- * Test GMM::Probability() for a single observation being from a particular
- * component.
- */
-BOOST_AUTO_TEST_CASE(GMMProbabilityComponentTest)
-{
- // Create a GMM (same as the last test).
- GMM<> gmm(2, 2);
- gmm.Means()[0] = "0 0";
- gmm.Means()[1] = "3 3";
- gmm.Covariances()[0] = "1 0; 0 1";
- gmm.Covariances()[1] = "2 1; 1 2";
- gmm.Weights() = "0.3 0.7";
-
- // Now test a couple observations. These comparisons are calculated by hand.
- BOOST_REQUIRE_CLOSE(gmm.Probability("0 0", 0), 0.0477464829276, 1e-5);
- BOOST_REQUIRE_CLOSE(gmm.Probability("0 0", 1), 0.0032023890978, 1e-5);
-
- BOOST_REQUIRE_CLOSE(gmm.Probability("1 1", 0), 0.0175649494573, 1e-5);
- BOOST_REQUIRE_CLOSE(gmm.Probability("1 1", 1), 0.0169550172159, 1e-5);
-
- BOOST_REQUIRE_CLOSE(gmm.Probability("2 2", 0), 8.7450733951e-4, 1e-5);
- BOOST_REQUIRE_CLOSE(gmm.Probability("2 2", 1), 0.0460885151993, 1e-5);
-
- BOOST_REQUIRE_CLOSE(gmm.Probability("3 3", 0), 5.8923841039e-6, 1e-5);
- BOOST_REQUIRE_CLOSE(gmm.Probability("3 3", 1), 0.0643217044658, 1e-5);
-
- BOOST_REQUIRE_CLOSE(gmm.Probability("-1 5.3", 0), 2.30212100302e-8, 1e-5);
- BOOST_REQUIRE_CLOSE(gmm.Probability("-1 5.3", 1), 2.48015006877e-6, 1e-5);
-
- BOOST_REQUIRE_CLOSE(gmm.Probability("1.4 0", 0), 0.0179197849738, 1e-5);
- BOOST_REQUIRE_CLOSE(gmm.Probability("1.4 0", 1), 0.0067568972024, 1e-5);
-}
-
-/**
- * Test training a model on only one Gaussian (randomly generated) in two
- * dimensions. We will vary the dataset size from small to large. The EM
- * algorithm is used for training the GMM.
- */
-BOOST_AUTO_TEST_CASE(GMMTrainEMOneGaussian)
-{
- for (size_t iterations = 0; iterations < 4; iterations++)
- {
- // Determine random covariance and mean.
- arma::vec mean;
- mean.randu(2);
- arma::vec covar;
- covar.randu(2);
-
- arma::mat data;
- data.randn(2 /* dimension */, 150 * pow(10, (iterations / 3.0)));
-
- // Now apply mean and covariance.
- data.row(0) *= covar(0);
- data.row(1) *= covar(1);
-
- data.row(0) += mean(0);
- data.row(1) += mean(1);
-
- // Now, train the model.
- GMM<> gmm(1, 2);
- double likelihood = gmm.Estimate(data, 10);
-
- arma::vec actualMean = arma::mean(data, 1);
- arma::mat actualCovar = ccov(data, 1 /* biased estimator */);
-
- // Check the model to see that it is correct.
- BOOST_REQUIRE_CLOSE((gmm.Means()[0])[0], actualMean(0), 1e-5);
- BOOST_REQUIRE_CLOSE((gmm.Means()[0])[1], actualMean(1), 1e-5);
-
- BOOST_REQUIRE_CLOSE((gmm.Covariances()[0])(0, 0), actualCovar(0, 0), 1e-5);
- BOOST_REQUIRE_CLOSE((gmm.Covariances()[0])(0, 1), actualCovar(0, 1), 1e-5);
- BOOST_REQUIRE_CLOSE((gmm.Covariances()[0])(1, 0), actualCovar(1, 0), 1e-5);
- BOOST_REQUIRE_CLOSE((gmm.Covariances()[0])(1, 1), actualCovar(1, 1), 1e-5);
-
- BOOST_REQUIRE_CLOSE(gmm.Weights()[0], 1.0, 1e-5);
- }
-}
-
-/**
- * Test a training model on multiple Gaussians in higher dimensionality than
- * two. We will hold the dataset size constant at 10k points. The EM algorithm
- * is used for training the GMM.
- */
-BOOST_AUTO_TEST_CASE(GMMTrainEMMultipleGaussians)
-{
- // Higher dimensionality gives us a greater chance of having separated
- // Gaussians.
- size_t dims = 8;
- size_t gaussians = 3;
-
- // Generate dataset.
- arma::mat data;
- data.zeros(dims, 500);
-
- std::vector<arma::vec> means(gaussians);
- std::vector<arma::mat> covars(gaussians);
- arma::vec weights(gaussians);
- arma::Col<size_t> counts(gaussians);
-
- // Choose weights randomly.
- weights.zeros();
- while (weights.min() < 0.02)
- {
- weights.randu(gaussians);
- weights /= accu(weights);
- }
-
- for (size_t i = 0; i < gaussians; i++)
- counts[i] = round(weights[i] * (data.n_cols - gaussians));
- // Ensure one point minimum in each.
- counts += 1;
-
- // Account for rounding errors (possibly necessary).
- counts[gaussians - 1] += (data.n_cols - arma::accu(counts));
-
- // Build each Gaussian individually.
- size_t point = 0;
- for (size_t i = 0; i < gaussians; i++)
- {
- arma::mat gaussian;
- gaussian.randn(dims, counts[i]);
-
- // Randomly generate mean and covariance.
- means[i].randu(dims);
- means[i] -= 0.5;
- means[i] *= 50;
-
- // We need to make sure the covariance is positive definite. We will take a
- // random matrix C and then set our covariance to 4 * C * C', which will be
- // positive semidefinite.
- covars[i].randu(dims, dims);
- covars[i] *= 4 * trans(covars[i]);
-
- data.cols(point, point + counts[i] - 1) = (covars[i] * gaussian + means[i]
- * arma::ones<arma::rowvec>(counts[i]));
-
- // Calculate the actual means and covariances because they will probably
- // be different (this is easier to do before we shuffle the points).
- means[i] = arma::mean(data.cols(point, point + counts[i] - 1), 1);
- covars[i] = ccov(data.cols(point, point + counts[i] - 1), 1 /* biased */);
-
- point += counts[i];
- }
-
- // Calculate actual weights.
- for (size_t i = 0; i < gaussians; i++)
- weights[i] = (double) counts[i] / data.n_cols;
-
- // Now train the model.
- GMM<> gmm(gaussians, dims);
- double likelihood = gmm.Estimate(data, 10);
-
- arma::uvec sortRef = sort_index(weights);
- arma::uvec sortTry = sort_index(gmm.Weights());
-
- // Check the model to see that it is correct.
- for (size_t i = 0; i < gaussians; i++)
- {
- // Check the mean.
- for (size_t j = 0; j < dims; j++)
- BOOST_REQUIRE_CLOSE((gmm.Means()[sortTry[i]])[j],
- (means[sortRef[i]])[j], 1e-5);
-
- // Check the covariance.
- for (size_t row = 0; row < dims; row++)
- for (size_t col = 0; col < dims; col++)
- BOOST_REQUIRE_CLOSE((gmm.Covariances()[sortTry[i]])(row, col),
- (covars[sortRef[i]])(row, col), 1e-5);
-
- // Check the weight.
- BOOST_REQUIRE_CLOSE(gmm.Weights()[sortTry[i]], weights[sortRef[i]],
- 1e-5);
- }
-}
-
-/**
- * Train a single-gaussian mixture, but using the overload of Estimate() where
- * probabilities of the observation are given.
- */
-BOOST_AUTO_TEST_CASE(GMMTrainEMSingleGaussianWithProbability)
-{
- // Generate observations from a Gaussian distribution.
- distribution::GaussianDistribution d("0.5 1.0", "1.0 0.3; 0.3 1.0");
-
- // 10000 observations, each with random probability.
- arma::mat observations(2, 20000);
- for (size_t i = 0; i < 20000; i++)
- observations.col(i) = d.Random();
- arma::vec probabilities;
- probabilities.randu(20000); // Random probabilities.
-
- // Now train the model.
- GMM<> g(1, 2);
- double likelihood = g.Estimate(observations, probabilities, 10);
-
- // Check that it is trained correctly. 7% tolerance because of random error
- // present in observations.
- BOOST_REQUIRE_CLOSE(g.Means()[0][0], 0.5, 7.0);
- BOOST_REQUIRE_CLOSE(g.Means()[0][1], 1.0, 7.0);
-
- // 9% tolerance on the large numbers, 12% on the smaller numbers.
- BOOST_REQUIRE_CLOSE(g.Covariances()[0](0, 0), 1.0, 9.0);
- BOOST_REQUIRE_CLOSE(g.Covariances()[0](0, 1), 0.3, 12.0);
- BOOST_REQUIRE_CLOSE(g.Covariances()[0](1, 0), 0.3, 12.0);
- BOOST_REQUIRE_CLOSE(g.Covariances()[0](1, 1), 1.0, 9.0);
-
- BOOST_REQUIRE_CLOSE(g.Weights()[0], 1.0, 1e-5);
-}
-
-/**
- * Train a multi-Gaussian mixture, using the overload of Estimate() where
- * probabilities of the observation are given.
- */
-BOOST_AUTO_TEST_CASE(GMMTrainEMMultipleGaussiansWithProbability)
-{
- srand(time(NULL));
-
- // We'll have three Gaussian distributions from this mixture, and one Gaussian
- // not from this mixture (but we'll put some observations from it in).
- distribution::GaussianDistribution d1("0.0 1.0 0.0", "1.0 0.0 0.5;"
- "0.0 0.8 0.1;"
- "0.5 0.1 1.0");
- distribution::GaussianDistribution d2("2.0 -1.0 5.0", "3.0 0.0 0.5;"
- "0.0 1.2 0.2;"
- "0.5 0.2 1.3");
- distribution::GaussianDistribution d3("0.0 5.0 -3.0", "2.0 0.0 0.0;"
- "0.0 0.3 0.0;"
- "0.0 0.0 1.0");
- distribution::GaussianDistribution d4("4.0 2.0 2.0", "1.5 0.6 0.5;"
- "0.6 1.1 0.1;"
- "0.5 0.1 1.0");
-
- // Now we'll generate points and probabilities. 1500 points. Slower than I
- // would like...
- arma::mat points(3, 2000);
- arma::vec probabilities(2000);
-
- for (size_t i = 0; i < 2000; i++)
- {
- double randValue = math::Random();
-
- if (randValue <= 0.20) // p(d1) = 0.20
- points.col(i) = d1.Random();
- else if (randValue <= 0.50) // p(d2) = 0.30
- points.col(i) = d2.Random();
- else if (randValue <= 0.90) // p(d3) = 0.40
- points.col(i) = d3.Random();
- else // p(d4) = 0.10
- points.col(i) = d4.Random();
-
- // Set the probability right. If it came from this mixture, it should be
- // 0.97 plus or minus a little bit of noise. If not, then it should be 0.03
- // plus or minus a little bit of noise. The base probability (minus the
- // noise) is parameterizable for easy modification of the test.
- double confidence = 0.995;
- double perturbation = math::Random(-0.005, 0.005);
-
- if (randValue <= 0.90)
- probabilities(i) = confidence + perturbation;
- else
- probabilities(i) = (1 - confidence) + perturbation;
- }
-
- // Now train the model.
- GMM<> g(4, 3); // 3 dimensions, 4 components.
-
- double likelihood = g.Estimate(points, probabilities, 8);
-
- // Now check the results. We need to order by weights so that when we do the
- // checking, things will be correct.
- arma::uvec sortedIndices = sort_index(g.Weights());
-
- // The tolerances in our checks are quite large, but it is good to remember
- // that we introduced a fair amount of random noise into this whole process.
-
- // First Gaussian (d4).
- BOOST_REQUIRE_SMALL(g.Weights()[sortedIndices[0]] - 0.1, 0.075);
-
- for (size_t i = 0; i < 3; i++)
- BOOST_REQUIRE_SMALL((g.Means()[sortedIndices[0]][i] - d4.Mean()[i]), 0.30);
-
- for (size_t row = 0; row < 3; row++)
- for (size_t col = 0; col < 3; col++)
- BOOST_REQUIRE_SMALL((g.Covariances()[sortedIndices[0]](row, col) -
- d4.Covariance()(row, col)), 0.60); // Big tolerance! Lots of noise.
-
- // Second Gaussian (d1).
- BOOST_REQUIRE_SMALL(g.Weights()[sortedIndices[1]] - 0.2, 0.075);
-
- for (size_t i = 0; i < 3; i++)
- BOOST_REQUIRE_SMALL((g.Means()[sortedIndices[1]][i] - d1.Mean()[i]), 0.25);
-
- for (size_t row = 0; row < 3; row++)
- for (size_t col = 0; col < 3; col++)
- BOOST_REQUIRE_SMALL((g.Covariances()[sortedIndices[1]](row, col) -
- d1.Covariance()(row, col)), 0.55); // Big tolerance! Lots of noise.
-
- // Third Gaussian (d2).
- BOOST_REQUIRE_SMALL(g.Weights()[sortedIndices[2]] - 0.3, 0.1);
-
- for (size_t i = 0; i < 3; i++)
- BOOST_REQUIRE_SMALL((g.Means()[sortedIndices[2]][i] - d2.Mean()[i]), 0.25);
-
- for (size_t row = 0; row < 3; row++)
- for (size_t col = 0; col < 3; col++)
- BOOST_REQUIRE_SMALL((g.Covariances()[sortedIndices[2]](row, col) -
- d2.Covariance()(row, col)), 0.50); // Big tolerance! Lots of noise.
-
- // Fourth gaussian (d3).
- BOOST_REQUIRE_SMALL(g.Weights()[sortedIndices[3]] - 0.4, 0.1);
-
- for (size_t i = 0; i < 3; ++i)
- BOOST_REQUIRE_SMALL((g.Means()[sortedIndices[3]][i] - d3.Mean()[i]), 0.25);
-
- for (size_t row = 0; row < 3; ++row)
- for (size_t col = 0; col < 3; ++col)
- BOOST_REQUIRE_SMALL((g.Covariances()[sortedIndices[3]](row, col) -
- d3.Covariance()(row, col)), 0.50);
-}
-
-/**
- * Make sure generating observations randomly works. We'll do this by
- * generating a bunch of random observations and then re-training on them, and
- * hope that our model is the same.
- */
-BOOST_AUTO_TEST_CASE(GMMRandomTest)
-{
- // Simple GMM distribution.
- GMM<> gmm(2, 2);
- gmm.Weights() = arma::vec("0.40 0.60");
-
- // N([2.25 3.10], [1.00 0.20; 0.20 0.89])
- gmm.Means()[0] = arma::vec("2.25 3.10");
- gmm.Covariances()[0] = arma::mat("1.00 0.60; 0.60 0.89");
-
- // N([4.10 1.01], [1.00 0.00; 0.00 1.01])
- gmm.Means()[1] = arma::vec("4.10 1.01");
- gmm.Covariances()[1] = arma::mat("1.00 0.70; 0.70 1.01");
-
- // Now generate a bunch of observations.
- arma::mat observations(2, 4000);
- for (size_t i = 0; i < 4000; i++)
- observations.col(i) = gmm.Random();
-
- // A new one which we'll train.
- GMM<> gmm2(2, 2);
- double likelihood = gmm2.Estimate(observations, 10);
-
- // Now check the results. We need to order by weights so that when we do the
- // checking, things will be correct.
- arma::uvec sortedIndices = sort_index(gmm2.Weights());
-
- // Now check that the parameters are the same. Tolerances are kind of big
- // because we only used 2000 observations.
- BOOST_REQUIRE_CLOSE(gmm.Weights()[0], gmm2.Weights()[sortedIndices[0]], 7.0);
- BOOST_REQUIRE_CLOSE(gmm.Weights()[1], gmm2.Weights()[sortedIndices[1]], 7.0);
-
- BOOST_REQUIRE_CLOSE(gmm.Means()[0][0], gmm2.Means()[sortedIndices[0]][0],
- 6.5);
- BOOST_REQUIRE_CLOSE(gmm.Means()[0][1], gmm2.Means()[sortedIndices[0]][1],
- 6.5);
-
- BOOST_REQUIRE_CLOSE(gmm.Covariances()[0](0, 0),
- gmm2.Covariances()[sortedIndices[0]](0, 0), 13.0);
- BOOST_REQUIRE_CLOSE(gmm.Covariances()[0](0, 1),
- gmm2.Covariances()[sortedIndices[0]](0, 1), 22.0);
- BOOST_REQUIRE_CLOSE(gmm.Covariances()[0](1, 0),
- gmm2.Covariances()[sortedIndices[0]](1, 0), 22.0);
- BOOST_REQUIRE_CLOSE(gmm.Covariances()[0](1, 1),
- gmm2.Covariances()[sortedIndices[0]](1, 1), 13.0);
-
- BOOST_REQUIRE_CLOSE(gmm.Means()[1][0], gmm2.Means()[sortedIndices[1]][0],
- 6.5);
- BOOST_REQUIRE_CLOSE(gmm.Means()[1][1], gmm2.Means()[sortedIndices[1]][1],
- 6.5);
-
- BOOST_REQUIRE_CLOSE(gmm.Covariances()[1](0, 0),
- gmm2.Covariances()[sortedIndices[1]](0, 0), 13.0);
- BOOST_REQUIRE_CLOSE(gmm.Covariances()[1](0, 1),
- gmm2.Covariances()[sortedIndices[1]](0, 1), 22.0);
- BOOST_REQUIRE_CLOSE(gmm.Covariances()[1](1, 0),
- gmm2.Covariances()[sortedIndices[1]](1, 0), 22.0);
- BOOST_REQUIRE_CLOSE(gmm.Covariances()[1](1, 1),
- gmm2.Covariances()[sortedIndices[1]](1, 1), 13.0);
-}
-
-/**
- * Test classification of observations by component.
- */
-BOOST_AUTO_TEST_CASE(GMMClassifyTest)
-{
- // First create a Gaussian with a few components.
- GMM<> gmm(3, 2);
- gmm.Means()[0] = "0 0";
- gmm.Means()[1] = "1 3";
- gmm.Means()[2] = "-2 -2";
- gmm.Covariances()[0] = "1 0; 0 1";
- gmm.Covariances()[1] = "3 2; 2 3";
- gmm.Covariances()[2] = "2.2 1.4; 1.4 5.1";
- gmm.Weights() = "0.6 0.25 0.15";
-
- arma::mat observations = arma::trans(arma::mat(
- " 0 0;"
- " 0 1;"
- " 0 2;"
- " 1 -2;"
- " 2 -2;"
- "-2 0;"
- " 5 5;"
- "-2 -2;"
- " 3 3;"
- "25 25;"
- "-1 -1;"
- "-3 -3;"
- "-5 1"));
-
- arma::Col<size_t> classes;
-
- gmm.Classify(observations, classes);
-
- // Test classification of points. Classifications produced by hand.
- BOOST_REQUIRE_EQUAL(classes[ 0], 0);
- BOOST_REQUIRE_EQUAL(classes[ 1], 0);
- BOOST_REQUIRE_EQUAL(classes[ 2], 1);
- BOOST_REQUIRE_EQUAL(classes[ 3], 0);
- BOOST_REQUIRE_EQUAL(classes[ 4], 0);
- BOOST_REQUIRE_EQUAL(classes[ 5], 0);
- BOOST_REQUIRE_EQUAL(classes[ 6], 1);
- BOOST_REQUIRE_EQUAL(classes[ 7], 2);
- BOOST_REQUIRE_EQUAL(classes[ 8], 1);
- BOOST_REQUIRE_EQUAL(classes[ 9], 1);
- BOOST_REQUIRE_EQUAL(classes[10], 0);
- BOOST_REQUIRE_EQUAL(classes[11], 2);
- BOOST_REQUIRE_EQUAL(classes[12], 2);
-}
-
-BOOST_AUTO_TEST_CASE(GMMLoadSaveTest)
-{
- // Create a GMM, save it, and load it.
- GMM<> gmm(10, 4);
- gmm.Weights().randu();
-
- for (size_t i = 0; i < gmm.Gaussians(); ++i)
- {
- gmm.Means()[i].randu();
- gmm.Covariances()[i].randu();
- }
-
- gmm.Save("test-gmm-save.xml");
-
- GMM<> gmm2;
- gmm2.Load("test-gmm-save.xml");
-
- // Remove clutter.
- remove("test-gmm-save.xml");
-
- BOOST_REQUIRE_EQUAL(gmm.Gaussians(), gmm2.Gaussians());
- BOOST_REQUIRE_EQUAL(gmm.Dimensionality(), gmm2.Dimensionality());
-
- for (size_t i = 0; i < gmm.Dimensionality(); ++i)
- BOOST_REQUIRE_CLOSE(gmm.Weights()[i], gmm2.Weights()[i], 1e-3);
-
- for (size_t i = 0; i < gmm.Gaussians(); ++i)
- {
- for (size_t j = 0; j < gmm.Dimensionality(); ++j)
- BOOST_REQUIRE_CLOSE(gmm.Means()[i][j], gmm2.Means()[i][j], 1e-3);
-
- for (size_t j = 0; j < gmm.Dimensionality(); ++j)
- {
- for (size_t k = 0; k < gmm.Dimensionality(); ++k)
- {
- BOOST_REQUIRE_CLOSE(gmm.Covariances()[i](j, k),
- gmm2.Covariances()[i](j, k), 1e-3);
- }
- }
- }
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/gmm_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/gmm_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/gmm_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/gmm_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,637 @@
+/**
+ * @file gmm_test.cpp
+ * @author Ryan Curtin
+ *
+ * Test for the Gaussian Mixture Model class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+
+#include <mlpack/methods/gmm/gmm.hpp>
+#include <mlpack/methods/gmm/phi.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::gmm;
+
+BOOST_AUTO_TEST_SUITE(GMMTest);
+/**
+ * Test the phi() function, in the univariate Gaussian case.
+ */
+BOOST_AUTO_TEST_CASE(UnivariatePhiTest)
+{
+ // Simple case.
+ BOOST_REQUIRE_CLOSE(phi(0.0, 0.0, 1.0), 0.398942280401433, 1e-5);
+
+ // A few more cases...
+ BOOST_REQUIRE_CLOSE(phi(0.0, 0.0, 2.0), 0.282094791773878, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(phi(1.0, 0.0, 1.0), 0.241970724519143, 1e-5);
+ BOOST_REQUIRE_CLOSE(phi(-1.0, 0.0, 1.0), 0.241970724519143, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(phi(1.0, 0.0, 2.0), 0.219695644733861, 1e-5);
+ BOOST_REQUIRE_CLOSE(phi(-1.0, 0.0, 2.0), 0.219695644733861, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(phi(1.0, 1.0, 1.0), 0.398942280401433, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(phi(-1.0, 1.0, 2.0), 0.103776874355149, 1e-5);
+}
+
+/**
+ * Test the phi() function, in the multivariate Gaussian case.
+ */
+BOOST_AUTO_TEST_CASE(MultivariatePhiTest)
+{
+ // Simple case.
+ arma::vec mean = "0 0";
+ arma::mat cov = "1 0; 0 1";
+ arma::vec x = "0 0";
+
+ BOOST_REQUIRE_CLOSE(phi(x, mean, cov), 0.159154943091895, 1e-5);
+
+ cov = "2 0; 0 2";
+
+ BOOST_REQUIRE_CLOSE(phi(x, mean, cov), 0.0795774715459477, 1e-5);
+
+ x = "1 1";
+
+ BOOST_REQUIRE_CLOSE(phi(x, mean, cov), 0.0482661763150270, 1e-5);
+ BOOST_REQUIRE_CLOSE(phi(-x, mean, cov), 0.0482661763150270, 1e-5);
+
+ mean = "1 1";
+
+ BOOST_REQUIRE_CLOSE(phi(x, mean, cov), 0.0795774715459477, 1e-5);
+ BOOST_REQUIRE_CLOSE(phi(-x, -mean, cov), 0.0795774715459477, 1e-5);
+
+ cov = "2 1.5; 1 4";
+
+ BOOST_REQUIRE_CLOSE(phi(x, mean, cov), 0.0624257046546403, 1e-5);
+ BOOST_REQUIRE_CLOSE(phi(-x, -mean, cov), 0.0624257046546403, 1e-5);
+
+ x = "-1 4";
+
+ BOOST_REQUIRE_CLOSE(phi(x, mean, cov), 0.00144014867515135, 1e-5);
+ BOOST_REQUIRE_CLOSE(phi(-x, mean, cov), 0.00133352162064845, 1e-5);
+
+ // Higher-dimensional case.
+ x = "0 1 2 3 4";
+ mean = "5 6 3 3 2";
+ cov = "6 1 1 0 2;"
+ "0 7 1 0 1;"
+ "1 1 4 1 1;"
+ "1 0 1 7 0;"
+ "2 0 1 1 6";
+
+ BOOST_REQUIRE_CLOSE(phi(x, mean, cov), 1.02531207499358e-6, 1e-5);
+ BOOST_REQUIRE_CLOSE(phi(-x, -mean, cov), 1.02531207499358e-6, 1e-5);
+ BOOST_REQUIRE_CLOSE(phi(x, -mean, cov), 1.06784794079363e-8, 1e-5);
+ BOOST_REQUIRE_CLOSE(phi(-x, mean, cov), 1.06784794079363e-8, 1e-5);
+}
+
+/**
+ * Test the phi() function, for multiple points in the multivariate Gaussian
+ * case.
+ */
+BOOST_AUTO_TEST_CASE(MultipointMultivariatePhiTest)
+{
+ // Same case as before.
+ arma::vec mean = "5 6 3 3 2";
+ arma::mat cov = "6 1 1 0 2; 0 7 1 0 1; 1 1 4 1 1; 1 0 1 7 0; 2 0 1 1 6";
+
+ arma::mat points = "0 3 2 2 3 4;"
+ "1 2 2 1 0 0;"
+ "2 3 0 5 5 6;"
+ "3 7 8 0 1 1;"
+ "4 8 1 1 0 0;";
+
+ arma::vec phis;
+ phi(points, mean, cov, phis);
+
+ BOOST_REQUIRE_EQUAL(phis.n_elem, 6);
+
+ BOOST_REQUIRE_CLOSE(phis(0), 1.02531207499358e-6, 1e-5);
+ BOOST_REQUIRE_CLOSE(phis(1), 1.82353695848039e-7, 1e-5);
+ BOOST_REQUIRE_CLOSE(phis(2), 1.29759261892949e-6, 1e-5);
+ BOOST_REQUIRE_CLOSE(phis(3), 1.33218060268258e-6, 1e-5);
+ BOOST_REQUIRE_CLOSE(phis(4), 1.12120427975708e-6, 1e-5);
+ BOOST_REQUIRE_CLOSE(phis(5), 4.57951032485297e-7, 1e-5);
+}
+
+/**
+ * Test GMM::Probability() for a single observation for a few cases.
+ */
+BOOST_AUTO_TEST_CASE(GMMProbabilityTest)
+{
+ // Create a GMM.
+ GMM<> gmm(2, 2);
+ gmm.Means()[0] = "0 0";
+ gmm.Means()[1] = "3 3";
+ gmm.Covariances()[0] = "1 0; 0 1";
+ gmm.Covariances()[1] = "2 1; 1 2";
+ gmm.Weights() = "0.3 0.7";
+
+ // Now test a couple observations. These comparisons are calculated by hand.
+ BOOST_REQUIRE_CLOSE(gmm.Probability("0 0"), 0.05094887202, 1e-5);
+ BOOST_REQUIRE_CLOSE(gmm.Probability("1 1"), 0.03451996667, 1e-5);
+ BOOST_REQUIRE_CLOSE(gmm.Probability("2 2"), 0.04696302254, 1e-5);
+ BOOST_REQUIRE_CLOSE(gmm.Probability("3 3"), 0.06432759685, 1e-5);
+ BOOST_REQUIRE_CLOSE(gmm.Probability("-1 5.3"), 2.503171278804e-6, 1e-5);
+ BOOST_REQUIRE_CLOSE(gmm.Probability("1.4 0"), 0.024676682176, 1e-5);
+}
+
+/**
+ * Test GMM::Probability() for a single observation being from a particular
+ * component.
+ */
+BOOST_AUTO_TEST_CASE(GMMProbabilityComponentTest)
+{
+ // Create a GMM (same as the last test).
+ GMM<> gmm(2, 2);
+ gmm.Means()[0] = "0 0";
+ gmm.Means()[1] = "3 3";
+ gmm.Covariances()[0] = "1 0; 0 1";
+ gmm.Covariances()[1] = "2 1; 1 2";
+ gmm.Weights() = "0.3 0.7";
+
+ // Now test a couple observations. These comparisons are calculated by hand.
+ BOOST_REQUIRE_CLOSE(gmm.Probability("0 0", 0), 0.0477464829276, 1e-5);
+ BOOST_REQUIRE_CLOSE(gmm.Probability("0 0", 1), 0.0032023890978, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(gmm.Probability("1 1", 0), 0.0175649494573, 1e-5);
+ BOOST_REQUIRE_CLOSE(gmm.Probability("1 1", 1), 0.0169550172159, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(gmm.Probability("2 2", 0), 8.7450733951e-4, 1e-5);
+ BOOST_REQUIRE_CLOSE(gmm.Probability("2 2", 1), 0.0460885151993, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(gmm.Probability("3 3", 0), 5.8923841039e-6, 1e-5);
+ BOOST_REQUIRE_CLOSE(gmm.Probability("3 3", 1), 0.0643217044658, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(gmm.Probability("-1 5.3", 0), 2.30212100302e-8, 1e-5);
+ BOOST_REQUIRE_CLOSE(gmm.Probability("-1 5.3", 1), 2.48015006877e-6, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(gmm.Probability("1.4 0", 0), 0.0179197849738, 1e-5);
+ BOOST_REQUIRE_CLOSE(gmm.Probability("1.4 0", 1), 0.0067568972024, 1e-5);
+}
+
+/**
+ * Test training a model on only one Gaussian (randomly generated) in two
+ * dimensions. We will vary the dataset size from small to large. The EM
+ * algorithm is used for training the GMM.
+ */
+BOOST_AUTO_TEST_CASE(GMMTrainEMOneGaussian)
+{
+ for (size_t iterations = 0; iterations < 4; iterations++)
+ {
+ // Determine random covariance and mean.
+ arma::vec mean;
+ mean.randu(2);
+ arma::vec covar;
+ covar.randu(2);
+
+ arma::mat data;
+ data.randn(2 /* dimension */, 150 * pow(10, (iterations / 3.0)));
+
+ // Now apply mean and covariance.
+ data.row(0) *= covar(0);
+ data.row(1) *= covar(1);
+
+ data.row(0) += mean(0);
+ data.row(1) += mean(1);
+
+ // Now, train the model.
+ GMM<> gmm(1, 2);
+ double likelihood = gmm.Estimate(data, 10);
+
+ arma::vec actualMean = arma::mean(data, 1);
+ arma::mat actualCovar = ccov(data, 1 /* biased estimator */);
+
+ // Check the model to see that it is correct.
+ BOOST_REQUIRE_CLOSE((gmm.Means()[0])[0], actualMean(0), 1e-5);
+ BOOST_REQUIRE_CLOSE((gmm.Means()[0])[1], actualMean(1), 1e-5);
+
+ BOOST_REQUIRE_CLOSE((gmm.Covariances()[0])(0, 0), actualCovar(0, 0), 1e-5);
+ BOOST_REQUIRE_CLOSE((gmm.Covariances()[0])(0, 1), actualCovar(0, 1), 1e-5);
+ BOOST_REQUIRE_CLOSE((gmm.Covariances()[0])(1, 0), actualCovar(1, 0), 1e-5);
+ BOOST_REQUIRE_CLOSE((gmm.Covariances()[0])(1, 1), actualCovar(1, 1), 1e-5);
+
+ BOOST_REQUIRE_CLOSE(gmm.Weights()[0], 1.0, 1e-5);
+ }
+}
+
+/**
+ * Test a training model on multiple Gaussians in higher dimensionality than
+ * two. We will hold the dataset size constant at 10k points. The EM algorithm
+ * is used for training the GMM.
+ */
+BOOST_AUTO_TEST_CASE(GMMTrainEMMultipleGaussians)
+{
+ // Higher dimensionality gives us a greater chance of having separated
+ // Gaussians.
+ size_t dims = 8;
+ size_t gaussians = 3;
+
+ // Generate dataset.
+ arma::mat data;
+ data.zeros(dims, 500);
+
+ std::vector<arma::vec> means(gaussians);
+ std::vector<arma::mat> covars(gaussians);
+ arma::vec weights(gaussians);
+ arma::Col<size_t> counts(gaussians);
+
+ // Choose weights randomly.
+ weights.zeros();
+ while (weights.min() < 0.02)
+ {
+ weights.randu(gaussians);
+ weights /= accu(weights);
+ }
+
+ for (size_t i = 0; i < gaussians; i++)
+ counts[i] = round(weights[i] * (data.n_cols - gaussians));
+ // Ensure one point minimum in each.
+ counts += 1;
+
+ // Account for rounding errors (possibly necessary).
+ counts[gaussians - 1] += (data.n_cols - arma::accu(counts));
+
+ // Build each Gaussian individually.
+ size_t point = 0;
+ for (size_t i = 0; i < gaussians; i++)
+ {
+ arma::mat gaussian;
+ gaussian.randn(dims, counts[i]);
+
+ // Randomly generate mean and covariance.
+ means[i].randu(dims);
+ means[i] -= 0.5;
+ means[i] *= 50;
+
+ // We need to make sure the covariance is positive definite. We will take a
+ // random matrix C and then set our covariance to 4 * C * C', which will be
+ // positive semidefinite.
+ covars[i].randu(dims, dims);
+ covars[i] *= 4 * trans(covars[i]);
+
+ data.cols(point, point + counts[i] - 1) = (covars[i] * gaussian + means[i]
+ * arma::ones<arma::rowvec>(counts[i]));
+
+ // Calculate the actual means and covariances because they will probably
+ // be different (this is easier to do before we shuffle the points).
+ means[i] = arma::mean(data.cols(point, point + counts[i] - 1), 1);
+ covars[i] = ccov(data.cols(point, point + counts[i] - 1), 1 /* biased */);
+
+ point += counts[i];
+ }
+
+ // Calculate actual weights.
+ for (size_t i = 0; i < gaussians; i++)
+ weights[i] = (double) counts[i] / data.n_cols;
+
+ // Now train the model.
+ GMM<> gmm(gaussians, dims);
+ double likelihood = gmm.Estimate(data, 10);
+
+ arma::uvec sortRef = sort_index(weights);
+ arma::uvec sortTry = sort_index(gmm.Weights());
+
+ // Check the model to see that it is correct.
+ for (size_t i = 0; i < gaussians; i++)
+ {
+ // Check the mean.
+ for (size_t j = 0; j < dims; j++)
+ BOOST_REQUIRE_CLOSE((gmm.Means()[sortTry[i]])[j],
+ (means[sortRef[i]])[j], 1e-5);
+
+ // Check the covariance.
+ for (size_t row = 0; row < dims; row++)
+ for (size_t col = 0; col < dims; col++)
+ BOOST_REQUIRE_CLOSE((gmm.Covariances()[sortTry[i]])(row, col),
+ (covars[sortRef[i]])(row, col), 1e-5);
+
+ // Check the weight.
+ BOOST_REQUIRE_CLOSE(gmm.Weights()[sortTry[i]], weights[sortRef[i]],
+ 1e-5);
+ }
+}
+
+/**
+ * Train a single-gaussian mixture, but using the overload of Estimate() where
+ * probabilities of the observation are given.
+ */
+BOOST_AUTO_TEST_CASE(GMMTrainEMSingleGaussianWithProbability)
+{
+ // Generate observations from a Gaussian distribution.
+ distribution::GaussianDistribution d("0.5 1.0", "1.0 0.3; 0.3 1.0");
+
+ // 10000 observations, each with random probability.
+ arma::mat observations(2, 20000);
+ for (size_t i = 0; i < 20000; i++)
+ observations.col(i) = d.Random();
+ arma::vec probabilities;
+ probabilities.randu(20000); // Random probabilities.
+
+ // Now train the model.
+ GMM<> g(1, 2);
+ double likelihood = g.Estimate(observations, probabilities, 10);
+
+ // Check that it is trained correctly. 7% tolerance because of random error
+ // present in observations.
+ BOOST_REQUIRE_CLOSE(g.Means()[0][0], 0.5, 7.0);
+ BOOST_REQUIRE_CLOSE(g.Means()[0][1], 1.0, 7.0);
+
+ // 9% tolerance on the large numbers, 12% on the smaller numbers.
+ BOOST_REQUIRE_CLOSE(g.Covariances()[0](0, 0), 1.0, 9.0);
+ BOOST_REQUIRE_CLOSE(g.Covariances()[0](0, 1), 0.3, 12.0);
+ BOOST_REQUIRE_CLOSE(g.Covariances()[0](1, 0), 0.3, 12.0);
+ BOOST_REQUIRE_CLOSE(g.Covariances()[0](1, 1), 1.0, 9.0);
+
+ BOOST_REQUIRE_CLOSE(g.Weights()[0], 1.0, 1e-5);
+}
+
+/**
+ * Train a multi-Gaussian mixture, using the overload of Estimate() where
+ * probabilities of the observation are given.
+ */
+BOOST_AUTO_TEST_CASE(GMMTrainEMMultipleGaussiansWithProbability)
+{
+ srand(time(NULL));
+
+ // We'll have three Gaussian distributions from this mixture, and one Gaussian
+ // not from this mixture (but we'll put some observations from it in).
+ distribution::GaussianDistribution d1("0.0 1.0 0.0", "1.0 0.0 0.5;"
+ "0.0 0.8 0.1;"
+ "0.5 0.1 1.0");
+ distribution::GaussianDistribution d2("2.0 -1.0 5.0", "3.0 0.0 0.5;"
+ "0.0 1.2 0.2;"
+ "0.5 0.2 1.3");
+ distribution::GaussianDistribution d3("0.0 5.0 -3.0", "2.0 0.0 0.0;"
+ "0.0 0.3 0.0;"
+ "0.0 0.0 1.0");
+ distribution::GaussianDistribution d4("4.0 2.0 2.0", "1.5 0.6 0.5;"
+ "0.6 1.1 0.1;"
+ "0.5 0.1 1.0");
+
+ // Now we'll generate points and probabilities. 1500 points. Slower than I
+ // would like...
+ arma::mat points(3, 2000);
+ arma::vec probabilities(2000);
+
+ for (size_t i = 0; i < 2000; i++)
+ {
+ double randValue = math::Random();
+
+ if (randValue <= 0.20) // p(d1) = 0.20
+ points.col(i) = d1.Random();
+ else if (randValue <= 0.50) // p(d2) = 0.30
+ points.col(i) = d2.Random();
+ else if (randValue <= 0.90) // p(d3) = 0.40
+ points.col(i) = d3.Random();
+ else // p(d4) = 0.10
+ points.col(i) = d4.Random();
+
+ // Set the probability right. If it came from this mixture, it should be
+ // 0.97 plus or minus a little bit of noise. If not, then it should be 0.03
+ // plus or minus a little bit of noise. The base probability (minus the
+ // noise) is parameterizable for easy modification of the test.
+ double confidence = 0.995;
+ double perturbation = math::Random(-0.005, 0.005);
+
+ if (randValue <= 0.90)
+ probabilities(i) = confidence + perturbation;
+ else
+ probabilities(i) = (1 - confidence) + perturbation;
+ }
+
+ // Now train the model.
+ GMM<> g(4, 3); // 3 dimensions, 4 components.
+
+ double likelihood = g.Estimate(points, probabilities, 8);
+
+ // Now check the results. We need to order by weights so that when we do the
+ // checking, things will be correct.
+ arma::uvec sortedIndices = sort_index(g.Weights());
+
+ // The tolerances in our checks are quite large, but it is good to remember
+ // that we introduced a fair amount of random noise into this whole process.
+
+ // First Gaussian (d4).
+ BOOST_REQUIRE_SMALL(g.Weights()[sortedIndices[0]] - 0.1, 0.075);
+
+ for (size_t i = 0; i < 3; i++)
+ BOOST_REQUIRE_SMALL((g.Means()[sortedIndices[0]][i] - d4.Mean()[i]), 0.30);
+
+ for (size_t row = 0; row < 3; row++)
+ for (size_t col = 0; col < 3; col++)
+ BOOST_REQUIRE_SMALL((g.Covariances()[sortedIndices[0]](row, col) -
+ d4.Covariance()(row, col)), 0.60); // Big tolerance! Lots of noise.
+
+ // Second Gaussian (d1).
+ BOOST_REQUIRE_SMALL(g.Weights()[sortedIndices[1]] - 0.2, 0.075);
+
+ for (size_t i = 0; i < 3; i++)
+ BOOST_REQUIRE_SMALL((g.Means()[sortedIndices[1]][i] - d1.Mean()[i]), 0.25);
+
+ for (size_t row = 0; row < 3; row++)
+ for (size_t col = 0; col < 3; col++)
+ BOOST_REQUIRE_SMALL((g.Covariances()[sortedIndices[1]](row, col) -
+ d1.Covariance()(row, col)), 0.55); // Big tolerance! Lots of noise.
+
+ // Third Gaussian (d2).
+ BOOST_REQUIRE_SMALL(g.Weights()[sortedIndices[2]] - 0.3, 0.1);
+
+ for (size_t i = 0; i < 3; i++)
+ BOOST_REQUIRE_SMALL((g.Means()[sortedIndices[2]][i] - d2.Mean()[i]), 0.25);
+
+ for (size_t row = 0; row < 3; row++)
+ for (size_t col = 0; col < 3; col++)
+ BOOST_REQUIRE_SMALL((g.Covariances()[sortedIndices[2]](row, col) -
+ d2.Covariance()(row, col)), 0.50); // Big tolerance! Lots of noise.
+
+ // Fourth gaussian (d3).
+ BOOST_REQUIRE_SMALL(g.Weights()[sortedIndices[3]] - 0.4, 0.1);
+
+ for (size_t i = 0; i < 3; ++i)
+ BOOST_REQUIRE_SMALL((g.Means()[sortedIndices[3]][i] - d3.Mean()[i]), 0.25);
+
+ for (size_t row = 0; row < 3; ++row)
+ for (size_t col = 0; col < 3; ++col)
+ BOOST_REQUIRE_SMALL((g.Covariances()[sortedIndices[3]](row, col) -
+ d3.Covariance()(row, col)), 0.50);
+}
+
+/**
+ * Make sure generating observations randomly works. We'll do this by
+ * generating a bunch of random observations and then re-training on them, and
+ * hope that our model is the same.
+ */
+BOOST_AUTO_TEST_CASE(GMMRandomTest)
+{
+ // Simple GMM distribution.
+ GMM<> gmm(2, 2);
+ gmm.Weights() = arma::vec("0.40 0.60");
+
+ // N([2.25 3.10], [1.00 0.20; 0.20 0.89])
+ gmm.Means()[0] = arma::vec("2.25 3.10");
+ gmm.Covariances()[0] = arma::mat("1.00 0.60; 0.60 0.89");
+
+ // N([4.10 1.01], [1.00 0.00; 0.00 1.01])
+ gmm.Means()[1] = arma::vec("4.10 1.01");
+ gmm.Covariances()[1] = arma::mat("1.00 0.70; 0.70 1.01");
+
+ // Now generate a bunch of observations.
+ arma::mat observations(2, 4000);
+ for (size_t i = 0; i < 4000; i++)
+ observations.col(i) = gmm.Random();
+
+ // A new one which we'll train.
+ GMM<> gmm2(2, 2);
+ double likelihood = gmm2.Estimate(observations, 10);
+
+ // Now check the results. We need to order by weights so that when we do the
+ // checking, things will be correct.
+ arma::uvec sortedIndices = sort_index(gmm2.Weights());
+
+ // Now check that the parameters are the same. Tolerances are kind of big
+ // because we only used 2000 observations.
+ BOOST_REQUIRE_CLOSE(gmm.Weights()[0], gmm2.Weights()[sortedIndices[0]], 7.0);
+ BOOST_REQUIRE_CLOSE(gmm.Weights()[1], gmm2.Weights()[sortedIndices[1]], 7.0);
+
+ BOOST_REQUIRE_CLOSE(gmm.Means()[0][0], gmm2.Means()[sortedIndices[0]][0],
+ 6.5);
+ BOOST_REQUIRE_CLOSE(gmm.Means()[0][1], gmm2.Means()[sortedIndices[0]][1],
+ 6.5);
+
+ BOOST_REQUIRE_CLOSE(gmm.Covariances()[0](0, 0),
+ gmm2.Covariances()[sortedIndices[0]](0, 0), 13.0);
+ BOOST_REQUIRE_CLOSE(gmm.Covariances()[0](0, 1),
+ gmm2.Covariances()[sortedIndices[0]](0, 1), 22.0);
+ BOOST_REQUIRE_CLOSE(gmm.Covariances()[0](1, 0),
+ gmm2.Covariances()[sortedIndices[0]](1, 0), 22.0);
+ BOOST_REQUIRE_CLOSE(gmm.Covariances()[0](1, 1),
+ gmm2.Covariances()[sortedIndices[0]](1, 1), 13.0);
+
+ BOOST_REQUIRE_CLOSE(gmm.Means()[1][0], gmm2.Means()[sortedIndices[1]][0],
+ 6.5);
+ BOOST_REQUIRE_CLOSE(gmm.Means()[1][1], gmm2.Means()[sortedIndices[1]][1],
+ 6.5);
+
+ BOOST_REQUIRE_CLOSE(gmm.Covariances()[1](0, 0),
+ gmm2.Covariances()[sortedIndices[1]](0, 0), 13.0);
+ BOOST_REQUIRE_CLOSE(gmm.Covariances()[1](0, 1),
+ gmm2.Covariances()[sortedIndices[1]](0, 1), 22.0);
+ BOOST_REQUIRE_CLOSE(gmm.Covariances()[1](1, 0),
+ gmm2.Covariances()[sortedIndices[1]](1, 0), 22.0);
+ BOOST_REQUIRE_CLOSE(gmm.Covariances()[1](1, 1),
+ gmm2.Covariances()[sortedIndices[1]](1, 1), 13.0);
+}
+
+/**
+ * Test classification of observations by component.
+ */
+BOOST_AUTO_TEST_CASE(GMMClassifyTest)
+{
+ // First create a Gaussian with a few components.
+ GMM<> gmm(3, 2);
+ gmm.Means()[0] = "0 0";
+ gmm.Means()[1] = "1 3";
+ gmm.Means()[2] = "-2 -2";
+ gmm.Covariances()[0] = "1 0; 0 1";
+ gmm.Covariances()[1] = "3 2; 2 3";
+ gmm.Covariances()[2] = "2.2 1.4; 1.4 5.1";
+ gmm.Weights() = "0.6 0.25 0.15";
+
+ arma::mat observations = arma::trans(arma::mat(
+ " 0 0;"
+ " 0 1;"
+ " 0 2;"
+ " 1 -2;"
+ " 2 -2;"
+ "-2 0;"
+ " 5 5;"
+ "-2 -2;"
+ " 3 3;"
+ "25 25;"
+ "-1 -1;"
+ "-3 -3;"
+ "-5 1"));
+
+ arma::Col<size_t> classes;
+
+ gmm.Classify(observations, classes);
+
+ // Test classification of points. Classifications produced by hand.
+ BOOST_REQUIRE_EQUAL(classes[ 0], 0);
+ BOOST_REQUIRE_EQUAL(classes[ 1], 0);
+ BOOST_REQUIRE_EQUAL(classes[ 2], 1);
+ BOOST_REQUIRE_EQUAL(classes[ 3], 0);
+ BOOST_REQUIRE_EQUAL(classes[ 4], 0);
+ BOOST_REQUIRE_EQUAL(classes[ 5], 0);
+ BOOST_REQUIRE_EQUAL(classes[ 6], 1);
+ BOOST_REQUIRE_EQUAL(classes[ 7], 2);
+ BOOST_REQUIRE_EQUAL(classes[ 8], 1);
+ BOOST_REQUIRE_EQUAL(classes[ 9], 1);
+ BOOST_REQUIRE_EQUAL(classes[10], 0);
+ BOOST_REQUIRE_EQUAL(classes[11], 2);
+ BOOST_REQUIRE_EQUAL(classes[12], 2);
+}
+
+BOOST_AUTO_TEST_CASE(GMMLoadSaveTest)
+{
+ // Create a GMM, save it, and load it.
+ GMM<> gmm(10, 4);
+ gmm.Weights().randu();
+
+ for (size_t i = 0; i < gmm.Gaussians(); ++i)
+ {
+ gmm.Means()[i].randu();
+ gmm.Covariances()[i].randu();
+ }
+
+ gmm.Save("test-gmm-save.xml");
+
+ GMM<> gmm2;
+ gmm2.Load("test-gmm-save.xml");
+
+ // Remove clutter.
+ remove("test-gmm-save.xml");
+
+ BOOST_REQUIRE_EQUAL(gmm.Gaussians(), gmm2.Gaussians());
+ BOOST_REQUIRE_EQUAL(gmm.Dimensionality(), gmm2.Dimensionality());
+
+ for (size_t i = 0; i < gmm.Dimensionality(); ++i)
+ BOOST_REQUIRE_CLOSE(gmm.Weights()[i], gmm2.Weights()[i], 1e-3);
+
+ for (size_t i = 0; i < gmm.Gaussians(); ++i)
+ {
+ for (size_t j = 0; j < gmm.Dimensionality(); ++j)
+ BOOST_REQUIRE_CLOSE(gmm.Means()[i][j], gmm2.Means()[i][j], 1e-3);
+
+ for (size_t j = 0; j < gmm.Dimensionality(); ++j)
+ {
+ for (size_t k = 0; k < gmm.Dimensionality(); ++k)
+ {
+ BOOST_REQUIRE_CLOSE(gmm.Covariances()[i](j, k),
+ gmm2.Covariances()[i](j, k), 1e-3);
+ }
+ }
+ }
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/hmm_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/hmm_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/hmm_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,955 +0,0 @@
-/**
- * @file hmm_test.cpp
- *
- * Test file for HMMs.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/methods/hmm/hmm.hpp>
-#include <mlpack/methods/gmm/gmm.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::hmm;
-using namespace mlpack::distribution;
-using namespace mlpack::gmm;
-
-BOOST_AUTO_TEST_SUITE(HMMTest);
-
-/**
- * We will use the simple case proposed by Russell and Norvig in Artificial
- * Intelligence: A Modern Approach, 2nd Edition, around p.549.
- */
-BOOST_AUTO_TEST_CASE(SimpleDiscreteHMMTestViterbi)
-{
- // We have two hidden states: rain/dry. Two emission states: umbrella/no
- // umbrella.
- // In this example, the transition matrix is
- // rain dry
- // [[0.7 0.3] rain
- // [0.3 0.7]] dry
- // and the emission probability is
- // rain dry
- // [[0.9 0.2] umbrella
- // [0.1 0.8]] no umbrella
- arma::mat transition("0.7 0.3; 0.3 0.7");
- std::vector<DiscreteDistribution> emission(2);
- emission[0] = DiscreteDistribution("0.9 0.2");
- emission[1] = DiscreteDistribution("0.1 0.8");
-
- HMM<DiscreteDistribution> hmm(transition, emission);
-
- // Now let's take a sequence and find what the most likely state is.
- // We'll use the sequence [U U N U U] (U = umbrella, N = no umbrella) like on
- // p. 547.
- arma::mat observation = "0 0 1 0 0";
- arma::Col<size_t> states;
- hmm.Predict(observation, states);
-
- // Check each state.
- BOOST_REQUIRE_EQUAL(states[0], 0); // Rain.
- BOOST_REQUIRE_EQUAL(states[1], 0); // Rain.
- BOOST_REQUIRE_EQUAL(states[2], 1); // No rain.
- BOOST_REQUIRE_EQUAL(states[3], 0); // Rain.
- BOOST_REQUIRE_EQUAL(states[4], 0); // Rain.
-}
-
-/**
- * This example is from Borodovsky & Ekisheva, p. 80-81. It is just slightly
- * more complex.
- */
-BOOST_AUTO_TEST_CASE(BorodovskyHMMTestViterbi)
-{
- // Two hidden states: H (high GC content) and L (low GC content), as well as a
- // start state.
- arma::mat transition("0.0 0.0 0.0;"
- "0.5 0.5 0.4;"
- "0.5 0.5 0.6");
- // Four emission states: A, C, G, T. Start state doesn't emit...
- std::vector<DiscreteDistribution> emission(3);
- emission[0] = DiscreteDistribution("0.25 0.25 0.25 0.25");
- emission[1] = DiscreteDistribution("0.20 0.30 0.30 0.20");
- emission[2] = DiscreteDistribution("0.30 0.20 0.20 0.30");
-
- HMM<DiscreteDistribution> hmm(transition, emission);
-
- // GGCACTGAA.
- arma::mat observation("2 2 1 0 1 3 2 0 0");
- arma::Col<size_t> states;
- hmm.Predict(observation, states);
-
- // Most probable path is HHHLLLLLL.
- BOOST_REQUIRE_EQUAL(states[0], 1);
- BOOST_REQUIRE_EQUAL(states[1], 1);
- BOOST_REQUIRE_EQUAL(states[2], 1);
- BOOST_REQUIRE_EQUAL(states[3], 2);
- // This could actually be one of two states (equal probability).
- BOOST_REQUIRE((states[4] == 1) || (states[4] == 2));
- BOOST_REQUIRE_EQUAL(states[5], 2);
- // This could also be one of two states.
- BOOST_REQUIRE((states[6] == 1) || (states[6] == 2));
- BOOST_REQUIRE_EQUAL(states[7], 2);
- BOOST_REQUIRE_EQUAL(states[8], 2);
-}
-
-/**
- * Ensure that the forward-backward algorithm is correct.
- */
-BOOST_AUTO_TEST_CASE(ForwardBackwardTwoState)
-{
- arma::mat obs("3 3 2 1 1 1 1 3 3 1");
-
- arma::mat transition("0.1 0.9; 0.4 0.6");
- std::vector<DiscreteDistribution> emis(2);
- emis[0] = DiscreteDistribution("0.85 0.15 0.00 0.00");
- emis[1] = DiscreteDistribution("0.00 0.00 0.50 0.50");
-
- HMM<DiscreteDistribution> hmm(transition, emis);
-
- // Now check we are getting the same results as MATLAB for this sequence.
- arma::mat stateProb;
- arma::mat forwardProb;
- arma::mat backwardProb;
- arma::vec scales;
-
- double log = hmm.Estimate(obs, stateProb, forwardProb, backwardProb, scales);
-
- // All values obtained from MATLAB hmmdecode().
- BOOST_REQUIRE_CLOSE(log, -23.4349, 1e-3);
-
- BOOST_REQUIRE_SMALL(stateProb(0, 0), 1e-5);
- BOOST_REQUIRE_CLOSE(stateProb(1, 0), 1.0, 1e-5);
- BOOST_REQUIRE_SMALL(stateProb(0, 1), 1e-5);
- BOOST_REQUIRE_CLOSE(stateProb(1, 1), 1.0, 1e-5);
- BOOST_REQUIRE_SMALL(stateProb(0, 2), 1e-5);
- BOOST_REQUIRE_CLOSE(stateProb(1, 2), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(stateProb(0, 3), 1.0, 1e-5);
- BOOST_REQUIRE_SMALL(stateProb(1, 3), 1e-5);
- BOOST_REQUIRE_CLOSE(stateProb(0, 4), 1.0, 1e-5);
- BOOST_REQUIRE_SMALL(stateProb(1, 4), 1e-5);
- BOOST_REQUIRE_CLOSE(stateProb(0, 5), 1.0, 1e-5);
- BOOST_REQUIRE_SMALL(stateProb(1, 5), 1e-5);
- BOOST_REQUIRE_CLOSE(stateProb(0, 6), 1.0, 1e-5);
- BOOST_REQUIRE_SMALL(stateProb(1, 6), 1e-5);
- BOOST_REQUIRE_SMALL(stateProb(0, 7), 1e-5);
- BOOST_REQUIRE_CLOSE(stateProb(1, 7), 1.0, 1e-5);
- BOOST_REQUIRE_SMALL(stateProb(0, 8), 1e-5);
- BOOST_REQUIRE_CLOSE(stateProb(1, 8), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(stateProb(0, 9), 1.0, 1e-5);
- BOOST_REQUIRE_SMALL(stateProb(1, 9), 1e-5);
-}
-
-/**
- * In this example we try to estimate the transmission and emission matrices
- * based on some observations. We use the simplest possible model.
- */
-BOOST_AUTO_TEST_CASE(SimplestBaumWelchDiscreteHMM)
-{
- // Don't yet require a useful distribution. 1 state, 1 emission.
- HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));
-
- std::vector<arma::mat> observations;
- // Different lengths for each observation sequence.
- observations.push_back("0 0 0 0 0 0 0 0"); // 8 zeros.
- observations.push_back("0 0 0 0 0 0 0"); // 7 zeros.
- observations.push_back("0 0 0 0 0 0 0 0 0 0 0 0"); // 12 zeros.
- observations.push_back("0 0 0 0 0 0 0 0 0 0"); // 10 zeros.
-
- hmm.Train(observations);
-
- BOOST_REQUIRE_CLOSE(hmm.Emission()[0].Probability("0"), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(hmm.Transition()(0, 0), 1.0, 1e-5);
-}
-
-/**
- * A slightly more complex model to estimate.
- */
-BOOST_AUTO_TEST_CASE(SimpleBaumWelchDiscreteHMM)
-{
- HMM<DiscreteDistribution> hmm(1, 2); // 1 state, 2 emissions.
- // Randomize the emission matrix.
- hmm.Emission()[0].Probabilities() = arma::randu<arma::vec>(2);
- hmm.Emission()[0].Probabilities() /= accu(hmm.Emission()[0].Probabilities());
-
- // P(each emission) = 0.5.
- // I've been careful to make P(first emission = 0) = P(first emission = 1).
- std::vector<arma::mat> observations;
- observations.push_back("0 1 0 1 0 1 0 1 0 1 0 1");
- observations.push_back("0 0 0 0 0 0 1 1 1 1 1 1");
- observations.push_back("1 1 1 1 1 1 0 0 0 0 0 0");
- observations.push_back("1 1 1 0 0 0 1 1 1 0 0 0");
- observations.push_back("0 0 1 1 0 0 0 0 1 1 1 1");
- observations.push_back("1 1 1 0 0 0 1 1 1 0 0 0");
- observations.push_back("0 1 0 1 0 1 0 1 0 1 0 1");
- observations.push_back("0 0 0 0 0 0 1 1 1 1 1 1");
- observations.push_back("1 1 1 1 1 1 0 0 0 0 0 0");
- observations.push_back("1 1 1 0 0 0 1 1 1 0 0 0");
- observations.push_back("0 0 1 1 0 0 0 0 1 1 1 1");
- observations.push_back("1 1 1 0 0 0 1 1 1 0 0 0");
-
- hmm.Train(observations);
-
- BOOST_REQUIRE_CLOSE(hmm.Emission()[0].Probability("0"), 0.5, 1e-5);
- BOOST_REQUIRE_CLOSE(hmm.Emission()[0].Probability("1"), 0.5, 1e-5);
- BOOST_REQUIRE_CLOSE(hmm.Transition()(0, 0), 1.0, 1e-5);
-}
-
-/**
- * Increasing complexity, but still simple; 4 emissions, 2 states; the state can
- * be determined directly by the emission.
- */
-BOOST_AUTO_TEST_CASE(SimpleBaumWelchDiscreteHMM_2)
-{
- HMM<DiscreteDistribution> hmm(2, DiscreteDistribution(4));
-
- // A little bit of obfuscation to the solution.
- hmm.Transition() = arma::mat("0.1 0.4; 0.9 0.6");
- hmm.Emission()[0].Probabilities() = "0.85 0.15 0.00 0.00";
- hmm.Emission()[1].Probabilities() = "0.00 0.00 0.50 0.50";
-
- // True emission matrix:
- // [[0.4 0 ]
- // [0.6 0 ]
- // [0 0.2]
- // [0 0.8]]
-
- // True transmission matrix:
- // [[0.5 0.5]
- // [0.5 0.5]]
-
- // Generate observations randomly by hand. This is kinda ugly, but it works.
- std::vector<arma::mat> observations;
- size_t obsNum = 250; // Number of observations.
- size_t obsLen = 500; // Number of elements in each observation.
- for (size_t i = 0; i < obsNum; i++)
- {
- arma::mat observation(1, obsLen);
-
- size_t state = 0;
- size_t emission = 0;
-
- for (size_t obs = 0; obs < obsLen; obs++)
- {
- // See if state changed.
- double r = math::Random();
-
- if (r <= 0.5)
- state = 0;
- else
- state = 1;
-
- // Now set the observation.
- r = math::Random();
-
- switch (state)
- {
- // case 0 is not possible.
- case 0:
- if (r <= 0.4)
- emission = 0;
- else
- emission = 1;
- break;
- case 1:
- if (r <= 0.2)
- emission = 2;
- else
- emission = 3;
- break;
- }
-
- observation(0, obs) = emission;
- }
-
- observations.push_back(observation);
- }
-
- hmm.Train(observations);
-
- // Only require 2.5% tolerance, because this is a little fuzzier.
- BOOST_REQUIRE_CLOSE(hmm.Transition()(0, 0), 0.5, 2.5);
- BOOST_REQUIRE_CLOSE(hmm.Transition()(1, 0), 0.5, 2.5);
- BOOST_REQUIRE_CLOSE(hmm.Transition()(0, 1), 0.5, 2.5);
- BOOST_REQUIRE_CLOSE(hmm.Transition()(1, 1), 0.5, 2.5);
-
- BOOST_REQUIRE_CLOSE(hmm.Emission()[0].Probability("0"), 0.4, 2.5);
- BOOST_REQUIRE_CLOSE(hmm.Emission()[0].Probability("1"), 0.6, 2.5);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Probability("2"), 2.5);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Probability("3"), 2.5);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Probability("0"), 2.5);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Probability("1"), 2.5);
- BOOST_REQUIRE_CLOSE(hmm.Emission()[1].Probability("2"), 0.2, 2.5);
- BOOST_REQUIRE_CLOSE(hmm.Emission()[1].Probability("3"), 0.8, 2.5);
-}
-
-BOOST_AUTO_TEST_CASE(DiscreteHMMLabeledTrainTest)
-{
- // Generate a random Markov model with 3 hidden states and 6 observations.
- arma::mat transition;
- std::vector<DiscreteDistribution> emission(3);
-
- transition.randu(3, 3);
- emission[0].Probabilities() = arma::randu<arma::vec>(6);
- emission[0].Probabilities() /= accu(emission[0].Probabilities());
- emission[1].Probabilities() = arma::randu<arma::vec>(6);
- emission[1].Probabilities() /= accu(emission[1].Probabilities());
- emission[2].Probabilities() = arma::randu<arma::vec>(6);
- emission[2].Probabilities() /= accu(emission[2].Probabilities());
-
- // Normalize so they we have a correct transition matrix.
- for (size_t col = 0; col < 3; col++)
- transition.col(col) /= accu(transition.col(col));
-
- // Now generate sequences.
- size_t obsNum = 250;
- size_t obsLen = 800;
-
- std::vector<arma::mat> observations(obsNum);
- std::vector<arma::Col<size_t> > states(obsNum);
-
- for (size_t n = 0; n < obsNum; n++)
- {
- observations[n].set_size(1, obsLen);
- states[n].set_size(obsLen);
-
- // Random starting state.
- states[n][0] = math::RandInt(3);
-
- // Random starting observation.
- observations[n].col(0) = emission[states[n][0]].Random();
-
- // Now the rest of the observations.
- for (size_t t = 1; t < obsLen; t++)
- {
- // Choose random number for state transition.
- double state = math::Random();
-
- // Decide next state.
- double sumProb = 0;
- for (size_t st = 0; st < 3; st++)
- {
- sumProb += transition(st, states[n][t - 1]);
- if (sumProb >= state)
- {
- states[n][t] = st;
- break;
- }
- }
-
- // Decide observation.
- observations[n].col(t) = emission[states[n][t]].Random();
- }
- }
-
- // Now that our data is generated, we give the HMM the labeled data to train
- // on.
- HMM<DiscreteDistribution> hmm(3, DiscreteDistribution(6));
-
- hmm.Train(observations, states);
-
- // We can't use % tolerance here because percent error increases as the actual
- // value gets very small. So, instead, we just ensure that every value is no
- // more than 0.009 away from the actual value.
- for (size_t row = 0; row < hmm.Transition().n_rows; row++)
- for (size_t col = 0; col < hmm.Transition().n_cols; col++)
- BOOST_REQUIRE_SMALL(hmm.Transition()(row, col) - transition(row, col),
- 0.009);
-
- for (size_t col = 0; col < hmm.Emission().size(); col++)
- {
- for (size_t row = 0; row < hmm.Emission()[col].Probabilities().n_elem;
- row++)
- {
- arma::vec obs(1);
- obs[0] = row;
- BOOST_REQUIRE_SMALL(hmm.Emission()[col].Probability(obs) -
- emission[col].Probability(obs), 0.009);
- }
- }
-}
-
-/**
- * Make sure the Generate() function works for a uniformly distributed HMM;
- * we'll take many samples just to make sure.
- */
-BOOST_AUTO_TEST_CASE(DiscreteHMMSimpleGenerateTest)
-{
- // Very simple HMM. 4 emissions with equal probability and 2 states with
- // equal probability. The default transition and emission matrices satisfy
- // this property.
- HMM<DiscreteDistribution> hmm(2, DiscreteDistribution(4));
-
- // Now generate a really, really long sequence.
- arma::mat dataSeq;
- arma::Col<size_t> stateSeq;
-
- hmm.Generate(100000, dataSeq, stateSeq);
-
- // Now find the empirical probabilities of each state.
- arma::vec emissionProb(4);
- arma::vec stateProb(2);
- emissionProb.zeros();
- stateProb.zeros();
- for (size_t i = 0; i < 100000; i++)
- {
- emissionProb[(size_t) dataSeq.col(i)[0] + 0.5]++;
- stateProb[stateSeq[i]]++;
- }
-
- // Normalize so these are probabilities.
- emissionProb /= accu(emissionProb);
- stateProb /= accu(stateProb);
-
- // Now check that the probabilities are right. 2% tolerance.
- BOOST_REQUIRE_CLOSE(emissionProb[0], 0.25, 2.0);
- BOOST_REQUIRE_CLOSE(emissionProb[1], 0.25, 2.0);
- BOOST_REQUIRE_CLOSE(emissionProb[2], 0.25, 2.0);
- BOOST_REQUIRE_CLOSE(emissionProb[3], 0.25, 2.0);
-
- BOOST_REQUIRE_CLOSE(stateProb[0], 0.50, 2.0);
- BOOST_REQUIRE_CLOSE(stateProb[1], 0.50, 2.0);
-}
-
-/**
- * More complex test for Generate().
- */
-BOOST_AUTO_TEST_CASE(DiscreteHMMGenerateTest)
-{
- // 6 emissions, 4 states. Random transition and emission probability.
- arma::mat transition(4, 4);
- std::vector<DiscreteDistribution> emission(4);
- emission[0].Probabilities() = arma::randu<arma::vec>(6);
- emission[0].Probabilities() /= accu(emission[0].Probabilities());
- emission[1].Probabilities() = arma::randu<arma::vec>(6);
- emission[1].Probabilities() /= accu(emission[1].Probabilities());
- emission[2].Probabilities() = arma::randu<arma::vec>(6);
- emission[2].Probabilities() /= accu(emission[2].Probabilities());
- emission[3].Probabilities() = arma::randu<arma::vec>(6);
- emission[3].Probabilities() /= accu(emission[3].Probabilities());
-
- transition.randu();
-
- // Normalize matrix.
- for (size_t col = 0; col < 4; col++)
- transition.col(col) /= accu(transition.col(col));
-
- // Create HMM object.
- HMM<DiscreteDistribution> hmm(transition, emission);
-
- // We'll create a bunch of sequences.
- int numSeq = 400;
- int numObs = 3000;
- std::vector<arma::mat> sequences(numSeq);
- std::vector<arma::Col<size_t> > states(numSeq);
- for (int i = 0; i < numSeq; i++)
- {
- // Random starting state.
- size_t startState = math::RandInt(4);
-
- hmm.Generate(numObs, sequences[i], states[i], startState);
- }
-
- // Now we will calculate the full probabilities.
- HMM<DiscreteDistribution> hmm2(4, 6);
- hmm2.Train(sequences, states);
-
- // Check that training gives the same result. Exact tolerance of 0.005.
- for (size_t row = 0; row < 4; row++)
- for (size_t col = 0; col < 4; col++)
- BOOST_REQUIRE_SMALL(hmm.Transition()(row, col) -
- hmm2.Transition()(row, col), 0.005);
-
- for (size_t row = 0; row < 6; row++)
- {
- arma::vec obs(1);
- obs[0] = row;
- for (size_t col = 0; col < 4; col++)
- {
- BOOST_REQUIRE_SMALL(hmm.Emission()[col].Probability(obs) -
- hmm2.Emission()[col].Probability(obs), 0.005);
- }
- }
-}
-
-BOOST_AUTO_TEST_CASE(DiscreteHMMLogLikelihoodTest)
-{
- // Create a simple HMM with three states and four emissions.
- arma::mat transition("0.5 0.0 0.1;"
- "0.2 0.6 0.2;"
- "0.3 0.4 0.7");
- std::vector<DiscreteDistribution> emission(3);
- emission[0].Probabilities() = "0.75 0.25 0.00 0.00";
- emission[1].Probabilities() = "0.00 0.25 0.25 0.50";
- emission[2].Probabilities() = "0.10 0.40 0.40 0.10";
-
- HMM<DiscreteDistribution> hmm(transition, emission);
-
- // Now generate some sequences and check that the log-likelihood is the same
- // as MATLAB gives for this HMM.
- BOOST_REQUIRE_CLOSE(hmm.LogLikelihood("0 1 2 3"), -4.9887223949, 1e-5);
- BOOST_REQUIRE_CLOSE(hmm.LogLikelihood("1 2 0 0"), -6.0288487077, 1e-5);
- BOOST_REQUIRE_CLOSE(hmm.LogLikelihood("3 3 3 3"), -5.5544000018, 1e-5);
- BOOST_REQUIRE_CLOSE(hmm.LogLikelihood("0 2 2 1 2 3 0 0 1 3 1 0 0 3 1 2 2"),
- -24.51556128368, 1e-5);
-}
-
-/**
- * A simple test to make sure HMMs with Gaussian output distributions work.
- */
-BOOST_AUTO_TEST_CASE(GaussianHMMSimpleTest)
-{
- // We'll have two Gaussians, far away from each other, one corresponding to
- // each state.
- // E(0) ~ N([ 5.0 5.0], eye(2)).
- // E(1) ~ N([-5.0 -5.0], eye(2)).
- // The transition matrix is simple:
- // T = [[0.75 0.25]
- // [0.25 0.75]]
- GaussianDistribution g1("5.0 5.0", "1.0 0.0; 0.0 1.0");
- GaussianDistribution g2("-5.0 -5.0", "1.0 0.0; 0.0 1.0");
-
- arma::mat transition("0.75 0.25; 0.25 0.75");
-
- std::vector<GaussianDistribution> emission;
- emission.push_back(g1);
- emission.push_back(g2);
-
- HMM<GaussianDistribution> hmm(transition, emission);
-
- // Now, generate some sequences.
- arma::mat observations(2, 1000);
- arma::Col<size_t> classes(1000);
-
- // 1000-observations sequence.
- classes[0] = 0;
- observations.col(0) = g1.Random();
- for (size_t i = 1; i < 1000; i++)
- {
- double randValue = math::Random();
-
- if (randValue > 0.75) // Then we change state.
- classes[i] = (classes[i - 1] + 1) % 2;
- else
- classes[i] = classes[i - 1];
-
- if (classes[i] == 0)
- observations.col(i) = g1.Random();
- else
- observations.col(i) = g2.Random();
- }
-
- // Now predict the sequence.
- arma::Col<size_t> predictedClasses;
- arma::mat stateProb;
-
- hmm.Predict(observations, predictedClasses);
- hmm.Estimate(observations, stateProb);
-
- // Check that each prediction is right.
- for (size_t i = 0; i < 1000; i++)
- {
- BOOST_REQUIRE_EQUAL(predictedClasses[i], classes[i]);
-
- // The probability of the wrong class should be infinitesimal.
- BOOST_REQUIRE_SMALL(stateProb((classes[i] + 1) % 2, i), 0.001);
- }
-}
-
-/**
- * Ensure that Gaussian HMMs can be trained properly, for the labeled training
- * case and also for the unlabeled training case.
- */
-BOOST_AUTO_TEST_CASE(GaussianHMMTrainTest)
-{
- // Four emission Gaussians and three internal states. The goal is to estimate
- // the transition matrix correctly, and each distribution correctly.
- std::vector<GaussianDistribution> emission;
- emission.push_back(GaussianDistribution("0.0 0.0 0.0", "1.0 0.2 0.2;"
- "0.2 1.5 0.0;"
- "0.2 0.0 1.1"));
- emission.push_back(GaussianDistribution("2.0 1.0 5.0", "0.7 0.3 0.0;"
- "0.3 2.6 0.0;"
- "0.0 0.0 1.0"));
- emission.push_back(GaussianDistribution("5.0 0.0 0.5", "1.0 0.0 0.0;"
- "0.0 1.0 0.0;"
- "0.0 0.0 1.0"));
-
- arma::mat transition("0.3 0.5 0.7;"
- "0.3 0.4 0.1;"
- "0.4 0.1 0.2");
-
- // Now generate observations.
- std::vector<arma::mat> observations(100);
- std::vector<arma::Col<size_t> > states(100);
-
- for (size_t obs = 0; obs < 100; obs++)
- {
- observations[obs].set_size(3, 1000);
- states[obs].set_size(1000);
-
- // Always start in state zero.
- states[obs][0] = 0;
- observations[obs].col(0) = emission[0].Random();
-
- for (size_t t = 1; t < 1000; t++)
- {
- // Choose the state.
- double randValue = math::Random();
- double probSum = 0;
- for (size_t state = 0; state < 3; state++)
- {
- probSum += transition(state, states[obs][t - 1]);
- if (probSum >= randValue)
- {
- states[obs][t] = state;
- break;
- }
- }
-
- // Now choose the emission.
- observations[obs].col(t) = emission[states[obs][t]].Random();
- }
- }
-
- // Now that the data is generated, train the HMM.
- HMM<GaussianDistribution> hmm(3, GaussianDistribution(3));
-
- hmm.Train(observations, states);
-
- // We use an absolute tolerance of 0.01 for the transition matrices.
- // Check that the transition matrix is correct.
- for (size_t row = 0; row < 3; row++)
- for (size_t col = 0; col < 3; col++)
- BOOST_REQUIRE_SMALL(transition(row, col) - hmm.Transition()(row, col),
- 0.01);
-
- // Check that each distribution is correct.
- for (size_t dist = 0; dist < 3; dist++)
- {
- // Check that the mean is correct. Absolute tolerance of 0.04.
- for (size_t dim = 0; dim < 3; dim++)
- BOOST_REQUIRE_SMALL(hmm.Emission()[dist].Mean()(dim) -
- emission[dist].Mean()(dim), 0.04);
-
- // Check that the covariance is correct. Absolute tolerance of 0.075.
- for (size_t row = 0; row < 3; row++)
- for (size_t col = 0; col < 3; col++)
- BOOST_REQUIRE_SMALL(hmm.Emission()[dist].Covariance()(row, col) -
- emission[dist].Covariance()(row, col), 0.075);
- }
-
- // Now let's try it all again, but this time, unlabeled. Everything will fail
- // if we don't have a decent guess at the Gaussians, so we'll take a "poor"
- // guess at it ourselves. I won't use K-Means because we can't afford to add
- // the instability of that to our test. We'll leave the covariances as the
- // identity.
- HMM<GaussianDistribution> hmm2(3, GaussianDistribution(3));
- hmm2.Emission()[0].Mean() = "0.3 -0.2 0.1"; // Actual: [0 0 0].
- hmm2.Emission()[1].Mean() = "1.0 1.4 3.2"; // Actual: [2 1 5].
- hmm2.Emission()[2].Mean() = "3.1 -0.2 6.1"; // Actual: [5 0 5].
-
- // We'll only use 20 observation sequences to try and keep training time
- // shorter.
- observations.resize(20);
-
- hmm.Train(observations);
-
- // The tolerances are increased because there is more error in unlabeled
- // training; we use an absolute tolerance of 0.03 for the transition matrices.
- // Check that the transition matrix is correct.
- for (size_t row = 0; row < 3; row++)
- for (size_t col = 0; col < 3; col++)
- BOOST_REQUIRE_SMALL(transition(row, col) - hmm.Transition()(row, col),
- 0.03);
-
- // Check that each distribution is correct.
- for (size_t dist = 0; dist < 3; dist++)
- {
- // Check that the mean is correct. Absolute tolerance of 0.09.
- for (size_t dim = 0; dim < 3; dim++)
- BOOST_REQUIRE_SMALL(hmm.Emission()[dist].Mean()(dim) -
- emission[dist].Mean()(dim), 0.09);
-
- // Check that the covariance is correct. Absolute tolerance of 0.12.
- for (size_t row = 0; row < 3; row++)
- for (size_t col = 0; col < 3; col++)
- BOOST_REQUIRE_SMALL(hmm.Emission()[dist].Covariance()(row, col) -
- emission[dist].Covariance()(row, col), 0.12);
- }
-}
-
-/**
- * Make sure that a random sequence generated by a Gaussian HMM fits the
- * distribution correctly.
- */
-BOOST_AUTO_TEST_CASE(GaussianHMMGenerateTest)
-{
- // Our distribution will have three two-dimensional output Gaussians.
- HMM<GaussianDistribution> hmm(3, GaussianDistribution(2));
- hmm.Transition() = arma::mat("0.4 0.6 0.8; 0.2 0.2 0.1; 0.4 0.2 0.1");
- hmm.Emission()[0] = GaussianDistribution("0.0 0.0", "1.0 0.0; 0.0 1.0");
- hmm.Emission()[1] = GaussianDistribution("2.0 2.0", "1.0 0.5; 0.5 1.2");
- hmm.Emission()[2] = GaussianDistribution("-2.0 1.0", "2.0 0.1; 0.1 1.0");
-
- // Now we will generate a long sequence.
- std::vector<arma::mat> observations(1);
- std::vector<arma::Col<size_t> > states(1);
-
- // Start in state 1 (no reason).
- hmm.Generate(10000, observations[0], states[0], 1);
-
- HMM<GaussianDistribution> hmm2(3, GaussianDistribution(2));
-
- // Now estimate the HMM from the generated sequence.
- hmm2.Train(observations, states);
-
- // Check that the estimated matrices are the same.
- for (size_t row = 0; row < 3; row++)
- for (size_t col = 0; col < 3; col++)
- BOOST_REQUIRE_SMALL(hmm.Transition()(row, col) - hmm2.Transition()(row,
- col), 0.03);
-
- // Check that each Gaussian is the same.
- for (size_t em = 0; em < 3; em++)
- {
- // Check that the mean is the same.
- BOOST_REQUIRE_SMALL(hmm.Emission()[em].Mean()(0) -
- hmm2.Emission()[em].Mean()(0), 0.09);
- BOOST_REQUIRE_SMALL(hmm.Emission()[em].Mean()(1) -
- hmm2.Emission()[em].Mean()(1), 0.09);
-
- // Check that the covariances are the same.
- BOOST_REQUIRE_SMALL(hmm.Emission()[em].Covariance()(0, 0) -
- hmm2.Emission()[em].Covariance()(0, 0), 0.2);
- BOOST_REQUIRE_SMALL(hmm.Emission()[em].Covariance()(0, 1) -
- hmm2.Emission()[em].Covariance()(0, 1), 0.2);
- BOOST_REQUIRE_SMALL(hmm.Emission()[em].Covariance()(1, 0) -
- hmm2.Emission()[em].Covariance()(1, 0), 0.2);
- BOOST_REQUIRE_SMALL(hmm.Emission()[em].Covariance()(1, 1) -
- hmm2.Emission()[em].Covariance()(1, 1), 0.2);
- }
-}
-
-/**
- * Test that HMMs work with Gaussian mixture models. We'll try putting in a
- * simple model by hand and making sure that prediction of observation sequences
- * works correctly.
- */
-BOOST_AUTO_TEST_CASE(GMMHMMPredictTest)
-{
- // We will use two GMMs; one with two components and one with three.
- std::vector<GMM<> > gmms(2);
- gmms[0] = GMM<>(2, 2);
- gmms[0].Weights() = arma::vec("0.75 0.25");
-
- // N([2.25 3.10], [1.00 0.20; 0.20 0.89])
- gmms[0].Means()[0] = arma::vec("4.25 3.10");
- gmms[0].Covariances()[0] = arma::mat("1.00 0.20; 0.20 0.89");
-
- // N([4.10 1.01], [1.00 0.00; 0.00 1.01])
- gmms[0].Means()[1] = arma::vec("7.10 5.01");
- gmms[0].Covariances()[1] = arma::mat("1.00 0.00; 0.00 1.01");
-
- gmms[1] = GMM<>(3, 2);
- gmms[1].Weights() = arma::vec("0.4 0.2 0.4");
-
- gmms[1].Means()[0] = arma::vec("-3.00 -6.12");
- gmms[1].Covariances()[0] = arma::mat("1.00 0.00; 0.00 1.00");
-
- gmms[1].Means()[1] = arma::vec("-4.25 -7.12");
- gmms[1].Covariances()[1] = arma::mat("1.50 0.60; 0.60 1.20");
-
- gmms[1].Means()[2] = arma::vec("-6.15 -2.00");
- gmms[1].Covariances()[2] = arma::mat("1.00 0.80; 0.80 1.00");
-
- // Transition matrix.
- arma::mat trans("0.30 0.50;"
- "0.70 0.50");
-
- // Now build the model.
- HMM<GMM<> > hmm(trans, gmms);
-
- // Make a sequence of observations.
- arma::mat observations(2, 1000);
- arma::Col<size_t> states(1000);
- states[0] = 0;
- observations.col(0) = gmms[0].Random();
-
- for (size_t i = 1; i < 1000; i++)
- {
- double randValue = math::Random();
-
- if (randValue <= trans(0, states[i - 1]))
- states[i] = 0;
- else
- states[i] = 1;
-
- observations.col(i) = gmms[states[i]].Random();
- }
-
- // Run the prediction.
- arma::Col<size_t> predictions;
- hmm.Predict(observations, predictions);
-
- // Check that the predictions were correct.
- for (size_t i = 0; i < 1000; i++)
- BOOST_REQUIRE_EQUAL(predictions[i], states[i]);
-}
-
-/**
- * Test that GMM-based HMMs can train on models correctly using labeled training
- * data.
- */
-BOOST_AUTO_TEST_CASE(GMMHMMLabeledTrainingTest)
-{
- srand(time(NULL));
-
- // We will use two GMMs; one with two components and one with three.
- std::vector<GMM<> > gmms(2, GMM<>(2, 2));
- gmms[0].Weights() = arma::vec("0.3 0.7");
-
- // N([2.25 3.10], [1.00 0.20; 0.20 0.89])
- gmms[0].Means()[0] = arma::vec("4.25 3.10");
- gmms[0].Covariances()[0] = arma::mat("1.00 0.20; 0.20 0.89");
-
- // N([4.10 1.01], [1.00 0.00; 0.00 1.01])
- gmms[0].Means()[1] = arma::vec("7.10 5.01");
- gmms[0].Covariances()[1] = arma::mat("1.00 0.00; 0.00 1.01");
-
- gmms[1].Weights() = arma::vec("0.20 0.80");
-
- gmms[1].Means()[0] = arma::vec("-3.00 -6.12");
- gmms[1].Covariances()[0] = arma::mat("1.00 0.00; 0.00 1.00");
-
- gmms[1].Means()[1] = arma::vec("-4.25 -2.12");
- gmms[1].Covariances()[1] = arma::mat("1.50 0.60; 0.60 1.20");
-
- // Transition matrix.
- arma::mat transMat("0.40 0.60;"
- "0.60 0.40");
-
- // Make a sequence of observations.
- std::vector<arma::mat> observations(5, arma::mat(2, 2500));
- std::vector<arma::Col<size_t> > states(5, arma::Col<size_t>(2500));
- for (size_t obs = 0; obs < 5; obs++)
- {
- states[obs][0] = 0;
- observations[obs].col(0) = gmms[0].Random();
-
- for (size_t i = 1; i < 2500; i++)
- {
- double randValue = (double) rand() / (double) RAND_MAX;
-
- if (randValue <= transMat(0, states[obs][i - 1]))
- states[obs][i] = 0;
- else
- states[obs][i] = 1;
-
- observations[obs].col(i) = gmms[states[obs][i]].Random();
- }
- }
-
- // Set up the GMM for training.
- HMM<GMM<> > hmm(2, GMM<>(2, 2));
-
- // Train the HMM.
- hmm.Train(observations, states);
-
- // Check the results. Use absolute tolerances instead of percentages.
- BOOST_REQUIRE_SMALL(hmm.Transition()(0, 0) - transMat(0, 0), 0.02);
- BOOST_REQUIRE_SMALL(hmm.Transition()(0, 1) - transMat(0, 1), 0.02);
- BOOST_REQUIRE_SMALL(hmm.Transition()(1, 0) - transMat(1, 0), 0.02);
- BOOST_REQUIRE_SMALL(hmm.Transition()(1, 1) - transMat(1, 1), 0.02);
-
- // Now the emission probabilities (the GMMs).
- // We have to sort each GMM for comparison.
- arma::uvec sortedIndices = sort_index(hmm.Emission()[0].Weights());
-
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Weights()[sortedIndices[0]] -
- gmms[0].Weights()[0], 0.08);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Weights()[sortedIndices[1]] -
- gmms[0].Weights()[1], 0.08);
-
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Means()[sortedIndices[0]][0] -
- gmms[0].Means()[0][0], 0.15);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Means()[sortedIndices[0]][1] -
- gmms[0].Means()[0][1], 0.15);
-
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Means()[sortedIndices[1]][0] -
- gmms[0].Means()[1][0], 0.15);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Means()[sortedIndices[1]][1] -
- gmms[0].Means()[1][1], 0.15);
-
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[0]](0, 0) -
- gmms[0].Covariances()[0](0, 0), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[0]](0, 1) -
- gmms[0].Covariances()[0](0, 1), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[0]](1, 0) -
- gmms[0].Covariances()[0](1, 0), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[0]](1, 1) -
- gmms[0].Covariances()[0](1, 1), 0.3);
-
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[1]](0, 0) -
- gmms[0].Covariances()[1](0, 0), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[1]](0, 1) -
- gmms[0].Covariances()[1](0, 1), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[1]](1, 0) -
- gmms[0].Covariances()[1](1, 0), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[1]](1, 1) -
- gmms[0].Covariances()[1](1, 1), 0.3);
-
- // Sort the GMM.
- sortedIndices = sort_index(hmm.Emission()[1].Weights());
-
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Weights()[sortedIndices[0]] -
- gmms[1].Weights()[0], 0.08);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Weights()[sortedIndices[1]] -
- gmms[1].Weights()[1], 0.08);
-
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Means()[sortedIndices[0]][0] -
- gmms[1].Means()[0][0], 0.15);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Means()[sortedIndices[0]][1] -
- gmms[1].Means()[0][1], 0.15);
-
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Means()[sortedIndices[1]][0] -
- gmms[1].Means()[1][0], 0.15);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Means()[sortedIndices[1]][1] -
- gmms[1].Means()[1][1], 0.15);
-
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[0]](0, 0) -
- gmms[1].Covariances()[0](0, 0), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[0]](0, 1) -
- gmms[1].Covariances()[0](0, 1), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[0]](1, 0) -
- gmms[1].Covariances()[0](1, 0), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[0]](1, 1) -
- gmms[1].Covariances()[0](1, 1), 0.3);
-
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[1]](0, 0) -
- gmms[1].Covariances()[1](0, 0), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[1]](0, 1) -
- gmms[1].Covariances()[1](0, 1), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[1]](1, 0) -
- gmms[1].Covariances()[1](1, 0), 0.3);
- BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[1]](1, 1) -
- gmms[1].Covariances()[1](1, 1), 0.3);
-}
-
-BOOST_AUTO_TEST_SUITE_END();
-
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/hmm_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/hmm_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/hmm_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/hmm_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,955 @@
+/**
+ * @file hmm_test.cpp
+ *
+ * Test file for HMMs.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/methods/hmm/hmm.hpp>
+#include <mlpack/methods/gmm/gmm.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::hmm;
+using namespace mlpack::distribution;
+using namespace mlpack::gmm;
+
+BOOST_AUTO_TEST_SUITE(HMMTest);
+
+/**
+ * We will use the simple case proposed by Russell and Norvig in Artificial
+ * Intelligence: A Modern Approach, 2nd Edition, around p.549.
+ */
+BOOST_AUTO_TEST_CASE(SimpleDiscreteHMMTestViterbi)
+{
+ // We have two hidden states: rain/dry. Two emission states: umbrella/no
+ // umbrella.
+ // In this example, the transition matrix is
+ // rain dry
+ // [[0.7 0.3] rain
+ // [0.3 0.7]] dry
+ // and the emission probability is
+ // rain dry
+ // [[0.9 0.2] umbrella
+ // [0.1 0.8]] no umbrella
+ arma::mat transition("0.7 0.3; 0.3 0.7");
+ std::vector<DiscreteDistribution> emission(2);
+ emission[0] = DiscreteDistribution("0.9 0.2");
+ emission[1] = DiscreteDistribution("0.1 0.8");
+
+ HMM<DiscreteDistribution> hmm(transition, emission);
+
+ // Now let's take a sequence and find what the most likely state is.
+ // We'll use the sequence [U U N U U] (U = umbrella, N = no umbrella) like on
+ // p. 547.
+ arma::mat observation = "0 0 1 0 0";
+ arma::Col<size_t> states;
+ hmm.Predict(observation, states);
+
+ // Check each state.
+ BOOST_REQUIRE_EQUAL(states[0], 0); // Rain.
+ BOOST_REQUIRE_EQUAL(states[1], 0); // Rain.
+ BOOST_REQUIRE_EQUAL(states[2], 1); // No rain.
+ BOOST_REQUIRE_EQUAL(states[3], 0); // Rain.
+ BOOST_REQUIRE_EQUAL(states[4], 0); // Rain.
+}
+
+/**
+ * This example is from Borodovsky & Ekisheva, p. 80-81. It is just slightly
+ * more complex.
+ */
+BOOST_AUTO_TEST_CASE(BorodovskyHMMTestViterbi)
+{
+ // Two hidden states: H (high GC content) and L (low GC content), as well as a
+ // start state.
+ arma::mat transition("0.0 0.0 0.0;"
+ "0.5 0.5 0.4;"
+ "0.5 0.5 0.6");
+ // Four emission states: A, C, G, T. Start state doesn't emit...
+ std::vector<DiscreteDistribution> emission(3);
+ emission[0] = DiscreteDistribution("0.25 0.25 0.25 0.25");
+ emission[1] = DiscreteDistribution("0.20 0.30 0.30 0.20");
+ emission[2] = DiscreteDistribution("0.30 0.20 0.20 0.30");
+
+ HMM<DiscreteDistribution> hmm(transition, emission);
+
+ // GGCACTGAA.
+ arma::mat observation("2 2 1 0 1 3 2 0 0");
+ arma::Col<size_t> states;
+ hmm.Predict(observation, states);
+
+ // Most probable path is HHHLLLLLL.
+ BOOST_REQUIRE_EQUAL(states[0], 1);
+ BOOST_REQUIRE_EQUAL(states[1], 1);
+ BOOST_REQUIRE_EQUAL(states[2], 1);
+ BOOST_REQUIRE_EQUAL(states[3], 2);
+ // This could actually be one of two states (equal probability).
+ BOOST_REQUIRE((states[4] == 1) || (states[4] == 2));
+ BOOST_REQUIRE_EQUAL(states[5], 2);
+ // This could also be one of two states.
+ BOOST_REQUIRE((states[6] == 1) || (states[6] == 2));
+ BOOST_REQUIRE_EQUAL(states[7], 2);
+ BOOST_REQUIRE_EQUAL(states[8], 2);
+}
+
+/**
+ * Ensure that the forward-backward algorithm is correct.
+ */
+BOOST_AUTO_TEST_CASE(ForwardBackwardTwoState)
+{
+ arma::mat obs("3 3 2 1 1 1 1 3 3 1");
+
+ arma::mat transition("0.1 0.9; 0.4 0.6");
+ std::vector<DiscreteDistribution> emis(2);
+ emis[0] = DiscreteDistribution("0.85 0.15 0.00 0.00");
+ emis[1] = DiscreteDistribution("0.00 0.00 0.50 0.50");
+
+ HMM<DiscreteDistribution> hmm(transition, emis);
+
+ // Now check we are getting the same results as MATLAB for this sequence.
+ arma::mat stateProb;
+ arma::mat forwardProb;
+ arma::mat backwardProb;
+ arma::vec scales;
+
+ double log = hmm.Estimate(obs, stateProb, forwardProb, backwardProb, scales);
+
+ // All values obtained from MATLAB hmmdecode().
+ BOOST_REQUIRE_CLOSE(log, -23.4349, 1e-3);
+
+ BOOST_REQUIRE_SMALL(stateProb(0, 0), 1e-5);
+ BOOST_REQUIRE_CLOSE(stateProb(1, 0), 1.0, 1e-5);
+ BOOST_REQUIRE_SMALL(stateProb(0, 1), 1e-5);
+ BOOST_REQUIRE_CLOSE(stateProb(1, 1), 1.0, 1e-5);
+ BOOST_REQUIRE_SMALL(stateProb(0, 2), 1e-5);
+ BOOST_REQUIRE_CLOSE(stateProb(1, 2), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(stateProb(0, 3), 1.0, 1e-5);
+ BOOST_REQUIRE_SMALL(stateProb(1, 3), 1e-5);
+ BOOST_REQUIRE_CLOSE(stateProb(0, 4), 1.0, 1e-5);
+ BOOST_REQUIRE_SMALL(stateProb(1, 4), 1e-5);
+ BOOST_REQUIRE_CLOSE(stateProb(0, 5), 1.0, 1e-5);
+ BOOST_REQUIRE_SMALL(stateProb(1, 5), 1e-5);
+ BOOST_REQUIRE_CLOSE(stateProb(0, 6), 1.0, 1e-5);
+ BOOST_REQUIRE_SMALL(stateProb(1, 6), 1e-5);
+ BOOST_REQUIRE_SMALL(stateProb(0, 7), 1e-5);
+ BOOST_REQUIRE_CLOSE(stateProb(1, 7), 1.0, 1e-5);
+ BOOST_REQUIRE_SMALL(stateProb(0, 8), 1e-5);
+ BOOST_REQUIRE_CLOSE(stateProb(1, 8), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(stateProb(0, 9), 1.0, 1e-5);
+ BOOST_REQUIRE_SMALL(stateProb(1, 9), 1e-5);
+}
+
+/**
+ * In this example we try to estimate the transmission and emission matrices
+ * based on some observations. We use the simplest possible model.
+ */
+BOOST_AUTO_TEST_CASE(SimplestBaumWelchDiscreteHMM)
+{
+ // Don't yet require a useful distribution. 1 state, 1 emission.
+ HMM<DiscreteDistribution> hmm(1, DiscreteDistribution(1));
+
+ std::vector<arma::mat> observations;
+ // Different lengths for each observation sequence.
+ observations.push_back("0 0 0 0 0 0 0 0"); // 8 zeros.
+ observations.push_back("0 0 0 0 0 0 0"); // 7 zeros.
+ observations.push_back("0 0 0 0 0 0 0 0 0 0 0 0"); // 12 zeros.
+ observations.push_back("0 0 0 0 0 0 0 0 0 0"); // 10 zeros.
+
+ hmm.Train(observations);
+
+ BOOST_REQUIRE_CLOSE(hmm.Emission()[0].Probability("0"), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(hmm.Transition()(0, 0), 1.0, 1e-5);
+}
+
+/**
+ * A slightly more complex model to estimate.
+ */
+BOOST_AUTO_TEST_CASE(SimpleBaumWelchDiscreteHMM)
+{
+ HMM<DiscreteDistribution> hmm(1, 2); // 1 state, 2 emissions.
+ // Randomize the emission matrix.
+ hmm.Emission()[0].Probabilities() = arma::randu<arma::vec>(2);
+ hmm.Emission()[0].Probabilities() /= accu(hmm.Emission()[0].Probabilities());
+
+ // P(each emission) = 0.5.
+ // I've been careful to make P(first emission = 0) = P(first emission = 1).
+ std::vector<arma::mat> observations;
+ observations.push_back("0 1 0 1 0 1 0 1 0 1 0 1");
+ observations.push_back("0 0 0 0 0 0 1 1 1 1 1 1");
+ observations.push_back("1 1 1 1 1 1 0 0 0 0 0 0");
+ observations.push_back("1 1 1 0 0 0 1 1 1 0 0 0");
+ observations.push_back("0 0 1 1 0 0 0 0 1 1 1 1");
+ observations.push_back("1 1 1 0 0 0 1 1 1 0 0 0");
+ observations.push_back("0 1 0 1 0 1 0 1 0 1 0 1");
+ observations.push_back("0 0 0 0 0 0 1 1 1 1 1 1");
+ observations.push_back("1 1 1 1 1 1 0 0 0 0 0 0");
+ observations.push_back("1 1 1 0 0 0 1 1 1 0 0 0");
+ observations.push_back("0 0 1 1 0 0 0 0 1 1 1 1");
+ observations.push_back("1 1 1 0 0 0 1 1 1 0 0 0");
+
+ hmm.Train(observations);
+
+ BOOST_REQUIRE_CLOSE(hmm.Emission()[0].Probability("0"), 0.5, 1e-5);
+ BOOST_REQUIRE_CLOSE(hmm.Emission()[0].Probability("1"), 0.5, 1e-5);
+ BOOST_REQUIRE_CLOSE(hmm.Transition()(0, 0), 1.0, 1e-5);
+}
+
+/**
+ * Increasing complexity, but still simple; 4 emissions, 2 states; the state can
+ * be determined directly by the emission.
+ */
+BOOST_AUTO_TEST_CASE(SimpleBaumWelchDiscreteHMM_2)
+{
+ HMM<DiscreteDistribution> hmm(2, DiscreteDistribution(4));
+
+ // A little bit of obfuscation to the solution.
+ hmm.Transition() = arma::mat("0.1 0.4; 0.9 0.6");
+ hmm.Emission()[0].Probabilities() = "0.85 0.15 0.00 0.00";
+ hmm.Emission()[1].Probabilities() = "0.00 0.00 0.50 0.50";
+
+ // True emission matrix:
+ // [[0.4 0 ]
+ // [0.6 0 ]
+ // [0 0.2]
+ // [0 0.8]]
+
+ // True transmission matrix:
+ // [[0.5 0.5]
+ // [0.5 0.5]]
+
+ // Generate observations randomly by hand. This is kinda ugly, but it works.
+ std::vector<arma::mat> observations;
+ size_t obsNum = 250; // Number of observations.
+ size_t obsLen = 500; // Number of elements in each observation.
+ for (size_t i = 0; i < obsNum; i++)
+ {
+ arma::mat observation(1, obsLen);
+
+ size_t state = 0;
+ size_t emission = 0;
+
+ for (size_t obs = 0; obs < obsLen; obs++)
+ {
+ // See if state changed.
+ double r = math::Random();
+
+ if (r <= 0.5)
+ state = 0;
+ else
+ state = 1;
+
+ // Now set the observation.
+ r = math::Random();
+
+ switch (state)
+ {
+ // case 0 is not possible.
+ case 0:
+ if (r <= 0.4)
+ emission = 0;
+ else
+ emission = 1;
+ break;
+ case 1:
+ if (r <= 0.2)
+ emission = 2;
+ else
+ emission = 3;
+ break;
+ }
+
+ observation(0, obs) = emission;
+ }
+
+ observations.push_back(observation);
+ }
+
+ hmm.Train(observations);
+
+ // Only require 2.5% tolerance, because this is a little fuzzier.
+ BOOST_REQUIRE_CLOSE(hmm.Transition()(0, 0), 0.5, 2.5);
+ BOOST_REQUIRE_CLOSE(hmm.Transition()(1, 0), 0.5, 2.5);
+ BOOST_REQUIRE_CLOSE(hmm.Transition()(0, 1), 0.5, 2.5);
+ BOOST_REQUIRE_CLOSE(hmm.Transition()(1, 1), 0.5, 2.5);
+
+ BOOST_REQUIRE_CLOSE(hmm.Emission()[0].Probability("0"), 0.4, 2.5);
+ BOOST_REQUIRE_CLOSE(hmm.Emission()[0].Probability("1"), 0.6, 2.5);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Probability("2"), 2.5);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Probability("3"), 2.5);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Probability("0"), 2.5);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Probability("1"), 2.5);
+ BOOST_REQUIRE_CLOSE(hmm.Emission()[1].Probability("2"), 0.2, 2.5);
+ BOOST_REQUIRE_CLOSE(hmm.Emission()[1].Probability("3"), 0.8, 2.5);
+}
+
+BOOST_AUTO_TEST_CASE(DiscreteHMMLabeledTrainTest)
+{
+ // Generate a random Markov model with 3 hidden states and 6 observations.
+ arma::mat transition;
+ std::vector<DiscreteDistribution> emission(3);
+
+ transition.randu(3, 3);
+ emission[0].Probabilities() = arma::randu<arma::vec>(6);
+ emission[0].Probabilities() /= accu(emission[0].Probabilities());
+ emission[1].Probabilities() = arma::randu<arma::vec>(6);
+ emission[1].Probabilities() /= accu(emission[1].Probabilities());
+ emission[2].Probabilities() = arma::randu<arma::vec>(6);
+ emission[2].Probabilities() /= accu(emission[2].Probabilities());
+
+ // Normalize so they we have a correct transition matrix.
+ for (size_t col = 0; col < 3; col++)
+ transition.col(col) /= accu(transition.col(col));
+
+ // Now generate sequences.
+ size_t obsNum = 250;
+ size_t obsLen = 800;
+
+ std::vector<arma::mat> observations(obsNum);
+ std::vector<arma::Col<size_t> > states(obsNum);
+
+ for (size_t n = 0; n < obsNum; n++)
+ {
+ observations[n].set_size(1, obsLen);
+ states[n].set_size(obsLen);
+
+ // Random starting state.
+ states[n][0] = math::RandInt(3);
+
+ // Random starting observation.
+ observations[n].col(0) = emission[states[n][0]].Random();
+
+ // Now the rest of the observations.
+ for (size_t t = 1; t < obsLen; t++)
+ {
+ // Choose random number for state transition.
+ double state = math::Random();
+
+ // Decide next state.
+ double sumProb = 0;
+ for (size_t st = 0; st < 3; st++)
+ {
+ sumProb += transition(st, states[n][t - 1]);
+ if (sumProb >= state)
+ {
+ states[n][t] = st;
+ break;
+ }
+ }
+
+ // Decide observation.
+ observations[n].col(t) = emission[states[n][t]].Random();
+ }
+ }
+
+ // Now that our data is generated, we give the HMM the labeled data to train
+ // on.
+ HMM<DiscreteDistribution> hmm(3, DiscreteDistribution(6));
+
+ hmm.Train(observations, states);
+
+ // We can't use % tolerance here because percent error increases as the actual
+ // value gets very small. So, instead, we just ensure that every value is no
+ // more than 0.009 away from the actual value.
+ for (size_t row = 0; row < hmm.Transition().n_rows; row++)
+ for (size_t col = 0; col < hmm.Transition().n_cols; col++)
+ BOOST_REQUIRE_SMALL(hmm.Transition()(row, col) - transition(row, col),
+ 0.009);
+
+ for (size_t col = 0; col < hmm.Emission().size(); col++)
+ {
+ for (size_t row = 0; row < hmm.Emission()[col].Probabilities().n_elem;
+ row++)
+ {
+ arma::vec obs(1);
+ obs[0] = row;
+ BOOST_REQUIRE_SMALL(hmm.Emission()[col].Probability(obs) -
+ emission[col].Probability(obs), 0.009);
+ }
+ }
+}
+
+/**
+ * Make sure the Generate() function works for a uniformly distributed HMM;
+ * we'll take many samples just to make sure.
+ */
+BOOST_AUTO_TEST_CASE(DiscreteHMMSimpleGenerateTest)
+{
+ // Very simple HMM. 4 emissions with equal probability and 2 states with
+ // equal probability. The default transition and emission matrices satisfy
+ // this property.
+ HMM<DiscreteDistribution> hmm(2, DiscreteDistribution(4));
+
+ // Now generate a really, really long sequence.
+ arma::mat dataSeq;
+ arma::Col<size_t> stateSeq;
+
+ hmm.Generate(100000, dataSeq, stateSeq);
+
+ // Now find the empirical probabilities of each state.
+ arma::vec emissionProb(4);
+ arma::vec stateProb(2);
+ emissionProb.zeros();
+ stateProb.zeros();
+ for (size_t i = 0; i < 100000; i++)
+ {
+ emissionProb[(size_t) dataSeq.col(i)[0] + 0.5]++;
+ stateProb[stateSeq[i]]++;
+ }
+
+ // Normalize so these are probabilities.
+ emissionProb /= accu(emissionProb);
+ stateProb /= accu(stateProb);
+
+ // Now check that the probabilities are right. 2% tolerance.
+ BOOST_REQUIRE_CLOSE(emissionProb[0], 0.25, 2.0);
+ BOOST_REQUIRE_CLOSE(emissionProb[1], 0.25, 2.0);
+ BOOST_REQUIRE_CLOSE(emissionProb[2], 0.25, 2.0);
+ BOOST_REQUIRE_CLOSE(emissionProb[3], 0.25, 2.0);
+
+ BOOST_REQUIRE_CLOSE(stateProb[0], 0.50, 2.0);
+ BOOST_REQUIRE_CLOSE(stateProb[1], 0.50, 2.0);
+}
+
+/**
+ * More complex test for Generate().
+ */
+BOOST_AUTO_TEST_CASE(DiscreteHMMGenerateTest)
+{
+ // 6 emissions, 4 states. Random transition and emission probability.
+ arma::mat transition(4, 4);
+ std::vector<DiscreteDistribution> emission(4);
+ emission[0].Probabilities() = arma::randu<arma::vec>(6);
+ emission[0].Probabilities() /= accu(emission[0].Probabilities());
+ emission[1].Probabilities() = arma::randu<arma::vec>(6);
+ emission[1].Probabilities() /= accu(emission[1].Probabilities());
+ emission[2].Probabilities() = arma::randu<arma::vec>(6);
+ emission[2].Probabilities() /= accu(emission[2].Probabilities());
+ emission[3].Probabilities() = arma::randu<arma::vec>(6);
+ emission[3].Probabilities() /= accu(emission[3].Probabilities());
+
+ transition.randu();
+
+ // Normalize matrix.
+ for (size_t col = 0; col < 4; col++)
+ transition.col(col) /= accu(transition.col(col));
+
+ // Create HMM object.
+ HMM<DiscreteDistribution> hmm(transition, emission);
+
+ // We'll create a bunch of sequences.
+ int numSeq = 400;
+ int numObs = 3000;
+ std::vector<arma::mat> sequences(numSeq);
+ std::vector<arma::Col<size_t> > states(numSeq);
+ for (int i = 0; i < numSeq; i++)
+ {
+ // Random starting state.
+ size_t startState = math::RandInt(4);
+
+ hmm.Generate(numObs, sequences[i], states[i], startState);
+ }
+
+ // Now we will calculate the full probabilities.
+ HMM<DiscreteDistribution> hmm2(4, 6);
+ hmm2.Train(sequences, states);
+
+ // Check that training gives the same result. Exact tolerance of 0.005.
+ for (size_t row = 0; row < 4; row++)
+ for (size_t col = 0; col < 4; col++)
+ BOOST_REQUIRE_SMALL(hmm.Transition()(row, col) -
+ hmm2.Transition()(row, col), 0.005);
+
+ for (size_t row = 0; row < 6; row++)
+ {
+ arma::vec obs(1);
+ obs[0] = row;
+ for (size_t col = 0; col < 4; col++)
+ {
+ BOOST_REQUIRE_SMALL(hmm.Emission()[col].Probability(obs) -
+ hmm2.Emission()[col].Probability(obs), 0.005);
+ }
+ }
+}
+
+BOOST_AUTO_TEST_CASE(DiscreteHMMLogLikelihoodTest)
+{
+ // Create a simple HMM with three states and four emissions.
+ arma::mat transition("0.5 0.0 0.1;"
+ "0.2 0.6 0.2;"
+ "0.3 0.4 0.7");
+ std::vector<DiscreteDistribution> emission(3);
+ emission[0].Probabilities() = "0.75 0.25 0.00 0.00";
+ emission[1].Probabilities() = "0.00 0.25 0.25 0.50";
+ emission[2].Probabilities() = "0.10 0.40 0.40 0.10";
+
+ HMM<DiscreteDistribution> hmm(transition, emission);
+
+ // Now generate some sequences and check that the log-likelihood is the same
+ // as MATLAB gives for this HMM.
+ BOOST_REQUIRE_CLOSE(hmm.LogLikelihood("0 1 2 3"), -4.9887223949, 1e-5);
+ BOOST_REQUIRE_CLOSE(hmm.LogLikelihood("1 2 0 0"), -6.0288487077, 1e-5);
+ BOOST_REQUIRE_CLOSE(hmm.LogLikelihood("3 3 3 3"), -5.5544000018, 1e-5);
+ BOOST_REQUIRE_CLOSE(hmm.LogLikelihood("0 2 2 1 2 3 0 0 1 3 1 0 0 3 1 2 2"),
+ -24.51556128368, 1e-5);
+}
+
+/**
+ * A simple test to make sure HMMs with Gaussian output distributions work.
+ */
+BOOST_AUTO_TEST_CASE(GaussianHMMSimpleTest)
+{
+ // We'll have two Gaussians, far away from each other, one corresponding to
+ // each state.
+ // E(0) ~ N([ 5.0 5.0], eye(2)).
+ // E(1) ~ N([-5.0 -5.0], eye(2)).
+ // The transition matrix is simple:
+ // T = [[0.75 0.25]
+ // [0.25 0.75]]
+ GaussianDistribution g1("5.0 5.0", "1.0 0.0; 0.0 1.0");
+ GaussianDistribution g2("-5.0 -5.0", "1.0 0.0; 0.0 1.0");
+
+ arma::mat transition("0.75 0.25; 0.25 0.75");
+
+ std::vector<GaussianDistribution> emission;
+ emission.push_back(g1);
+ emission.push_back(g2);
+
+ HMM<GaussianDistribution> hmm(transition, emission);
+
+ // Now, generate some sequences.
+ arma::mat observations(2, 1000);
+ arma::Col<size_t> classes(1000);
+
+ // 1000-observations sequence.
+ classes[0] = 0;
+ observations.col(0) = g1.Random();
+ for (size_t i = 1; i < 1000; i++)
+ {
+ double randValue = math::Random();
+
+ if (randValue > 0.75) // Then we change state.
+ classes[i] = (classes[i - 1] + 1) % 2;
+ else
+ classes[i] = classes[i - 1];
+
+ if (classes[i] == 0)
+ observations.col(i) = g1.Random();
+ else
+ observations.col(i) = g2.Random();
+ }
+
+ // Now predict the sequence.
+ arma::Col<size_t> predictedClasses;
+ arma::mat stateProb;
+
+ hmm.Predict(observations, predictedClasses);
+ hmm.Estimate(observations, stateProb);
+
+ // Check that each prediction is right.
+ for (size_t i = 0; i < 1000; i++)
+ {
+ BOOST_REQUIRE_EQUAL(predictedClasses[i], classes[i]);
+
+ // The probability of the wrong class should be infinitesimal.
+ BOOST_REQUIRE_SMALL(stateProb((classes[i] + 1) % 2, i), 0.001);
+ }
+}
+
+/**
+ * Ensure that Gaussian HMMs can be trained properly, for the labeled training
+ * case and also for the unlabeled training case.
+ */
+BOOST_AUTO_TEST_CASE(GaussianHMMTrainTest)
+{
+ // Four emission Gaussians and three internal states. The goal is to estimate
+ // the transition matrix correctly, and each distribution correctly.
+ std::vector<GaussianDistribution> emission;
+ emission.push_back(GaussianDistribution("0.0 0.0 0.0", "1.0 0.2 0.2;"
+ "0.2 1.5 0.0;"
+ "0.2 0.0 1.1"));
+ emission.push_back(GaussianDistribution("2.0 1.0 5.0", "0.7 0.3 0.0;"
+ "0.3 2.6 0.0;"
+ "0.0 0.0 1.0"));
+ emission.push_back(GaussianDistribution("5.0 0.0 0.5", "1.0 0.0 0.0;"
+ "0.0 1.0 0.0;"
+ "0.0 0.0 1.0"));
+
+ arma::mat transition("0.3 0.5 0.7;"
+ "0.3 0.4 0.1;"
+ "0.4 0.1 0.2");
+
+ // Now generate observations.
+ std::vector<arma::mat> observations(100);
+ std::vector<arma::Col<size_t> > states(100);
+
+ for (size_t obs = 0; obs < 100; obs++)
+ {
+ observations[obs].set_size(3, 1000);
+ states[obs].set_size(1000);
+
+ // Always start in state zero.
+ states[obs][0] = 0;
+ observations[obs].col(0) = emission[0].Random();
+
+ for (size_t t = 1; t < 1000; t++)
+ {
+ // Choose the state.
+ double randValue = math::Random();
+ double probSum = 0;
+ for (size_t state = 0; state < 3; state++)
+ {
+ probSum += transition(state, states[obs][t - 1]);
+ if (probSum >= randValue)
+ {
+ states[obs][t] = state;
+ break;
+ }
+ }
+
+ // Now choose the emission.
+ observations[obs].col(t) = emission[states[obs][t]].Random();
+ }
+ }
+
+ // Now that the data is generated, train the HMM.
+ HMM<GaussianDistribution> hmm(3, GaussianDistribution(3));
+
+ hmm.Train(observations, states);
+
+ // We use an absolute tolerance of 0.01 for the transition matrices.
+ // Check that the transition matrix is correct.
+ for (size_t row = 0; row < 3; row++)
+ for (size_t col = 0; col < 3; col++)
+ BOOST_REQUIRE_SMALL(transition(row, col) - hmm.Transition()(row, col),
+ 0.01);
+
+ // Check that each distribution is correct.
+ for (size_t dist = 0; dist < 3; dist++)
+ {
+ // Check that the mean is correct. Absolute tolerance of 0.04.
+ for (size_t dim = 0; dim < 3; dim++)
+ BOOST_REQUIRE_SMALL(hmm.Emission()[dist].Mean()(dim) -
+ emission[dist].Mean()(dim), 0.04);
+
+ // Check that the covariance is correct. Absolute tolerance of 0.075.
+ for (size_t row = 0; row < 3; row++)
+ for (size_t col = 0; col < 3; col++)
+ BOOST_REQUIRE_SMALL(hmm.Emission()[dist].Covariance()(row, col) -
+ emission[dist].Covariance()(row, col), 0.075);
+ }
+
+ // Now let's try it all again, but this time, unlabeled. Everything will fail
+ // if we don't have a decent guess at the Gaussians, so we'll take a "poor"
+ // guess at it ourselves. I won't use K-Means because we can't afford to add
+ // the instability of that to our test. We'll leave the covariances as the
+ // identity.
+ HMM<GaussianDistribution> hmm2(3, GaussianDistribution(3));
+ hmm2.Emission()[0].Mean() = "0.3 -0.2 0.1"; // Actual: [0 0 0].
+ hmm2.Emission()[1].Mean() = "1.0 1.4 3.2"; // Actual: [2 1 5].
+ hmm2.Emission()[2].Mean() = "3.1 -0.2 6.1"; // Actual: [5 0 5].
+
+ // We'll only use 20 observation sequences to try and keep training time
+ // shorter.
+ observations.resize(20);
+
+ hmm.Train(observations);
+
+ // The tolerances are increased because there is more error in unlabeled
+ // training; we use an absolute tolerance of 0.03 for the transition matrices.
+ // Check that the transition matrix is correct.
+ for (size_t row = 0; row < 3; row++)
+ for (size_t col = 0; col < 3; col++)
+ BOOST_REQUIRE_SMALL(transition(row, col) - hmm.Transition()(row, col),
+ 0.03);
+
+ // Check that each distribution is correct.
+ for (size_t dist = 0; dist < 3; dist++)
+ {
+ // Check that the mean is correct. Absolute tolerance of 0.09.
+ for (size_t dim = 0; dim < 3; dim++)
+ BOOST_REQUIRE_SMALL(hmm.Emission()[dist].Mean()(dim) -
+ emission[dist].Mean()(dim), 0.09);
+
+ // Check that the covariance is correct. Absolute tolerance of 0.12.
+ for (size_t row = 0; row < 3; row++)
+ for (size_t col = 0; col < 3; col++)
+ BOOST_REQUIRE_SMALL(hmm.Emission()[dist].Covariance()(row, col) -
+ emission[dist].Covariance()(row, col), 0.12);
+ }
+}
+
+/**
+ * Make sure that a random sequence generated by a Gaussian HMM fits the
+ * distribution correctly.
+ */
+BOOST_AUTO_TEST_CASE(GaussianHMMGenerateTest)
+{
+ // Our distribution will have three two-dimensional output Gaussians.
+ HMM<GaussianDistribution> hmm(3, GaussianDistribution(2));
+ hmm.Transition() = arma::mat("0.4 0.6 0.8; 0.2 0.2 0.1; 0.4 0.2 0.1");
+ hmm.Emission()[0] = GaussianDistribution("0.0 0.0", "1.0 0.0; 0.0 1.0");
+ hmm.Emission()[1] = GaussianDistribution("2.0 2.0", "1.0 0.5; 0.5 1.2");
+ hmm.Emission()[2] = GaussianDistribution("-2.0 1.0", "2.0 0.1; 0.1 1.0");
+
+ // Now we will generate a long sequence.
+ std::vector<arma::mat> observations(1);
+ std::vector<arma::Col<size_t> > states(1);
+
+ // Start in state 1 (no reason).
+ hmm.Generate(10000, observations[0], states[0], 1);
+
+ HMM<GaussianDistribution> hmm2(3, GaussianDistribution(2));
+
+ // Now estimate the HMM from the generated sequence.
+ hmm2.Train(observations, states);
+
+ // Check that the estimated matrices are the same.
+ for (size_t row = 0; row < 3; row++)
+ for (size_t col = 0; col < 3; col++)
+ BOOST_REQUIRE_SMALL(hmm.Transition()(row, col) - hmm2.Transition()(row,
+ col), 0.03);
+
+ // Check that each Gaussian is the same.
+ for (size_t em = 0; em < 3; em++)
+ {
+ // Check that the mean is the same.
+ BOOST_REQUIRE_SMALL(hmm.Emission()[em].Mean()(0) -
+ hmm2.Emission()[em].Mean()(0), 0.09);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[em].Mean()(1) -
+ hmm2.Emission()[em].Mean()(1), 0.09);
+
+ // Check that the covariances are the same.
+ BOOST_REQUIRE_SMALL(hmm.Emission()[em].Covariance()(0, 0) -
+ hmm2.Emission()[em].Covariance()(0, 0), 0.2);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[em].Covariance()(0, 1) -
+ hmm2.Emission()[em].Covariance()(0, 1), 0.2);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[em].Covariance()(1, 0) -
+ hmm2.Emission()[em].Covariance()(1, 0), 0.2);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[em].Covariance()(1, 1) -
+ hmm2.Emission()[em].Covariance()(1, 1), 0.2);
+ }
+}
+
+/**
+ * Test that HMMs work with Gaussian mixture models. We'll try putting in a
+ * simple model by hand and making sure that prediction of observation sequences
+ * works correctly.
+ */
+BOOST_AUTO_TEST_CASE(GMMHMMPredictTest)
+{
+ // We will use two GMMs; one with two components and one with three.
+ std::vector<GMM<> > gmms(2);
+ gmms[0] = GMM<>(2, 2);
+ gmms[0].Weights() = arma::vec("0.75 0.25");
+
+ // N([2.25 3.10], [1.00 0.20; 0.20 0.89])
+ gmms[0].Means()[0] = arma::vec("4.25 3.10");
+ gmms[0].Covariances()[0] = arma::mat("1.00 0.20; 0.20 0.89");
+
+ // N([4.10 1.01], [1.00 0.00; 0.00 1.01])
+ gmms[0].Means()[1] = arma::vec("7.10 5.01");
+ gmms[0].Covariances()[1] = arma::mat("1.00 0.00; 0.00 1.01");
+
+ gmms[1] = GMM<>(3, 2);
+ gmms[1].Weights() = arma::vec("0.4 0.2 0.4");
+
+ gmms[1].Means()[0] = arma::vec("-3.00 -6.12");
+ gmms[1].Covariances()[0] = arma::mat("1.00 0.00; 0.00 1.00");
+
+ gmms[1].Means()[1] = arma::vec("-4.25 -7.12");
+ gmms[1].Covariances()[1] = arma::mat("1.50 0.60; 0.60 1.20");
+
+ gmms[1].Means()[2] = arma::vec("-6.15 -2.00");
+ gmms[1].Covariances()[2] = arma::mat("1.00 0.80; 0.80 1.00");
+
+ // Transition matrix.
+ arma::mat trans("0.30 0.50;"
+ "0.70 0.50");
+
+ // Now build the model.
+ HMM<GMM<> > hmm(trans, gmms);
+
+ // Make a sequence of observations.
+ arma::mat observations(2, 1000);
+ arma::Col<size_t> states(1000);
+ states[0] = 0;
+ observations.col(0) = gmms[0].Random();
+
+ for (size_t i = 1; i < 1000; i++)
+ {
+ double randValue = math::Random();
+
+ if (randValue <= trans(0, states[i - 1]))
+ states[i] = 0;
+ else
+ states[i] = 1;
+
+ observations.col(i) = gmms[states[i]].Random();
+ }
+
+ // Run the prediction.
+ arma::Col<size_t> predictions;
+ hmm.Predict(observations, predictions);
+
+ // Check that the predictions were correct.
+ for (size_t i = 0; i < 1000; i++)
+ BOOST_REQUIRE_EQUAL(predictions[i], states[i]);
+}
+
+/**
+ * Test that GMM-based HMMs can train on models correctly using labeled training
+ * data.
+ */
+BOOST_AUTO_TEST_CASE(GMMHMMLabeledTrainingTest)
+{
+ srand(time(NULL));
+
+ // We will use two GMMs; one with two components and one with three.
+ std::vector<GMM<> > gmms(2, GMM<>(2, 2));
+ gmms[0].Weights() = arma::vec("0.3 0.7");
+
+ // N([2.25 3.10], [1.00 0.20; 0.20 0.89])
+ gmms[0].Means()[0] = arma::vec("4.25 3.10");
+ gmms[0].Covariances()[0] = arma::mat("1.00 0.20; 0.20 0.89");
+
+ // N([4.10 1.01], [1.00 0.00; 0.00 1.01])
+ gmms[0].Means()[1] = arma::vec("7.10 5.01");
+ gmms[0].Covariances()[1] = arma::mat("1.00 0.00; 0.00 1.01");
+
+ gmms[1].Weights() = arma::vec("0.20 0.80");
+
+ gmms[1].Means()[0] = arma::vec("-3.00 -6.12");
+ gmms[1].Covariances()[0] = arma::mat("1.00 0.00; 0.00 1.00");
+
+ gmms[1].Means()[1] = arma::vec("-4.25 -2.12");
+ gmms[1].Covariances()[1] = arma::mat("1.50 0.60; 0.60 1.20");
+
+ // Transition matrix.
+ arma::mat transMat("0.40 0.60;"
+ "0.60 0.40");
+
+ // Make a sequence of observations.
+ std::vector<arma::mat> observations(5, arma::mat(2, 2500));
+ std::vector<arma::Col<size_t> > states(5, arma::Col<size_t>(2500));
+ for (size_t obs = 0; obs < 5; obs++)
+ {
+ states[obs][0] = 0;
+ observations[obs].col(0) = gmms[0].Random();
+
+ for (size_t i = 1; i < 2500; i++)
+ {
+ double randValue = (double) rand() / (double) RAND_MAX;
+
+ if (randValue <= transMat(0, states[obs][i - 1]))
+ states[obs][i] = 0;
+ else
+ states[obs][i] = 1;
+
+ observations[obs].col(i) = gmms[states[obs][i]].Random();
+ }
+ }
+
+ // Set up the GMM for training.
+ HMM<GMM<> > hmm(2, GMM<>(2, 2));
+
+ // Train the HMM.
+ hmm.Train(observations, states);
+
+ // Check the results. Use absolute tolerances instead of percentages.
+ BOOST_REQUIRE_SMALL(hmm.Transition()(0, 0) - transMat(0, 0), 0.02);
+ BOOST_REQUIRE_SMALL(hmm.Transition()(0, 1) - transMat(0, 1), 0.02);
+ BOOST_REQUIRE_SMALL(hmm.Transition()(1, 0) - transMat(1, 0), 0.02);
+ BOOST_REQUIRE_SMALL(hmm.Transition()(1, 1) - transMat(1, 1), 0.02);
+
+ // Now the emission probabilities (the GMMs).
+ // We have to sort each GMM for comparison.
+ arma::uvec sortedIndices = sort_index(hmm.Emission()[0].Weights());
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Weights()[sortedIndices[0]] -
+ gmms[0].Weights()[0], 0.08);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Weights()[sortedIndices[1]] -
+ gmms[0].Weights()[1], 0.08);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Means()[sortedIndices[0]][0] -
+ gmms[0].Means()[0][0], 0.15);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Means()[sortedIndices[0]][1] -
+ gmms[0].Means()[0][1], 0.15);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Means()[sortedIndices[1]][0] -
+ gmms[0].Means()[1][0], 0.15);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Means()[sortedIndices[1]][1] -
+ gmms[0].Means()[1][1], 0.15);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[0]](0, 0) -
+ gmms[0].Covariances()[0](0, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[0]](0, 1) -
+ gmms[0].Covariances()[0](0, 1), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[0]](1, 0) -
+ gmms[0].Covariances()[0](1, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[0]](1, 1) -
+ gmms[0].Covariances()[0](1, 1), 0.3);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[1]](0, 0) -
+ gmms[0].Covariances()[1](0, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[1]](0, 1) -
+ gmms[0].Covariances()[1](0, 1), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[1]](1, 0) -
+ gmms[0].Covariances()[1](1, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[0].Covariances()[sortedIndices[1]](1, 1) -
+ gmms[0].Covariances()[1](1, 1), 0.3);
+
+ // Sort the GMM.
+ sortedIndices = sort_index(hmm.Emission()[1].Weights());
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Weights()[sortedIndices[0]] -
+ gmms[1].Weights()[0], 0.08);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Weights()[sortedIndices[1]] -
+ gmms[1].Weights()[1], 0.08);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Means()[sortedIndices[0]][0] -
+ gmms[1].Means()[0][0], 0.15);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Means()[sortedIndices[0]][1] -
+ gmms[1].Means()[0][1], 0.15);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Means()[sortedIndices[1]][0] -
+ gmms[1].Means()[1][0], 0.15);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Means()[sortedIndices[1]][1] -
+ gmms[1].Means()[1][1], 0.15);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[0]](0, 0) -
+ gmms[1].Covariances()[0](0, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[0]](0, 1) -
+ gmms[1].Covariances()[0](0, 1), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[0]](1, 0) -
+ gmms[1].Covariances()[0](1, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[0]](1, 1) -
+ gmms[1].Covariances()[0](1, 1), 0.3);
+
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[1]](0, 0) -
+ gmms[1].Covariances()[1](0, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[1]](0, 1) -
+ gmms[1].Covariances()[1](0, 1), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[1]](1, 0) -
+ gmms[1].Covariances()[1](1, 0), 0.3);
+ BOOST_REQUIRE_SMALL(hmm.Emission()[1].Covariances()[sortedIndices[1]](1, 1) -
+ gmms[1].Covariances()[1](1, 1), 0.3);
+}
+
+BOOST_AUTO_TEST_SUITE_END();
+
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/kernel_pca_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/kernel_pca_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/kernel_pca_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,67 +0,0 @@
-/**
- * @file kernel_pca_test.cpp
- * @author Ajinkya Kale <kaleajinkya at gmail.com>
- *
- * Test file for Kernel PCA.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/core/kernels/linear_kernel.hpp>
-#include <mlpack/methods/kernel_pca/kernel_pca.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-BOOST_AUTO_TEST_SUITE(KernelPCATest);
-
-using namespace mlpack;
-using namespace mlpack::kpca;
-using namespace mlpack::kernel;
-using namespace std;
-using namespace arma;
-
-BOOST_AUTO_TEST_CASE(linear_kernel)
-{
-
- mat data("1 0 2 3 9;"
- "5 2 8 4 8;"
- "6 7 3 1 8");
-
- KernelPCA<LinearKernel> p;
- p.Apply(data, 2); // Reduce to 2 dimensions.
-
- // Compare with correct results.
- mat correct("-1.53781086 -3.51358020 -0.16139887 -1.87706634 7.08985628;"
- " 1.29937798 3.45762685 -2.69910005 -3.15620704 1.09830225");
-
- // If the eigenvectors are pointed opposite directions, they will cancel
- // each other out in this summation.
- for(size_t i = 0; i < data.n_rows; i++)
- {
- if (fabs(correct(i, 1) + data(i, 1)) < 0.001 /* arbitrary */)
- {
- // Flip eigenvector for this column (negate output).
- data.row(i) *= -1;
- }
- }
-
- for (size_t row = 0; row < 2; ++row)
- for (size_t col = 0; col < 5; ++col)
- BOOST_REQUIRE_CLOSE(data(row, col), correct(row, col), 1e-3);
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/kernel_pca_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/kernel_pca_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/kernel_pca_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/kernel_pca_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,67 @@
+/**
+ * @file kernel_pca_test.cpp
+ * @author Ajinkya Kale <kaleajinkya at gmail.com>
+ *
+ * Test file for Kernel PCA.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/kernels/linear_kernel.hpp>
+#include <mlpack/methods/kernel_pca/kernel_pca.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+BOOST_AUTO_TEST_SUITE(KernelPCATest);
+
+using namespace mlpack;
+using namespace mlpack::kpca;
+using namespace mlpack::kernel;
+using namespace std;
+using namespace arma;
+
+BOOST_AUTO_TEST_CASE(linear_kernel)
+{
+
+ mat data("1 0 2 3 9;"
+ "5 2 8 4 8;"
+ "6 7 3 1 8");
+
+ KernelPCA<LinearKernel> p;
+ p.Apply(data, 2); // Reduce to 2 dimensions.
+
+ // Compare with correct results.
+ mat correct("-1.53781086 -3.51358020 -0.16139887 -1.87706634 7.08985628;"
+ " 1.29937798 3.45762685 -2.69910005 -3.15620704 1.09830225");
+
+ // If the eigenvectors are pointed opposite directions, they will cancel
+ // each other out in this summation.
+ for(size_t i = 0; i < data.n_rows; i++)
+ {
+ if (fabs(correct(i, 1) + data(i, 1)) < 0.001 /* arbitrary */)
+ {
+ // Flip eigenvector for this column (negate output).
+ data.row(i) *= -1;
+ }
+ }
+
+ for (size_t row = 0; row < 2; ++row)
+ for (size_t col = 0; col < 5; ++col)
+ BOOST_REQUIRE_CLOSE(data(row, col), correct(row, col), 1e-3);
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/kernel_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/kernel_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/kernel_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,608 +0,0 @@
-/**
- * @file kernel_test.cpp
- * @author Ryan Curtin
- * @author Ajinkya Kale
- *
- * Tests for the various kernel classes.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core/kernels/cosine_distance.hpp>
-#include <mlpack/core/kernels/epanechnikov_kernel.hpp>
-#include <mlpack/core/kernels/gaussian_kernel.hpp>
-#include <mlpack/core/kernels/hyperbolic_tangent_kernel.hpp>
-#include <mlpack/core/kernels/laplacian_kernel.hpp>
-#include <mlpack/core/kernels/linear_kernel.hpp>
-#include <mlpack/core/kernels/linear_kernel.hpp>
-#include <mlpack/core/kernels/polynomial_kernel.hpp>
-#include <mlpack/core/kernels/spherical_kernel.hpp>
-#include <mlpack/core/kernels/pspectrum_string_kernel.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-#include <mlpack/core/metrics/mahalanobis_distance.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::kernel;
-using namespace mlpack::metric;
-
-BOOST_AUTO_TEST_SUITE(KernelTest);
-
-/**
- * Basic test of the Manhattan distance.
- */
-BOOST_AUTO_TEST_CASE(manhattan_distance)
-{
- // A couple quick tests.
- arma::vec a = "1.0 3.0 4.0";
- arma::vec b = "3.0 3.0 5.0";
-
- BOOST_REQUIRE_CLOSE(ManhattanDistance::Evaluate(a, b), 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(ManhattanDistance::Evaluate(b, a), 3.0, 1e-5);
-
- // Check also for when the root is taken (should be the same).
- BOOST_REQUIRE_CLOSE((LMetric<1, true>::Evaluate(a, b)), 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE((LMetric<1, true>::Evaluate(b, a)), 3.0, 1e-5);
-}
-
-/**
- * Basic test of squared Euclidean distance.
- */
-BOOST_AUTO_TEST_CASE(squared_euclidean_distance)
-{
- // Sample 2-dimensional vectors.
- arma::vec a = "1.0 2.0";
- arma::vec b = "0.0 -2.0";
-
- BOOST_REQUIRE_CLOSE(SquaredEuclideanDistance::Evaluate(a, b), 17.0, 1e-5);
- BOOST_REQUIRE_CLOSE(SquaredEuclideanDistance::Evaluate(b, a), 17.0, 1e-5);
-}
-
-/**
- * Basic test of Euclidean distance.
- */
-BOOST_AUTO_TEST_CASE(euclidean_distance)
-{
- arma::vec a = "1.0 3.0 5.0 7.0";
- arma::vec b = "4.0 0.0 2.0 0.0";
-
- BOOST_REQUIRE_CLOSE(EuclideanDistance::Evaluate(a, b), sqrt(76.0), 1e-5);
- BOOST_REQUIRE_CLOSE(EuclideanDistance::Evaluate(b, a), sqrt(76.0), 1e-5);
-}
-
-/**
- * Arbitrary test case for coverage.
- */
-BOOST_AUTO_TEST_CASE(arbitrary_case)
-{
- arma::vec a = "3.0 5.0 6.0 7.0";
- arma::vec b = "1.0 2.0 1.0 0.0";
-
- BOOST_REQUIRE_CLOSE((LMetric<3, false>::Evaluate(a, b)), 503.0, 1e-5);
- BOOST_REQUIRE_CLOSE((LMetric<3, false>::Evaluate(b, a)), 503.0, 1e-5);
-
- BOOST_REQUIRE_CLOSE((LMetric<3, true>::Evaluate(a, b)), 7.95284762, 1e-5);
- BOOST_REQUIRE_CLOSE((LMetric<3, true>::Evaluate(b, a)), 7.95284762, 1e-5);
-}
-
-/**
- * Make sure two vectors of all zeros return zero distance, for a few different
- * powers.
- */
-BOOST_AUTO_TEST_CASE(lmetric_zeros)
-{
- arma::vec a(250);
- a.fill(0.0);
-
- // We cannot use a loop because compilers seem to be unable to unroll the loop
- // and realize the variable actually is knowable at compile-time.
- BOOST_REQUIRE((LMetric<1, false>::Evaluate(a, a)) == 0);
- BOOST_REQUIRE((LMetric<1, true>::Evaluate(a, a)) == 0);
- BOOST_REQUIRE((LMetric<2, false>::Evaluate(a, a)) == 0);
- BOOST_REQUIRE((LMetric<2, true>::Evaluate(a, a)) == 0);
- BOOST_REQUIRE((LMetric<3, false>::Evaluate(a, a)) == 0);
- BOOST_REQUIRE((LMetric<3, true>::Evaluate(a, a)) == 0);
- BOOST_REQUIRE((LMetric<4, false>::Evaluate(a, a)) == 0);
- BOOST_REQUIRE((LMetric<4, true>::Evaluate(a, a)) == 0);
- BOOST_REQUIRE((LMetric<5, false>::Evaluate(a, a)) == 0);
- BOOST_REQUIRE((LMetric<5, true>::Evaluate(a, a)) == 0);
-}
-
-/**
- * Simple test of Mahalanobis distance with unset covariance matrix in
- * constructor.
- */
-BOOST_AUTO_TEST_CASE(md_unset_covariance)
-{
- MahalanobisDistance<false> md;
- md.Covariance() = arma::eye<arma::mat>(4, 4);
- arma::vec a = "1.0 2.0 2.0 3.0";
- arma::vec b = "0.0 0.0 1.0 3.0";
-
- BOOST_REQUIRE_CLOSE(md.Evaluate(a, b), 6.0, 1e-5);
- BOOST_REQUIRE_CLOSE(md.Evaluate(b, a), 6.0, 1e-5);
-}
-
-/**
- * Simple test of Mahalanobis distance with unset covariance matrix in
- * constructor and t_take_root set to true.
- */
-BOOST_AUTO_TEST_CASE(md_root_unset_covariance)
-{
- MahalanobisDistance<true> md;
- md.Covariance() = arma::eye<arma::mat>(4, 4);
- arma::vec a = "1.0 2.0 2.5 5.0";
- arma::vec b = "0.0 2.0 0.5 8.0";
-
- BOOST_REQUIRE_CLOSE(md.Evaluate(a, b), sqrt(14.0), 1e-5);
- BOOST_REQUIRE_CLOSE(md.Evaluate(b, a), sqrt(14.0), 1e-5);
-}
-
-/**
- * Simple test of Mahalanobis distance setting identity covariance in
- * constructor.
- */
-BOOST_AUTO_TEST_CASE(md_eye_covariance)
-{
- MahalanobisDistance<false> md(4);
- arma::vec a = "1.0 2.0 2.0 3.0";
- arma::vec b = "0.0 0.0 1.0 3.0";
-
- BOOST_REQUIRE_CLOSE(md.Evaluate(a, b), 6.0, 1e-5);
- BOOST_REQUIRE_CLOSE(md.Evaluate(b, a), 6.0, 1e-5);
-}
-
-/**
- * Simple test of Mahalanobis distance setting identity covariance in
- * constructor and t_take_root set to true.
- */
-BOOST_AUTO_TEST_CASE(md_root_eye_covariance)
-{
- MahalanobisDistance<true> md(4);
- arma::vec a = "1.0 2.0 2.5 5.0";
- arma::vec b = "0.0 2.0 0.5 8.0";
-
- BOOST_REQUIRE_CLOSE(md.Evaluate(a, b), sqrt(14.0), 1e-5);
- BOOST_REQUIRE_CLOSE(md.Evaluate(b, a), sqrt(14.0), 1e-5);
-}
-
-/**
- * Simple test with diagonal covariance matrix.
- */
-BOOST_AUTO_TEST_CASE(md_diagonal_covariance)
-{
- arma::mat cov = arma::eye<arma::mat>(5, 5);
- cov(0, 0) = 2.0;
- cov(1, 1) = 0.5;
- cov(2, 2) = 3.0;
- cov(3, 3) = 1.0;
- cov(4, 4) = 1.5;
- MahalanobisDistance<false> md(cov);
-
- arma::vec a = "1.0 2.0 2.0 4.0 5.0";
- arma::vec b = "2.0 3.0 1.0 1.0 0.0";
-
- BOOST_REQUIRE_CLOSE(md.Evaluate(a, b), 52.0, 1e-5);
- BOOST_REQUIRE_CLOSE(md.Evaluate(b, a), 52.0, 1e-5);
-}
-
-/**
- * More specific case with more difficult covariance matrix.
- */
-BOOST_AUTO_TEST_CASE(md_full_covariance)
-{
- arma::mat cov = "1.0 2.0 3.0 4.0;"
- "0.5 0.6 0.7 0.1;"
- "3.4 4.3 5.0 6.1;"
- "1.0 2.0 4.0 1.0;";
- MahalanobisDistance<false> md(cov);
-
- arma::vec a = "1.0 2.0 2.0 4.0";
- arma::vec b = "2.0 3.0 1.0 1.0";
-
- BOOST_REQUIRE_CLOSE(md.Evaluate(a, b), 15.7, 1e-5);
- BOOST_REQUIRE_CLOSE(md.Evaluate(b, a), 15.7, 1e-5);
-}
-
-/**
- * Simple test case for the cosine distance.
- */
-BOOST_AUTO_TEST_CASE(cosine_distance_same_angle)
-{
- arma::vec a = "1.0 2.0 3.0";
- arma::vec b = "2.0 4.0 6.0";
-
- BOOST_REQUIRE_CLOSE(CosineDistance::Evaluate(a, b), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(CosineDistance::Evaluate(b, a), 1.0, 1e-5);
-}
-
-/**
- * Now let's have them be orthogonal.
- */
-BOOST_AUTO_TEST_CASE(cosine_distance_orthogonal)
-{
- arma::vec a = "0.0 1.0";
- arma::vec b = "1.0 0.0";
-
- BOOST_REQUIRE_SMALL(CosineDistance::Evaluate(a, b), 1e-5);
- BOOST_REQUIRE_SMALL(CosineDistance::Evaluate(b, a), 1e-5);
-}
-
-/**
- * Some random angle test.
- */
-BOOST_AUTO_TEST_CASE(cosine_distance_random_test)
-{
- arma::vec a = "0.1 0.2 0.3 0.4 0.5";
- arma::vec b = "1.2 1.0 0.8 -0.3 -0.5";
-
- BOOST_REQUIRE_CLOSE(CosineDistance::Evaluate(a, b), 0.1385349024, 1e-5);
- BOOST_REQUIRE_CLOSE(CosineDistance::Evaluate(b, a), 0.1385349024, 1e-5);
-}
-
-/**
- * Linear Kernel test.
- */
-BOOST_AUTO_TEST_CASE(linear_kernel)
-{
- arma::vec a = ".2 .3 .4 .1";
- arma::vec b = ".56 .21 .623 .82";
-
- LinearKernel lk;
- BOOST_REQUIRE_CLOSE(lk.Evaluate(a,b), .5062, 1e-5);
- BOOST_REQUIRE_CLOSE(lk.Evaluate(b,a), .5062, 1e-5);
-}
-
-/**
- * Linear Kernel test, orthogonal vectors.
- */
-BOOST_AUTO_TEST_CASE(linear_kernel_orthogonal)
-{
- arma::vec a = "1 0 0";
- arma::vec b = "0 0 1";
-
- LinearKernel lk;
- BOOST_REQUIRE_SMALL(lk.Evaluate(a,b), 1e-5);
- BOOST_REQUIRE_SMALL(lk.Evaluate(b,a), 1e-5);
-}
-
-BOOST_AUTO_TEST_CASE(gaussian_kernel)
-{
- arma::vec a = "1 0 0";
- arma::vec b = "0 1 0";
- arma::vec c = "0 0 1";
-
- GaussianKernel gk(.5);
- BOOST_REQUIRE_CLOSE(gk.Evaluate(a,b), .018315638888734, 1e-5);
- BOOST_REQUIRE_CLOSE(gk.Evaluate(b,a), .018315638888734, 1e-5);
- BOOST_REQUIRE_CLOSE(gk.Evaluate(a,c), .018315638888734, 1e-5);
- BOOST_REQUIRE_CLOSE(gk.Evaluate(c,a), .018315638888734, 1e-5);
- BOOST_REQUIRE_CLOSE(gk.Evaluate(b,c), .018315638888734, 1e-5);
- BOOST_REQUIRE_CLOSE(gk.Evaluate(c,b), .018315638888734, 1e-5);
- /* check the single dimension evaluate function */
- BOOST_REQUIRE_CLOSE(gk.Evaluate(1.0), 0.1353352832366127, 1e-5);
- BOOST_REQUIRE_CLOSE(gk.Evaluate(2.0), 0.00033546262790251185, 1e-5);
- BOOST_REQUIRE_CLOSE(gk.Evaluate(3.0), 1.5229979744712629e-08, 1e-5);
- /* check the normalization constant */
- BOOST_REQUIRE_CLOSE(gk.Normalizer(1), 1.2533141373155001, 1e-5);
- BOOST_REQUIRE_CLOSE(gk.Normalizer(2), 1.5707963267948963, 1e-5);
- BOOST_REQUIRE_CLOSE(gk.Normalizer(3), 1.9687012432153019, 1e-5);
- BOOST_REQUIRE_CLOSE(gk.Normalizer(4), 2.4674011002723386, 1e-5);
- /* check the convolution integral */
- BOOST_REQUIRE_CLOSE(gk.ConvolutionIntegral(a,b), 0.024304474038457577, 1e-5);
- BOOST_REQUIRE_CLOSE(gk.ConvolutionIntegral(a,c), 0.024304474038457577, 1e-5);
- BOOST_REQUIRE_CLOSE(gk.ConvolutionIntegral(b,c), 0.024304474038457577, 1e-5);
-
-}
-
-BOOST_AUTO_TEST_CASE(spherical_kernel)
-{
- arma::vec a = "1.0 0.0";
- arma::vec b = "0.0 1.0";
- arma::vec c = "0.2 0.9";
-
- SphericalKernel sk(.5);
- BOOST_REQUIRE_CLOSE(sk.Evaluate(a,b), 0.0, 1e-5);
- BOOST_REQUIRE_CLOSE(sk.Evaluate(a,c), 0.0, 1e-5);
- BOOST_REQUIRE_CLOSE(sk.Evaluate(b,c), 1.0, 1e-5);
- /* check the single dimension evaluate function */
- BOOST_REQUIRE_CLOSE(sk.Evaluate(0.10), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(sk.Evaluate(0.25), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(sk.Evaluate(0.50), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(sk.Evaluate(1.00), 0.0, 1e-5);
- /* check the normalization constant */
- BOOST_REQUIRE_CLOSE(sk.Normalizer(1), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(sk.Normalizer(2), 0.78539816339744828, 1e-5);
- BOOST_REQUIRE_CLOSE(sk.Normalizer(3), 0.52359877559829893, 1e-5);
- BOOST_REQUIRE_CLOSE(sk.Normalizer(4), 0.30842513753404244, 1e-5);
- /* check the convolution integral */
- BOOST_REQUIRE_CLOSE(sk.ConvolutionIntegral(a,b), 0.0, 1e-5);
- BOOST_REQUIRE_CLOSE(sk.ConvolutionIntegral(a,c), 0.0, 1e-5);
- BOOST_REQUIRE_CLOSE(sk.ConvolutionIntegral(b,c), 1.0021155029652784, 1e-5);
-}
-
-BOOST_AUTO_TEST_CASE(epanechnikov_kernel)
-{
- arma::vec a = "1.0 0.0";
- arma::vec b = "0.0 1.0";
- arma::vec c = "0.1 0.9";
-
- EpanechnikovKernel ek(.5);
- BOOST_REQUIRE_CLOSE(ek.Evaluate(a,b), 0.0, 1e-5);
- BOOST_REQUIRE_CLOSE(ek.Evaluate(b,c), 0.92, 1e-5);
- BOOST_REQUIRE_CLOSE(ek.Evaluate(a,c), 0.0, 1e-5);
- /* check the single dimension evaluate function */
- BOOST_REQUIRE_CLOSE(ek.Evaluate(0.10), 0.96, 1e-5);
- BOOST_REQUIRE_CLOSE(ek.Evaluate(0.25), 0.75, 1e-5);
- BOOST_REQUIRE_CLOSE(ek.Evaluate(0.50), 0.0, 1e-5);
- BOOST_REQUIRE_CLOSE(ek.Evaluate(1.00), 0.0, 1e-5);
- /* check the normalization constant */
- BOOST_REQUIRE_CLOSE(ek.Normalizer(1), 0.666666666666666, 1e-5);
- BOOST_REQUIRE_CLOSE(ek.Normalizer(2), 0.39269908169872414, 1e-5);
- BOOST_REQUIRE_CLOSE(ek.Normalizer(3), 0.20943951023931956, 1e-5);
- BOOST_REQUIRE_CLOSE(ek.Normalizer(4), 0.10280837917801415, 1e-5);
- /* check the convolution integral */
- BOOST_REQUIRE_CLOSE(ek.ConvolutionIntegral(a,b), 0.0, 1e-5);
- BOOST_REQUIRE_CLOSE(ek.ConvolutionIntegral(a,c), 0.0, 1e-5);
- BOOST_REQUIRE_CLOSE(ek.ConvolutionIntegral(b,c), 1.5263455690698258, 1e-5);
-}
-
-BOOST_AUTO_TEST_CASE(polynomial_kernel)
-{
- arma::vec a = "0 0 1";
- arma::vec b = "0 1 0";
-
- PolynomialKernel pk(5.0, 5.0);
- BOOST_REQUIRE_CLOSE(pk.Evaluate(a, b), 3125.0, 0);
- BOOST_REQUIRE_CLOSE(pk.Evaluate(b, a), 3125.0, 0);
-}
-
-BOOST_AUTO_TEST_CASE(hyperbolic_tangent_kernel)
-{
- arma::vec a = "0 0 1";
- arma::vec b = "0 1 0";
-
- HyperbolicTangentKernel tk(5.0, 5.0);
- BOOST_REQUIRE_CLOSE(tk.Evaluate(a, b), 0.9999092, 1e-5);
- BOOST_REQUIRE_CLOSE(tk.Evaluate(b, a), 0.9999092, 1e-5);
-}
-
-BOOST_AUTO_TEST_CASE(laplacian_kernel)
-{
- arma::vec a = "0 0 1";
- arma::vec b = "0 1 0";
-
- LaplacianKernel lk(1.0);
- BOOST_REQUIRE_CLOSE(lk.Evaluate(a, b), 0.243116734, 5e-5);
- BOOST_REQUIRE_CLOSE(lk.Evaluate(b, a), 0.243116734, 5e-5);
-}
-
-// Ensure that the p-spectrum kernel successfully extracts all length-p
-// substrings from the data.
-BOOST_AUTO_TEST_CASE(PSpectrumSubstringExtractionTest)
-{
- std::vector<std::vector<std::string> > datasets;
-
- datasets.push_back(std::vector<std::string>());
-
- datasets[0].push_back("herpgle");
- datasets[0].push_back("herpagkle");
- datasets[0].push_back("klunktor");
- datasets[0].push_back("flibbynopple");
-
- datasets.push_back(std::vector<std::string>());
-
- datasets[1].push_back("floggy3245");
- datasets[1].push_back("flippydopflip");
- datasets[1].push_back("stupid fricking cat");
- datasets[1].push_back("food time isn't until later");
- datasets[1].push_back("leave me alone until 6:00");
- datasets[1].push_back("only after that do you get any food.");
- datasets[1].push_back("obloblobloblobloblobloblob");
-
- PSpectrumStringKernel p(datasets, 3);
-
- // Ensure the sizes are correct.
- BOOST_REQUIRE_EQUAL(p.Counts().size(), 2);
- BOOST_REQUIRE_EQUAL(p.Counts()[0].size(), 4);
- BOOST_REQUIRE_EQUAL(p.Counts()[1].size(), 7);
-
- // herpgle: her, erp, rpg, pgl, gle
- BOOST_REQUIRE_EQUAL(p.Counts()[0][0].size(), 5);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][0]["her"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][0]["erp"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][0]["rpg"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][0]["pgl"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][0]["gle"], 1);
-
- // herpagkle: her, erp, rpa, pag, agk, gkl, kle
- BOOST_REQUIRE_EQUAL(p.Counts()[0][1].size(), 7);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][1]["her"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][1]["erp"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][1]["rpa"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][1]["pag"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][1]["agk"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][1]["gkl"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][1]["kle"], 1);
-
- // klunktor: klu, lun, unk, nkt, kto, tor
- BOOST_REQUIRE_EQUAL(p.Counts()[0][2].size(), 6);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][2]["klu"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][2]["lun"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][2]["unk"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][2]["nkt"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][2]["kto"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][2]["tor"], 1);
-
- // flibbynopple: fli lib ibb bby byn yno nop opp ppl ple
- BOOST_REQUIRE_EQUAL(p.Counts()[0][3].size(), 10);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][3]["fli"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][3]["lib"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][3]["ibb"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][3]["bby"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][3]["byn"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][3]["yno"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][3]["nop"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][3]["opp"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][3]["ppl"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[0][3]["ple"], 1);
-
- // floggy3245: flo log ogg ggy gy3 y32 324 245
- BOOST_REQUIRE_EQUAL(p.Counts()[1][0].size(), 8);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][0]["flo"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][0]["log"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][0]["ogg"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][0]["ggy"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][0]["gy3"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][0]["y32"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][0]["324"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][0]["245"], 1);
-
- // flippydopflip: fli lip ipp ppy pyd ydo dop opf pfl fli lip
- // fli(2) lip(2) ipp ppy pyd ydo dop opf pfl
- BOOST_REQUIRE_EQUAL(p.Counts()[1][1].size(), 9);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][1]["fli"], 2);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][1]["lip"], 2);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][1]["ipp"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][1]["ppy"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][1]["pyd"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][1]["ydo"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][1]["dop"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][1]["opf"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][1]["pfl"], 1);
-
- // stupid fricking cat: stu tup upi pid fri ric ick cki kin ing cat
- BOOST_REQUIRE_EQUAL(p.Counts()[1][2].size(), 11);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["stu"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["tup"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["upi"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["pid"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["fri"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["ric"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["ick"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["cki"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["kin"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["ing"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["cat"], 1);
-
- // food time isn't until later: foo ood tim ime isn unt nti til lat ate ter
- BOOST_REQUIRE_EQUAL(p.Counts()[1][3].size(), 11);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["foo"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["ood"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["tim"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["ime"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["isn"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["unt"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["nti"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["til"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["lat"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["ate"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["ter"], 1);
-
- // leave me alone until 6:00: lea eav ave alo lon one unt nti til
- BOOST_REQUIRE_EQUAL(p.Counts()[1][4].size(), 9);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][4]["lea"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][4]["eav"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][4]["ave"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][4]["alo"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][4]["lon"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][4]["one"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][4]["unt"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][4]["nti"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][4]["til"], 1);
-
- // only after that do you get any food.:
- // onl nly aft fte ter tha hat you get any foo ood
- BOOST_REQUIRE_EQUAL(p.Counts()[1][5].size(), 12);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["onl"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["nly"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["aft"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["fte"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["ter"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["tha"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["hat"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["you"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["get"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["any"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["foo"], 1);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["ood"], 1);
-
- // obloblobloblobloblobloblob: obl(8) blo(8) lob(8)
- BOOST_REQUIRE_EQUAL(p.Counts()[1][6].size(), 3);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][6]["obl"], 8);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][6]["blo"], 8);
- BOOST_REQUIRE_EQUAL(p.Counts()[1][6]["lob"], 8);
-}
-
-BOOST_AUTO_TEST_CASE(PSpectrumStringEvaluateTest)
-{
- // Construct simple dataset.
- std::vector<std::vector<std::string> > dataset;
- dataset.push_back(std::vector<std::string>());
- dataset[0].push_back("hello");
- dataset[0].push_back("jello");
- dataset[0].push_back("mellow");
- dataset[0].push_back("mellow jello");
-
- PSpectrumStringKernel p(dataset, 3);
-
- arma::vec a("0 0");
- arma::vec b("0 0");
-
- BOOST_REQUIRE_CLOSE(p.Evaluate(a, b), 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(p.Evaluate(b, a), 3.0, 1e-5);
-
- b = "0 1";
- BOOST_REQUIRE_CLOSE(p.Evaluate(a, b), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(p.Evaluate(b, a), 2.0, 1e-5);
-
- b = "0 2";
- BOOST_REQUIRE_CLOSE(p.Evaluate(a, b), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(p.Evaluate(b, a), 2.0, 1e-5);
-
- b = "0 3";
- BOOST_REQUIRE_CLOSE(p.Evaluate(a, b), 4.0, 1e-5);
- BOOST_REQUIRE_CLOSE(p.Evaluate(b, a), 4.0, 1e-5);
-
- a = "0 1";
- b = "0 1";
- BOOST_REQUIRE_CLOSE(p.Evaluate(a, b), 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(p.Evaluate(b, a), 3.0, 1e-5);
-
- b = "0 2";
- BOOST_REQUIRE_CLOSE(p.Evaluate(a, b), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(p.Evaluate(b, a), 2.0, 1e-5);
-
- b = "0 3";
- BOOST_REQUIRE_CLOSE(p.Evaluate(a, b), 5.0, 1e-5);
- BOOST_REQUIRE_CLOSE(p.Evaluate(b, a), 5.0, 1e-5);
-
- a = "0 2";
- b = "0 2";
- BOOST_REQUIRE_CLOSE(p.Evaluate(a, b), 4.0, 1e-5);
- BOOST_REQUIRE_CLOSE(p.Evaluate(b, a), 4.0, 1e-5);
-
- b = "0 3";
- BOOST_REQUIRE_CLOSE(p.Evaluate(a, b), 6.0, 1e-5);
- BOOST_REQUIRE_CLOSE(p.Evaluate(b, a), 6.0, 1e-5);
-
- a = "0 3";
- BOOST_REQUIRE_CLOSE(p.Evaluate(a, b), 11.0, 1e-5);
- BOOST_REQUIRE_CLOSE(p.Evaluate(b, a), 11.0, 1e-5);
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/kernel_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/kernel_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/kernel_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/kernel_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,608 @@
+/**
+ * @file kernel_test.cpp
+ * @author Ryan Curtin
+ * @author Ajinkya Kale
+ *
+ * Tests for the various kernel classes.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core/kernels/cosine_distance.hpp>
+#include <mlpack/core/kernels/epanechnikov_kernel.hpp>
+#include <mlpack/core/kernels/gaussian_kernel.hpp>
+#include <mlpack/core/kernels/hyperbolic_tangent_kernel.hpp>
+#include <mlpack/core/kernels/laplacian_kernel.hpp>
+#include <mlpack/core/kernels/linear_kernel.hpp>
+#include <mlpack/core/kernels/linear_kernel.hpp>
+#include <mlpack/core/kernels/polynomial_kernel.hpp>
+#include <mlpack/core/kernels/spherical_kernel.hpp>
+#include <mlpack/core/kernels/pspectrum_string_kernel.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/core/metrics/mahalanobis_distance.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::kernel;
+using namespace mlpack::metric;
+
+BOOST_AUTO_TEST_SUITE(KernelTest);
+
+/**
+ * Basic test of the Manhattan distance.
+ */
+BOOST_AUTO_TEST_CASE(manhattan_distance)
+{
+ // A couple quick tests.
+ arma::vec a = "1.0 3.0 4.0";
+ arma::vec b = "3.0 3.0 5.0";
+
+ BOOST_REQUIRE_CLOSE(ManhattanDistance::Evaluate(a, b), 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(ManhattanDistance::Evaluate(b, a), 3.0, 1e-5);
+
+ // Check also for when the root is taken (should be the same).
+ BOOST_REQUIRE_CLOSE((LMetric<1, true>::Evaluate(a, b)), 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE((LMetric<1, true>::Evaluate(b, a)), 3.0, 1e-5);
+}
+
+/**
+ * Basic test of squared Euclidean distance.
+ */
+BOOST_AUTO_TEST_CASE(squared_euclidean_distance)
+{
+ // Sample 2-dimensional vectors.
+ arma::vec a = "1.0 2.0";
+ arma::vec b = "0.0 -2.0";
+
+ BOOST_REQUIRE_CLOSE(SquaredEuclideanDistance::Evaluate(a, b), 17.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(SquaredEuclideanDistance::Evaluate(b, a), 17.0, 1e-5);
+}
+
+/**
+ * Basic test of Euclidean distance.
+ */
+BOOST_AUTO_TEST_CASE(euclidean_distance)
+{
+ arma::vec a = "1.0 3.0 5.0 7.0";
+ arma::vec b = "4.0 0.0 2.0 0.0";
+
+ BOOST_REQUIRE_CLOSE(EuclideanDistance::Evaluate(a, b), sqrt(76.0), 1e-5);
+ BOOST_REQUIRE_CLOSE(EuclideanDistance::Evaluate(b, a), sqrt(76.0), 1e-5);
+}
+
+/**
+ * Arbitrary test case for coverage.
+ */
+BOOST_AUTO_TEST_CASE(arbitrary_case)
+{
+ arma::vec a = "3.0 5.0 6.0 7.0";
+ arma::vec b = "1.0 2.0 1.0 0.0";
+
+ BOOST_REQUIRE_CLOSE((LMetric<3, false>::Evaluate(a, b)), 503.0, 1e-5);
+ BOOST_REQUIRE_CLOSE((LMetric<3, false>::Evaluate(b, a)), 503.0, 1e-5);
+
+ BOOST_REQUIRE_CLOSE((LMetric<3, true>::Evaluate(a, b)), 7.95284762, 1e-5);
+ BOOST_REQUIRE_CLOSE((LMetric<3, true>::Evaluate(b, a)), 7.95284762, 1e-5);
+}
+
+/**
+ * Make sure two vectors of all zeros return zero distance, for a few different
+ * powers.
+ */
+BOOST_AUTO_TEST_CASE(lmetric_zeros)
+{
+ arma::vec a(250);
+ a.fill(0.0);
+
+ // We cannot use a loop because compilers seem to be unable to unroll the loop
+ // and realize the variable actually is knowable at compile-time.
+ BOOST_REQUIRE((LMetric<1, false>::Evaluate(a, a)) == 0);
+ BOOST_REQUIRE((LMetric<1, true>::Evaluate(a, a)) == 0);
+ BOOST_REQUIRE((LMetric<2, false>::Evaluate(a, a)) == 0);
+ BOOST_REQUIRE((LMetric<2, true>::Evaluate(a, a)) == 0);
+ BOOST_REQUIRE((LMetric<3, false>::Evaluate(a, a)) == 0);
+ BOOST_REQUIRE((LMetric<3, true>::Evaluate(a, a)) == 0);
+ BOOST_REQUIRE((LMetric<4, false>::Evaluate(a, a)) == 0);
+ BOOST_REQUIRE((LMetric<4, true>::Evaluate(a, a)) == 0);
+ BOOST_REQUIRE((LMetric<5, false>::Evaluate(a, a)) == 0);
+ BOOST_REQUIRE((LMetric<5, true>::Evaluate(a, a)) == 0);
+}
+
+/**
+ * Simple test of Mahalanobis distance with unset covariance matrix in
+ * constructor.
+ */
+BOOST_AUTO_TEST_CASE(md_unset_covariance)
+{
+ MahalanobisDistance<false> md;
+ md.Covariance() = arma::eye<arma::mat>(4, 4);
+ arma::vec a = "1.0 2.0 2.0 3.0";
+ arma::vec b = "0.0 0.0 1.0 3.0";
+
+ BOOST_REQUIRE_CLOSE(md.Evaluate(a, b), 6.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(md.Evaluate(b, a), 6.0, 1e-5);
+}
+
+/**
+ * Simple test of Mahalanobis distance with unset covariance matrix in
+ * constructor and t_take_root set to true.
+ */
+BOOST_AUTO_TEST_CASE(md_root_unset_covariance)
+{
+ MahalanobisDistance<true> md;
+ md.Covariance() = arma::eye<arma::mat>(4, 4);
+ arma::vec a = "1.0 2.0 2.5 5.0";
+ arma::vec b = "0.0 2.0 0.5 8.0";
+
+ BOOST_REQUIRE_CLOSE(md.Evaluate(a, b), sqrt(14.0), 1e-5);
+ BOOST_REQUIRE_CLOSE(md.Evaluate(b, a), sqrt(14.0), 1e-5);
+}
+
+/**
+ * Simple test of Mahalanobis distance setting identity covariance in
+ * constructor.
+ */
+BOOST_AUTO_TEST_CASE(md_eye_covariance)
+{
+ MahalanobisDistance<false> md(4);
+ arma::vec a = "1.0 2.0 2.0 3.0";
+ arma::vec b = "0.0 0.0 1.0 3.0";
+
+ BOOST_REQUIRE_CLOSE(md.Evaluate(a, b), 6.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(md.Evaluate(b, a), 6.0, 1e-5);
+}
+
+/**
+ * Simple test of Mahalanobis distance setting identity covariance in
+ * constructor and t_take_root set to true.
+ */
+BOOST_AUTO_TEST_CASE(md_root_eye_covariance)
+{
+ MahalanobisDistance<true> md(4);
+ arma::vec a = "1.0 2.0 2.5 5.0";
+ arma::vec b = "0.0 2.0 0.5 8.0";
+
+ BOOST_REQUIRE_CLOSE(md.Evaluate(a, b), sqrt(14.0), 1e-5);
+ BOOST_REQUIRE_CLOSE(md.Evaluate(b, a), sqrt(14.0), 1e-5);
+}
+
+/**
+ * Simple test with diagonal covariance matrix.
+ */
+BOOST_AUTO_TEST_CASE(md_diagonal_covariance)
+{
+ arma::mat cov = arma::eye<arma::mat>(5, 5);
+ cov(0, 0) = 2.0;
+ cov(1, 1) = 0.5;
+ cov(2, 2) = 3.0;
+ cov(3, 3) = 1.0;
+ cov(4, 4) = 1.5;
+ MahalanobisDistance<false> md(cov);
+
+ arma::vec a = "1.0 2.0 2.0 4.0 5.0";
+ arma::vec b = "2.0 3.0 1.0 1.0 0.0";
+
+ BOOST_REQUIRE_CLOSE(md.Evaluate(a, b), 52.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(md.Evaluate(b, a), 52.0, 1e-5);
+}
+
+/**
+ * More specific case with more difficult covariance matrix.
+ */
+BOOST_AUTO_TEST_CASE(md_full_covariance)
+{
+ arma::mat cov = "1.0 2.0 3.0 4.0;"
+ "0.5 0.6 0.7 0.1;"
+ "3.4 4.3 5.0 6.1;"
+ "1.0 2.0 4.0 1.0;";
+ MahalanobisDistance<false> md(cov);
+
+ arma::vec a = "1.0 2.0 2.0 4.0";
+ arma::vec b = "2.0 3.0 1.0 1.0";
+
+ BOOST_REQUIRE_CLOSE(md.Evaluate(a, b), 15.7, 1e-5);
+ BOOST_REQUIRE_CLOSE(md.Evaluate(b, a), 15.7, 1e-5);
+}
+
+/**
+ * Simple test case for the cosine distance.
+ */
+BOOST_AUTO_TEST_CASE(cosine_distance_same_angle)
+{
+ arma::vec a = "1.0 2.0 3.0";
+ arma::vec b = "2.0 4.0 6.0";
+
+ BOOST_REQUIRE_CLOSE(CosineDistance::Evaluate(a, b), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(CosineDistance::Evaluate(b, a), 1.0, 1e-5);
+}
+
+/**
+ * Now let's have them be orthogonal.
+ */
+BOOST_AUTO_TEST_CASE(cosine_distance_orthogonal)
+{
+ arma::vec a = "0.0 1.0";
+ arma::vec b = "1.0 0.0";
+
+ BOOST_REQUIRE_SMALL(CosineDistance::Evaluate(a, b), 1e-5);
+ BOOST_REQUIRE_SMALL(CosineDistance::Evaluate(b, a), 1e-5);
+}
+
+/**
+ * Some random angle test.
+ */
+BOOST_AUTO_TEST_CASE(cosine_distance_random_test)
+{
+ arma::vec a = "0.1 0.2 0.3 0.4 0.5";
+ arma::vec b = "1.2 1.0 0.8 -0.3 -0.5";
+
+ BOOST_REQUIRE_CLOSE(CosineDistance::Evaluate(a, b), 0.1385349024, 1e-5);
+ BOOST_REQUIRE_CLOSE(CosineDistance::Evaluate(b, a), 0.1385349024, 1e-5);
+}
+
+/**
+ * Linear Kernel test.
+ */
+BOOST_AUTO_TEST_CASE(linear_kernel)
+{
+ arma::vec a = ".2 .3 .4 .1";
+ arma::vec b = ".56 .21 .623 .82";
+
+ LinearKernel lk;
+ BOOST_REQUIRE_CLOSE(lk.Evaluate(a,b), .5062, 1e-5);
+ BOOST_REQUIRE_CLOSE(lk.Evaluate(b,a), .5062, 1e-5);
+}
+
+/**
+ * Linear Kernel test, orthogonal vectors.
+ */
+BOOST_AUTO_TEST_CASE(linear_kernel_orthogonal)
+{
+ arma::vec a = "1 0 0";
+ arma::vec b = "0 0 1";
+
+ LinearKernel lk;
+ BOOST_REQUIRE_SMALL(lk.Evaluate(a,b), 1e-5);
+ BOOST_REQUIRE_SMALL(lk.Evaluate(b,a), 1e-5);
+}
+
+BOOST_AUTO_TEST_CASE(gaussian_kernel)
+{
+ arma::vec a = "1 0 0";
+ arma::vec b = "0 1 0";
+ arma::vec c = "0 0 1";
+
+ GaussianKernel gk(.5);
+ BOOST_REQUIRE_CLOSE(gk.Evaluate(a,b), .018315638888734, 1e-5);
+ BOOST_REQUIRE_CLOSE(gk.Evaluate(b,a), .018315638888734, 1e-5);
+ BOOST_REQUIRE_CLOSE(gk.Evaluate(a,c), .018315638888734, 1e-5);
+ BOOST_REQUIRE_CLOSE(gk.Evaluate(c,a), .018315638888734, 1e-5);
+ BOOST_REQUIRE_CLOSE(gk.Evaluate(b,c), .018315638888734, 1e-5);
+ BOOST_REQUIRE_CLOSE(gk.Evaluate(c,b), .018315638888734, 1e-5);
+ /* check the single dimension evaluate function */
+ BOOST_REQUIRE_CLOSE(gk.Evaluate(1.0), 0.1353352832366127, 1e-5);
+ BOOST_REQUIRE_CLOSE(gk.Evaluate(2.0), 0.00033546262790251185, 1e-5);
+ BOOST_REQUIRE_CLOSE(gk.Evaluate(3.0), 1.5229979744712629e-08, 1e-5);
+ /* check the normalization constant */
+ BOOST_REQUIRE_CLOSE(gk.Normalizer(1), 1.2533141373155001, 1e-5);
+ BOOST_REQUIRE_CLOSE(gk.Normalizer(2), 1.5707963267948963, 1e-5);
+ BOOST_REQUIRE_CLOSE(gk.Normalizer(3), 1.9687012432153019, 1e-5);
+ BOOST_REQUIRE_CLOSE(gk.Normalizer(4), 2.4674011002723386, 1e-5);
+ /* check the convolution integral */
+ BOOST_REQUIRE_CLOSE(gk.ConvolutionIntegral(a,b), 0.024304474038457577, 1e-5);
+ BOOST_REQUIRE_CLOSE(gk.ConvolutionIntegral(a,c), 0.024304474038457577, 1e-5);
+ BOOST_REQUIRE_CLOSE(gk.ConvolutionIntegral(b,c), 0.024304474038457577, 1e-5);
+
+}
+
+BOOST_AUTO_TEST_CASE(spherical_kernel)
+{
+ arma::vec a = "1.0 0.0";
+ arma::vec b = "0.0 1.0";
+ arma::vec c = "0.2 0.9";
+
+ SphericalKernel sk(.5);
+ BOOST_REQUIRE_CLOSE(sk.Evaluate(a,b), 0.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(sk.Evaluate(a,c), 0.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(sk.Evaluate(b,c), 1.0, 1e-5);
+ /* check the single dimension evaluate function */
+ BOOST_REQUIRE_CLOSE(sk.Evaluate(0.10), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(sk.Evaluate(0.25), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(sk.Evaluate(0.50), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(sk.Evaluate(1.00), 0.0, 1e-5);
+ /* check the normalization constant */
+ BOOST_REQUIRE_CLOSE(sk.Normalizer(1), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(sk.Normalizer(2), 0.78539816339744828, 1e-5);
+ BOOST_REQUIRE_CLOSE(sk.Normalizer(3), 0.52359877559829893, 1e-5);
+ BOOST_REQUIRE_CLOSE(sk.Normalizer(4), 0.30842513753404244, 1e-5);
+ /* check the convolution integral */
+ BOOST_REQUIRE_CLOSE(sk.ConvolutionIntegral(a,b), 0.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(sk.ConvolutionIntegral(a,c), 0.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(sk.ConvolutionIntegral(b,c), 1.0021155029652784, 1e-5);
+}
+
+BOOST_AUTO_TEST_CASE(epanechnikov_kernel)
+{
+ arma::vec a = "1.0 0.0";
+ arma::vec b = "0.0 1.0";
+ arma::vec c = "0.1 0.9";
+
+ EpanechnikovKernel ek(.5);
+ BOOST_REQUIRE_CLOSE(ek.Evaluate(a,b), 0.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(ek.Evaluate(b,c), 0.92, 1e-5);
+ BOOST_REQUIRE_CLOSE(ek.Evaluate(a,c), 0.0, 1e-5);
+ /* check the single dimension evaluate function */
+ BOOST_REQUIRE_CLOSE(ek.Evaluate(0.10), 0.96, 1e-5);
+ BOOST_REQUIRE_CLOSE(ek.Evaluate(0.25), 0.75, 1e-5);
+ BOOST_REQUIRE_CLOSE(ek.Evaluate(0.50), 0.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(ek.Evaluate(1.00), 0.0, 1e-5);
+ /* check the normalization constant */
+ BOOST_REQUIRE_CLOSE(ek.Normalizer(1), 0.666666666666666, 1e-5);
+ BOOST_REQUIRE_CLOSE(ek.Normalizer(2), 0.39269908169872414, 1e-5);
+ BOOST_REQUIRE_CLOSE(ek.Normalizer(3), 0.20943951023931956, 1e-5);
+ BOOST_REQUIRE_CLOSE(ek.Normalizer(4), 0.10280837917801415, 1e-5);
+ /* check the convolution integral */
+ BOOST_REQUIRE_CLOSE(ek.ConvolutionIntegral(a,b), 0.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(ek.ConvolutionIntegral(a,c), 0.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(ek.ConvolutionIntegral(b,c), 1.5263455690698258, 1e-5);
+}
+
+BOOST_AUTO_TEST_CASE(polynomial_kernel)
+{
+ arma::vec a = "0 0 1";
+ arma::vec b = "0 1 0";
+
+ PolynomialKernel pk(5.0, 5.0);
+ BOOST_REQUIRE_CLOSE(pk.Evaluate(a, b), 3125.0, 0);
+ BOOST_REQUIRE_CLOSE(pk.Evaluate(b, a), 3125.0, 0);
+}
+
+BOOST_AUTO_TEST_CASE(hyperbolic_tangent_kernel)
+{
+ arma::vec a = "0 0 1";
+ arma::vec b = "0 1 0";
+
+ HyperbolicTangentKernel tk(5.0, 5.0);
+ BOOST_REQUIRE_CLOSE(tk.Evaluate(a, b), 0.9999092, 1e-5);
+ BOOST_REQUIRE_CLOSE(tk.Evaluate(b, a), 0.9999092, 1e-5);
+}
+
+BOOST_AUTO_TEST_CASE(laplacian_kernel)
+{
+ arma::vec a = "0 0 1";
+ arma::vec b = "0 1 0";
+
+ LaplacianKernel lk(1.0);
+ BOOST_REQUIRE_CLOSE(lk.Evaluate(a, b), 0.243116734, 5e-5);
+ BOOST_REQUIRE_CLOSE(lk.Evaluate(b, a), 0.243116734, 5e-5);
+}
+
+// Ensure that the p-spectrum kernel successfully extracts all length-p
+// substrings from the data.
+BOOST_AUTO_TEST_CASE(PSpectrumSubstringExtractionTest)
+{
+ std::vector<std::vector<std::string> > datasets;
+
+ datasets.push_back(std::vector<std::string>());
+
+ datasets[0].push_back("herpgle");
+ datasets[0].push_back("herpagkle");
+ datasets[0].push_back("klunktor");
+ datasets[0].push_back("flibbynopple");
+
+ datasets.push_back(std::vector<std::string>());
+
+ datasets[1].push_back("floggy3245");
+ datasets[1].push_back("flippydopflip");
+ datasets[1].push_back("stupid fricking cat");
+ datasets[1].push_back("food time isn't until later");
+ datasets[1].push_back("leave me alone until 6:00");
+ datasets[1].push_back("only after that do you get any food.");
+ datasets[1].push_back("obloblobloblobloblobloblob");
+
+ PSpectrumStringKernel p(datasets, 3);
+
+ // Ensure the sizes are correct.
+ BOOST_REQUIRE_EQUAL(p.Counts().size(), 2);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0].size(), 4);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1].size(), 7);
+
+ // herpgle: her, erp, rpg, pgl, gle
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][0].size(), 5);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][0]["her"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][0]["erp"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][0]["rpg"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][0]["pgl"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][0]["gle"], 1);
+
+ // herpagkle: her, erp, rpa, pag, agk, gkl, kle
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][1].size(), 7);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][1]["her"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][1]["erp"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][1]["rpa"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][1]["pag"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][1]["agk"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][1]["gkl"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][1]["kle"], 1);
+
+ // klunktor: klu, lun, unk, nkt, kto, tor
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][2].size(), 6);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][2]["klu"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][2]["lun"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][2]["unk"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][2]["nkt"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][2]["kto"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][2]["tor"], 1);
+
+ // flibbynopple: fli lib ibb bby byn yno nop opp ppl ple
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][3].size(), 10);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][3]["fli"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][3]["lib"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][3]["ibb"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][3]["bby"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][3]["byn"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][3]["yno"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][3]["nop"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][3]["opp"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][3]["ppl"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[0][3]["ple"], 1);
+
+ // floggy3245: flo log ogg ggy gy3 y32 324 245
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][0].size(), 8);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][0]["flo"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][0]["log"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][0]["ogg"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][0]["ggy"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][0]["gy3"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][0]["y32"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][0]["324"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][0]["245"], 1);
+
+ // flippydopflip: fli lip ipp ppy pyd ydo dop opf pfl fli lip
+ // fli(2) lip(2) ipp ppy pyd ydo dop opf pfl
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][1].size(), 9);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][1]["fli"], 2);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][1]["lip"], 2);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][1]["ipp"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][1]["ppy"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][1]["pyd"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][1]["ydo"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][1]["dop"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][1]["opf"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][1]["pfl"], 1);
+
+ // stupid fricking cat: stu tup upi pid fri ric ick cki kin ing cat
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][2].size(), 11);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["stu"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["tup"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["upi"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["pid"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["fri"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["ric"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["ick"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["cki"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["kin"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["ing"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][2]["cat"], 1);
+
+ // food time isn't until later: foo ood tim ime isn unt nti til lat ate ter
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][3].size(), 11);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["foo"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["ood"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["tim"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["ime"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["isn"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["unt"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["nti"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["til"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["lat"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["ate"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][3]["ter"], 1);
+
+ // leave me alone until 6:00: lea eav ave alo lon one unt nti til
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][4].size(), 9);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][4]["lea"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][4]["eav"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][4]["ave"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][4]["alo"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][4]["lon"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][4]["one"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][4]["unt"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][4]["nti"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][4]["til"], 1);
+
+ // only after that do you get any food.:
+ // onl nly aft fte ter tha hat you get any foo ood
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][5].size(), 12);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["onl"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["nly"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["aft"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["fte"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["ter"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["tha"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["hat"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["you"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["get"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["any"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["foo"], 1);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][5]["ood"], 1);
+
+ // obloblobloblobloblobloblob: obl(8) blo(8) lob(8)
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][6].size(), 3);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][6]["obl"], 8);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][6]["blo"], 8);
+ BOOST_REQUIRE_EQUAL(p.Counts()[1][6]["lob"], 8);
+}
+
+BOOST_AUTO_TEST_CASE(PSpectrumStringEvaluateTest)
+{
+ // Construct simple dataset.
+ std::vector<std::vector<std::string> > dataset;
+ dataset.push_back(std::vector<std::string>());
+ dataset[0].push_back("hello");
+ dataset[0].push_back("jello");
+ dataset[0].push_back("mellow");
+ dataset[0].push_back("mellow jello");
+
+ PSpectrumStringKernel p(dataset, 3);
+
+ arma::vec a("0 0");
+ arma::vec b("0 0");
+
+ BOOST_REQUIRE_CLOSE(p.Evaluate(a, b), 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(p.Evaluate(b, a), 3.0, 1e-5);
+
+ b = "0 1";
+ BOOST_REQUIRE_CLOSE(p.Evaluate(a, b), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(p.Evaluate(b, a), 2.0, 1e-5);
+
+ b = "0 2";
+ BOOST_REQUIRE_CLOSE(p.Evaluate(a, b), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(p.Evaluate(b, a), 2.0, 1e-5);
+
+ b = "0 3";
+ BOOST_REQUIRE_CLOSE(p.Evaluate(a, b), 4.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(p.Evaluate(b, a), 4.0, 1e-5);
+
+ a = "0 1";
+ b = "0 1";
+ BOOST_REQUIRE_CLOSE(p.Evaluate(a, b), 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(p.Evaluate(b, a), 3.0, 1e-5);
+
+ b = "0 2";
+ BOOST_REQUIRE_CLOSE(p.Evaluate(a, b), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(p.Evaluate(b, a), 2.0, 1e-5);
+
+ b = "0 3";
+ BOOST_REQUIRE_CLOSE(p.Evaluate(a, b), 5.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(p.Evaluate(b, a), 5.0, 1e-5);
+
+ a = "0 2";
+ b = "0 2";
+ BOOST_REQUIRE_CLOSE(p.Evaluate(a, b), 4.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(p.Evaluate(b, a), 4.0, 1e-5);
+
+ b = "0 3";
+ BOOST_REQUIRE_CLOSE(p.Evaluate(a, b), 6.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(p.Evaluate(b, a), 6.0, 1e-5);
+
+ a = "0 3";
+ BOOST_REQUIRE_CLOSE(p.Evaluate(a, b), 11.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(p.Evaluate(b, a), 11.0, 1e-5);
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/kmeans_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/kmeans_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/kmeans_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,512 +0,0 @@
-/**
- * @file kmeans_test.cpp
- * @author Ryan Curtin
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-
-#include <mlpack/methods/kmeans/kmeans.hpp>
-#include <mlpack/methods/kmeans/allow_empty_clusters.hpp>
-#include <mlpack/methods/kmeans/refined_start.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::kmeans;
-
-BOOST_AUTO_TEST_SUITE(KMeansTest);
-
-// Generate dataset; written transposed because it's easier to read.
-arma::mat kMeansData(" 0.0 0.0;" // Class 1.
- " 0.3 0.4;"
- " 0.1 0.0;"
- " 0.1 0.3;"
- " -0.2 -0.2;"
- " -0.1 0.3;"
- " -0.4 0.1;"
- " 0.2 -0.1;"
- " 0.3 0.0;"
- " -0.3 -0.3;"
- " 0.1 -0.1;"
- " 0.2 -0.3;"
- " -0.3 0.2;"
- " 10.0 10.0;" // Class 2.
- " 10.1 9.9;"
- " 9.9 10.0;"
- " 10.2 9.7;"
- " 10.2 9.8;"
- " 9.7 10.3;"
- " 9.9 10.1;"
- "-10.0 5.0;" // Class 3.
- " -9.8 5.1;"
- " -9.9 4.9;"
- "-10.0 4.9;"
- "-10.2 5.2;"
- "-10.1 5.1;"
- "-10.3 5.3;"
- "-10.0 4.8;"
- " -9.6 5.0;"
- " -9.8 5.1;");
-
-/**
- * 30-point 3-class test case for K-Means, with no overclustering.
- */
-BOOST_AUTO_TEST_CASE(KMeansNoOverclusteringTest)
-{
- KMeans<> kmeans; // No overclustering.
-
- arma::Col<size_t> assignments;
- kmeans.Cluster((arma::mat) trans(kMeansData), 3, assignments);
-
- // Now make sure we got it all right. There is no restriction on how the
- // clusters are ordered, so we have to be careful about that.
- size_t firstClass = assignments(0);
-
- for (size_t i = 1; i < 13; i++)
- BOOST_REQUIRE_EQUAL(assignments(i), firstClass);
-
- size_t secondClass = assignments(13);
-
- // To ensure that class 1 != class 2.
- BOOST_REQUIRE_NE(firstClass, secondClass);
-
- for (size_t i = 13; i < 20; i++)
- BOOST_REQUIRE_EQUAL(assignments(i), secondClass);
-
- size_t thirdClass = assignments(20);
-
- // To ensure that this is the third class which we haven't seen yet.
- BOOST_REQUIRE_NE(firstClass, thirdClass);
- BOOST_REQUIRE_NE(secondClass, thirdClass);
-
- for (size_t i = 20; i < 30; i++)
- BOOST_REQUIRE_EQUAL(assignments(i), thirdClass);
-}
-
-/**
- * 30-point 3-class test case for K-Means, with overclustering.
- */
-BOOST_AUTO_TEST_CASE(KMeansOverclusteringTest)
-{
- KMeans<> kmeans(1000, 4.0); // Overclustering factor of 4.0.
-
- arma::Col<size_t> assignments;
- kmeans.Cluster((arma::mat) trans(kMeansData), 3, assignments);
-
- // Now make sure we got it all right. There is no restriction on how the
- // clusters are ordered, so we have to be careful about that.
- size_t firstClass = assignments(0);
-
- for (size_t i = 1; i < 13; i++)
- BOOST_REQUIRE_EQUAL(assignments(i), firstClass);
-
- size_t secondClass = assignments(13);
-
- // To ensure that class 1 != class 2.
- BOOST_REQUIRE_NE(firstClass, secondClass);
-
- for (size_t i = 13; i < 20; i++)
- BOOST_REQUIRE_EQUAL(assignments(i), secondClass);
-
- size_t thirdClass = assignments(20);
-
- // To ensure that this is the third class which we haven't seen yet.
- BOOST_REQUIRE_NE(firstClass, thirdClass);
- BOOST_REQUIRE_NE(secondClass, thirdClass);
-
- for (size_t i = 20; i < 30; i++)
- BOOST_REQUIRE_EQUAL(assignments(i), thirdClass);
-}
-
-/**
- * Make sure the empty cluster policy class does nothing.
- */
-BOOST_AUTO_TEST_CASE(AllowEmptyClusterTest)
-{
- arma::Col<size_t> assignments;
- assignments.randu(30);
- arma::Col<size_t> assignmentsOld = assignments;
-
- arma::mat centroids;
- centroids.randu(30, 3); // This doesn't matter.
-
- arma::Col<size_t> counts(3);
- counts[0] = accu(assignments == 0);
- counts[1] = accu(assignments == 1);
- counts[2] = 0;
- arma::Col<size_t> countsOld = counts;
-
- // Make sure the method doesn't modify any points.
- BOOST_REQUIRE_EQUAL(AllowEmptyClusters::EmptyCluster(kMeansData, 2, centroids,
- counts, assignments), 0);
-
- // Make sure no assignments were changed.
- for (size_t i = 0; i < assignments.n_elem; i++)
- BOOST_REQUIRE_EQUAL(assignments[i], assignmentsOld[i]);
-
- // Make sure no counts were changed.
- for (size_t i = 0; i < 3; i++)
- BOOST_REQUIRE_EQUAL(counts[i], countsOld[i]);
-}
-
-/**
- * Make sure the max variance method finds the correct point.
- */
-BOOST_AUTO_TEST_CASE(MaxVarianceNewClusterTest)
-{
- // Five points.
- arma::mat data("0.4 1.0 5.0 -2.0 -2.5;"
- "1.0 0.8 0.7 5.1 5.2;");
-
- // Point 2 is the mis-clustered point we're looking for to be moved.
- arma::Col<size_t> assignments("0 0 0 1 1");
-
- arma::mat centroids(2, 3);
- centroids.col(0) = (1.0 / 3.0) * (data.col(0) + data.col(1) + data.col(2));
- centroids.col(1) = 0.5 * (data.col(3) + data.col(4));
- centroids(0, 2) = 0;
- centroids(1, 2) = 0;
-
- arma::Col<size_t> counts("3 2 0");
-
- // This should only change one point.
- BOOST_REQUIRE_EQUAL(MaxVarianceNewCluster::EmptyCluster(data, 2, centroids,
- counts, assignments), 1);
-
- // Ensure that the cluster assignments are right.
- BOOST_REQUIRE_EQUAL(assignments[0], 0);
- BOOST_REQUIRE_EQUAL(assignments[1], 0);
- BOOST_REQUIRE_EQUAL(assignments[2], 2);
- BOOST_REQUIRE_EQUAL(assignments[3], 1);
- BOOST_REQUIRE_EQUAL(assignments[4], 1);
-
- // Ensure that the counts are right.
- BOOST_REQUIRE_EQUAL(counts[0], 2);
- BOOST_REQUIRE_EQUAL(counts[1], 2);
- BOOST_REQUIRE_EQUAL(counts[2], 1);
-}
-
-/**
- * Make sure the random partitioner seems to return valid results.
- */
-BOOST_AUTO_TEST_CASE(RandomPartitionTest)
-{
- arma::mat data;
- data.randu(2, 1000); // One thousand points.
-
- arma::Col<size_t> assignments;
-
- // We'll ask for 18 clusters (arbitrary).
- RandomPartition::Cluster(data, 18, assignments);
-
- // Ensure that the right number of assignments were given.
- BOOST_REQUIRE_EQUAL(assignments.n_elem, 1000);
-
- // Ensure that no value is greater than 17 (the maximum valid cluster).
- for (size_t i = 0; i < 1000; i++)
- BOOST_REQUIRE_LT(assignments[i], 18);
-}
-
-/**
- * Make sure that random initialization fails for a corner case dataset.
- */
-BOOST_AUTO_TEST_CASE(RandomInitialAssignmentFailureTest)
-{
- // This is a very synthetic dataset. It is one Gaussian with a huge number of
- // points combined with one faraway Gaussian with very few points. Normally,
- // k-means should not get the correct result -- which is one cluster at each
- // Gaussian. This is because the random partitioning scheme has very low
- // (virtually zero) likelihood of separating the two Gaussians properly, and
- // then the algorithm will never converge to that result.
- //
- // So we will set the initial assignments appropriately. Remember, once the
- // initial assignments are done, k-means is deterministic.
- arma::mat dataset(2, 10002);
- dataset.randn();
- // Now move the second Gaussian far away.
- for (size_t i = 0; i < 2; ++i)
- dataset.col(10000 + i) += arma::vec("50 50");
-
- // Ensure that k-means fails when run with random initialization. This isn't
- // strictly a necessary test, but it does help let us know that this is a good
- // test.
- size_t successes = 0;
- for (size_t run = 0; run < 15; ++run)
- {
- arma::mat centroids;
- arma::Col<size_t> assignments;
- KMeans<> kmeans;
- kmeans.Cluster(dataset, 2, assignments, centroids);
-
- // Inspect centroids. See if one is close to the second Gaussian.
- if ((centroids(0, 0) >= 30.0 && centroids(1, 0) >= 30.0) ||
- (centroids(0, 1) >= 30.0 && centroids(1, 1) >= 30.0))
- ++successes;
- }
-
- // Only one success allowed. The probability of two successes should be
- // infinitesimal.
- BOOST_REQUIRE_LT(successes, 2);
-}
-
-/**
- * Make sure that specifying initial assignments is successful for a corner case
- * dataset which doesn't usually converge otherwise.
- */
-BOOST_AUTO_TEST_CASE(InitialAssignmentTest)
-{
- // For a better description of this dataset, see
- // RandomInitialAssignmentFailureTest.
- arma::mat dataset(2, 10002);
- dataset.randn();
- // Now move the second Gaussian far away.
- for (size_t i = 0; i < 2; ++i)
- dataset.col(10000 + i) += arma::vec("50 50");
-
- // Now, if we specify initial assignments, the algorithm should converge (with
- // zero iterations, actually, because this is the solution).
- arma::Col<size_t> assignments(10002);
- assignments.fill(0);
- assignments[10000] = 1;
- assignments[10001] = 1;
-
- KMeans<> kmeans;
- kmeans.Cluster(dataset, 2, assignments, true);
-
- // Check results.
- for (size_t i = 0; i < 10000; ++i)
- BOOST_REQUIRE_EQUAL(assignments[i], 0);
- for (size_t i = 10000; i < 10002; ++i)
- BOOST_REQUIRE_EQUAL(assignments[i], 1);
-
- // Now, slightly harder. Give it one incorrect assignment in each cluster.
- // The wrong assignment should be quickly fixed.
- assignments[9999] = 1;
- assignments[10000] = 0;
-
- kmeans.Cluster(dataset, 2, assignments, true);
-
- // Check results.
- for (size_t i = 0; i < 10000; ++i)
- BOOST_REQUIRE_EQUAL(assignments[i], 0);
- for (size_t i = 10000; i < 10002; ++i)
- BOOST_REQUIRE_EQUAL(assignments[i], 1);
-}
-
-/**
- * Make sure specifying initial centroids is successful for a corner case which
- * doesn't usually converge otherwise.
- */
-BOOST_AUTO_TEST_CASE(InitialCentroidTest)
-{
- // For a better description of this dataset, see
- // RandomInitialAssignmentFailureTest.
- arma::mat dataset(2, 10002);
- dataset.randn();
- // Now move the second Gaussian far away.
- for (size_t i = 0; i < 2; ++i)
- dataset.col(10000 + i) += arma::vec("50 50");
-
- arma::Col<size_t> assignments;
- arma::mat centroids(2, 2);
-
- centroids.col(0) = arma::vec("0 0");
- centroids.col(1) = arma::vec("50 50");
-
- // This should converge correctly.
- KMeans<> k;
- k.Cluster(dataset, 2, assignments, centroids, false, true);
-
- // Check results.
- for (size_t i = 0; i < 10000; ++i)
- BOOST_REQUIRE_EQUAL(assignments[i], 0);
- for (size_t i = 10000; i < 10002; ++i)
- BOOST_REQUIRE_EQUAL(assignments[i], 1);
-
- // Now add a little noise to the initial centroids.
- centroids.col(0) = arma::vec("3 4");
- centroids.col(1) = arma::vec("25 10");
-
- k.Cluster(dataset, 2, assignments, centroids, false, true);
-
- // Check results.
- for (size_t i = 0; i < 10000; ++i)
- BOOST_REQUIRE_EQUAL(assignments[i], 0);
- for (size_t i = 10000; i < 10002; ++i)
- BOOST_REQUIRE_EQUAL(assignments[i], 1);
-}
-
-/**
- * Ensure that initial assignments override initial centroids.
- */
-BOOST_AUTO_TEST_CASE(InitialAssignmentOverrideTest)
-{
- // For a better description of this dataset, see
- // RandomInitialAssignmentFailureTest.
- arma::mat dataset(2, 10002);
- dataset.randn();
- // Now move the second Gaussian far away.
- for (size_t i = 0; i < 2; ++i)
- dataset.col(10000 + i) += arma::vec("50 50");
-
- arma::Col<size_t> assignments(10002);
- assignments.fill(0);
- assignments[10000] = 1;
- assignments[10001] = 1;
-
- // Note that this initial centroid guess is the opposite of the assignments
- // guess!
- arma::mat centroids(2, 2);
- centroids.col(0) = arma::vec("50 50");
- centroids.col(1) = arma::vec("0 0");
-
- KMeans<> k;
- k.Cluster(dataset, 2, assignments, centroids, true, true);
-
- // Because the initial assignments guess should take priority, we should get
- // those same results back.
- for (size_t i = 0; i < 10000; ++i)
- BOOST_REQUIRE_EQUAL(assignments[i], 0);
- for (size_t i = 10000; i < 10002; ++i)
- BOOST_REQUIRE_EQUAL(assignments[i], 1);
-
- // Make sure the centroids are about right too.
- BOOST_REQUIRE_LT(centroids(0, 0), 10.0);
- BOOST_REQUIRE_LT(centroids(1, 0), 10.0);
- BOOST_REQUIRE_GT(centroids(0, 1), 40.0);
- BOOST_REQUIRE_GT(centroids(1, 1), 40.0);
-}
-
-/**
- * Test that the refined starting policy returns decent initial cluster
- * estimates.
- */
-BOOST_AUTO_TEST_CASE(RefinedStartTest)
-{
- // Our dataset will be five Gaussians of largely varying numbers of points and
- // we expect that the refined starting policy should return good guesses at
- // what these Gaussians are.
- math::RandomSeed(std::time(NULL));
- arma::mat data(3, 3000);
- data.randn();
-
- // First Gaussian: 10000 points, centered at (0, 0, 0).
- // Second Gaussian: 2000 points, centered at (5, 0, -2).
- // Third Gaussian: 5000 points, centered at (-2, -2, -2).
- // Fourth Gaussian: 1000 points, centered at (-6, 8, 8).
- // Fifth Gaussian: 12000 points, centered at (1, 6, 1).
- arma::mat centroids(" 0 5 -2 -6 1;"
- " 0 0 -2 8 6;"
- " 0 -2 -2 8 1");
-
- for (size_t i = 1000; i < 1200; ++i)
- data.col(i) += centroids.col(1);
- for (size_t i = 1200; i < 1700; ++i)
- data.col(i) += centroids.col(2);
- for (size_t i = 1700; i < 1800; ++i)
- data.col(i) += centroids.col(3);
- for (size_t i = 1800; i < 3000; ++i)
- data.col(i) += centroids.col(4);
-
- // Now run the RefinedStart algorithm and make sure it doesn't deviate too
- // much from the actual solution.
- RefinedStart rs;
- arma::Col<size_t> assignments;
- arma::mat resultingCentroids;
- rs.Cluster(data, 5, assignments);
-
- // Calculate resulting centroids.
- resultingCentroids.zeros(3, 5);
- arma::Col<size_t> counts(5);
- counts.zeros();
- for (size_t i = 0; i < 3000; ++i)
- {
- resultingCentroids.col(assignments[i]) += data.col(i);
- ++counts[assignments[i]];
- }
-
- // Normalize centroids.
- for (size_t i = 0; i < 5; ++i)
- if (counts[i] != 0)
- resultingCentroids /= counts[i];
-
- // Calculate sum of distances from centroid means.
- double distortion = 0;
- for (size_t i = 0; i < 3000; ++i)
- distortion += metric::EuclideanDistance::Evaluate(data.col(i),
- resultingCentroids.col(assignments[i]));
-
- // Using the refined start, the distance for this dataset is usually around
- // 13500. Regular k-means is between 10000 and 30000 (I think the 10000
- // figure is a corner case which actually does not give good clusters), and
- // random initial starts give distortion around 22000. So we'll require that
- // our distortion is less than 14000.
- BOOST_REQUIRE_LT(distortion, 14000.0);
-}
-
-#ifdef ARMA_HAS_SPMAT
-// Can't do this test on Armadillo 3.4; var(SpBase) is not implemented.
-#if !((ARMA_VERSION_MAJOR == 3) && (ARMA_VERSION_MINOR == 4))
-
-/**
- * Make sure sparse k-means works okay.
- */
-BOOST_AUTO_TEST_CASE(SparseKMeansTest)
-{
- // Huge dimensionality, few points.
- arma::SpMat<double> data(5000, 12);
- data(14, 0) = 6.4;
- data(14, 1) = 6.3;
- data(14, 2) = 6.5;
- data(14, 3) = 6.2;
- data(14, 4) = 6.1;
- data(14, 5) = 6.6;
- data(1402, 6) = -3.2;
- data(1402, 7) = -3.3;
- data(1402, 8) = -3.1;
- data(1402, 9) = -3.4;
- data(1402, 10) = -3.5;
- data(1402, 11) = -3.0;
-
- arma::Col<size_t> assignments;
-
- KMeans<> kmeans; // Default options.
-
- kmeans.Cluster(data, 2, assignments);
-
- size_t clusterOne = assignments[0];
- size_t clusterTwo = assignments[6];
-
- BOOST_REQUIRE_EQUAL(assignments[0], clusterOne);
- BOOST_REQUIRE_EQUAL(assignments[1], clusterOne);
- BOOST_REQUIRE_EQUAL(assignments[2], clusterOne);
- BOOST_REQUIRE_EQUAL(assignments[3], clusterOne);
- BOOST_REQUIRE_EQUAL(assignments[4], clusterOne);
- BOOST_REQUIRE_EQUAL(assignments[5], clusterOne);
- BOOST_REQUIRE_EQUAL(assignments[6], clusterTwo);
- BOOST_REQUIRE_EQUAL(assignments[7], clusterTwo);
- BOOST_REQUIRE_EQUAL(assignments[8], clusterTwo);
- BOOST_REQUIRE_EQUAL(assignments[9], clusterTwo);
- BOOST_REQUIRE_EQUAL(assignments[10], clusterTwo);
- BOOST_REQUIRE_EQUAL(assignments[11], clusterTwo);
-}
-
-#endif // Exclude Armadillo 3.4.
-#endif // ARMA_HAS_SPMAT
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/kmeans_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/kmeans_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/kmeans_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/kmeans_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,512 @@
+/**
+ * @file kmeans_test.cpp
+ * @author Ryan Curtin
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+
+#include <mlpack/methods/kmeans/kmeans.hpp>
+#include <mlpack/methods/kmeans/allow_empty_clusters.hpp>
+#include <mlpack/methods/kmeans/refined_start.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::kmeans;
+
+BOOST_AUTO_TEST_SUITE(KMeansTest);
+
+// Generate dataset; written transposed because it's easier to read.
+arma::mat kMeansData(" 0.0 0.0;" // Class 1.
+ " 0.3 0.4;"
+ " 0.1 0.0;"
+ " 0.1 0.3;"
+ " -0.2 -0.2;"
+ " -0.1 0.3;"
+ " -0.4 0.1;"
+ " 0.2 -0.1;"
+ " 0.3 0.0;"
+ " -0.3 -0.3;"
+ " 0.1 -0.1;"
+ " 0.2 -0.3;"
+ " -0.3 0.2;"
+ " 10.0 10.0;" // Class 2.
+ " 10.1 9.9;"
+ " 9.9 10.0;"
+ " 10.2 9.7;"
+ " 10.2 9.8;"
+ " 9.7 10.3;"
+ " 9.9 10.1;"
+ "-10.0 5.0;" // Class 3.
+ " -9.8 5.1;"
+ " -9.9 4.9;"
+ "-10.0 4.9;"
+ "-10.2 5.2;"
+ "-10.1 5.1;"
+ "-10.3 5.3;"
+ "-10.0 4.8;"
+ " -9.6 5.0;"
+ " -9.8 5.1;");
+
+/**
+ * 30-point 3-class test case for K-Means, with no overclustering.
+ */
+BOOST_AUTO_TEST_CASE(KMeansNoOverclusteringTest)
+{
+ KMeans<> kmeans; // No overclustering.
+
+ arma::Col<size_t> assignments;
+ kmeans.Cluster((arma::mat) trans(kMeansData), 3, assignments);
+
+ // Now make sure we got it all right. There is no restriction on how the
+ // clusters are ordered, so we have to be careful about that.
+ size_t firstClass = assignments(0);
+
+ for (size_t i = 1; i < 13; i++)
+ BOOST_REQUIRE_EQUAL(assignments(i), firstClass);
+
+ size_t secondClass = assignments(13);
+
+ // To ensure that class 1 != class 2.
+ BOOST_REQUIRE_NE(firstClass, secondClass);
+
+ for (size_t i = 13; i < 20; i++)
+ BOOST_REQUIRE_EQUAL(assignments(i), secondClass);
+
+ size_t thirdClass = assignments(20);
+
+ // To ensure that this is the third class which we haven't seen yet.
+ BOOST_REQUIRE_NE(firstClass, thirdClass);
+ BOOST_REQUIRE_NE(secondClass, thirdClass);
+
+ for (size_t i = 20; i < 30; i++)
+ BOOST_REQUIRE_EQUAL(assignments(i), thirdClass);
+}
+
+/**
+ * 30-point 3-class test case for K-Means, with overclustering.
+ */
+BOOST_AUTO_TEST_CASE(KMeansOverclusteringTest)
+{
+ KMeans<> kmeans(1000, 4.0); // Overclustering factor of 4.0.
+
+ arma::Col<size_t> assignments;
+ kmeans.Cluster((arma::mat) trans(kMeansData), 3, assignments);
+
+ // Now make sure we got it all right. There is no restriction on how the
+ // clusters are ordered, so we have to be careful about that.
+ size_t firstClass = assignments(0);
+
+ for (size_t i = 1; i < 13; i++)
+ BOOST_REQUIRE_EQUAL(assignments(i), firstClass);
+
+ size_t secondClass = assignments(13);
+
+ // To ensure that class 1 != class 2.
+ BOOST_REQUIRE_NE(firstClass, secondClass);
+
+ for (size_t i = 13; i < 20; i++)
+ BOOST_REQUIRE_EQUAL(assignments(i), secondClass);
+
+ size_t thirdClass = assignments(20);
+
+ // To ensure that this is the third class which we haven't seen yet.
+ BOOST_REQUIRE_NE(firstClass, thirdClass);
+ BOOST_REQUIRE_NE(secondClass, thirdClass);
+
+ for (size_t i = 20; i < 30; i++)
+ BOOST_REQUIRE_EQUAL(assignments(i), thirdClass);
+}
+
+/**
+ * Make sure the empty cluster policy class does nothing.
+ */
+BOOST_AUTO_TEST_CASE(AllowEmptyClusterTest)
+{
+ arma::Col<size_t> assignments;
+ assignments.randu(30);
+ arma::Col<size_t> assignmentsOld = assignments;
+
+ arma::mat centroids;
+ centroids.randu(30, 3); // This doesn't matter.
+
+ arma::Col<size_t> counts(3);
+ counts[0] = accu(assignments == 0);
+ counts[1] = accu(assignments == 1);
+ counts[2] = 0;
+ arma::Col<size_t> countsOld = counts;
+
+ // Make sure the method doesn't modify any points.
+ BOOST_REQUIRE_EQUAL(AllowEmptyClusters::EmptyCluster(kMeansData, 2, centroids,
+ counts, assignments), 0);
+
+ // Make sure no assignments were changed.
+ for (size_t i = 0; i < assignments.n_elem; i++)
+ BOOST_REQUIRE_EQUAL(assignments[i], assignmentsOld[i]);
+
+ // Make sure no counts were changed.
+ for (size_t i = 0; i < 3; i++)
+ BOOST_REQUIRE_EQUAL(counts[i], countsOld[i]);
+}
+
+/**
+ * Make sure the max variance method finds the correct point.
+ */
+BOOST_AUTO_TEST_CASE(MaxVarianceNewClusterTest)
+{
+ // Five points.
+ arma::mat data("0.4 1.0 5.0 -2.0 -2.5;"
+ "1.0 0.8 0.7 5.1 5.2;");
+
+ // Point 2 is the mis-clustered point we're looking for to be moved.
+ arma::Col<size_t> assignments("0 0 0 1 1");
+
+ arma::mat centroids(2, 3);
+ centroids.col(0) = (1.0 / 3.0) * (data.col(0) + data.col(1) + data.col(2));
+ centroids.col(1) = 0.5 * (data.col(3) + data.col(4));
+ centroids(0, 2) = 0;
+ centroids(1, 2) = 0;
+
+ arma::Col<size_t> counts("3 2 0");
+
+ // This should only change one point.
+ BOOST_REQUIRE_EQUAL(MaxVarianceNewCluster::EmptyCluster(data, 2, centroids,
+ counts, assignments), 1);
+
+ // Ensure that the cluster assignments are right.
+ BOOST_REQUIRE_EQUAL(assignments[0], 0);
+ BOOST_REQUIRE_EQUAL(assignments[1], 0);
+ BOOST_REQUIRE_EQUAL(assignments[2], 2);
+ BOOST_REQUIRE_EQUAL(assignments[3], 1);
+ BOOST_REQUIRE_EQUAL(assignments[4], 1);
+
+ // Ensure that the counts are right.
+ BOOST_REQUIRE_EQUAL(counts[0], 2);
+ BOOST_REQUIRE_EQUAL(counts[1], 2);
+ BOOST_REQUIRE_EQUAL(counts[2], 1);
+}
+
+/**
+ * Make sure the random partitioner seems to return valid results.
+ */
+BOOST_AUTO_TEST_CASE(RandomPartitionTest)
+{
+ arma::mat data;
+ data.randu(2, 1000); // One thousand points.
+
+ arma::Col<size_t> assignments;
+
+ // We'll ask for 18 clusters (arbitrary).
+ RandomPartition::Cluster(data, 18, assignments);
+
+ // Ensure that the right number of assignments were given.
+ BOOST_REQUIRE_EQUAL(assignments.n_elem, 1000);
+
+ // Ensure that no value is greater than 17 (the maximum valid cluster).
+ for (size_t i = 0; i < 1000; i++)
+ BOOST_REQUIRE_LT(assignments[i], 18);
+}
+
+/**
+ * Make sure that random initialization fails for a corner case dataset.
+ */
+BOOST_AUTO_TEST_CASE(RandomInitialAssignmentFailureTest)
+{
+ // This is a very synthetic dataset. It is one Gaussian with a huge number of
+ // points combined with one faraway Gaussian with very few points. Normally,
+ // k-means should not get the correct result -- which is one cluster at each
+ // Gaussian. This is because the random partitioning scheme has very low
+ // (virtually zero) likelihood of separating the two Gaussians properly, and
+ // then the algorithm will never converge to that result.
+ //
+ // So we will set the initial assignments appropriately. Remember, once the
+ // initial assignments are done, k-means is deterministic.
+ arma::mat dataset(2, 10002);
+ dataset.randn();
+ // Now move the second Gaussian far away.
+ for (size_t i = 0; i < 2; ++i)
+ dataset.col(10000 + i) += arma::vec("50 50");
+
+ // Ensure that k-means fails when run with random initialization. This isn't
+ // strictly a necessary test, but it does help let us know that this is a good
+ // test.
+ size_t successes = 0;
+ for (size_t run = 0; run < 15; ++run)
+ {
+ arma::mat centroids;
+ arma::Col<size_t> assignments;
+ KMeans<> kmeans;
+ kmeans.Cluster(dataset, 2, assignments, centroids);
+
+ // Inspect centroids. See if one is close to the second Gaussian.
+ if ((centroids(0, 0) >= 30.0 && centroids(1, 0) >= 30.0) ||
+ (centroids(0, 1) >= 30.0 && centroids(1, 1) >= 30.0))
+ ++successes;
+ }
+
+ // Only one success allowed. The probability of two successes should be
+ // infinitesimal.
+ BOOST_REQUIRE_LT(successes, 2);
+}
+
+/**
+ * Make sure that specifying initial assignments is successful for a corner case
+ * dataset which doesn't usually converge otherwise.
+ */
+BOOST_AUTO_TEST_CASE(InitialAssignmentTest)
+{
+ // For a better description of this dataset, see
+ // RandomInitialAssignmentFailureTest.
+ arma::mat dataset(2, 10002);
+ dataset.randn();
+ // Now move the second Gaussian far away.
+ for (size_t i = 0; i < 2; ++i)
+ dataset.col(10000 + i) += arma::vec("50 50");
+
+ // Now, if we specify initial assignments, the algorithm should converge (with
+ // zero iterations, actually, because this is the solution).
+ arma::Col<size_t> assignments(10002);
+ assignments.fill(0);
+ assignments[10000] = 1;
+ assignments[10001] = 1;
+
+ KMeans<> kmeans;
+ kmeans.Cluster(dataset, 2, assignments, true);
+
+ // Check results.
+ for (size_t i = 0; i < 10000; ++i)
+ BOOST_REQUIRE_EQUAL(assignments[i], 0);
+ for (size_t i = 10000; i < 10002; ++i)
+ BOOST_REQUIRE_EQUAL(assignments[i], 1);
+
+ // Now, slightly harder. Give it one incorrect assignment in each cluster.
+ // The wrong assignment should be quickly fixed.
+ assignments[9999] = 1;
+ assignments[10000] = 0;
+
+ kmeans.Cluster(dataset, 2, assignments, true);
+
+ // Check results.
+ for (size_t i = 0; i < 10000; ++i)
+ BOOST_REQUIRE_EQUAL(assignments[i], 0);
+ for (size_t i = 10000; i < 10002; ++i)
+ BOOST_REQUIRE_EQUAL(assignments[i], 1);
+}
+
+/**
+ * Make sure specifying initial centroids is successful for a corner case which
+ * doesn't usually converge otherwise.
+ */
+BOOST_AUTO_TEST_CASE(InitialCentroidTest)
+{
+ // For a better description of this dataset, see
+ // RandomInitialAssignmentFailureTest.
+ arma::mat dataset(2, 10002);
+ dataset.randn();
+ // Now move the second Gaussian far away.
+ for (size_t i = 0; i < 2; ++i)
+ dataset.col(10000 + i) += arma::vec("50 50");
+
+ arma::Col<size_t> assignments;
+ arma::mat centroids(2, 2);
+
+ centroids.col(0) = arma::vec("0 0");
+ centroids.col(1) = arma::vec("50 50");
+
+ // This should converge correctly.
+ KMeans<> k;
+ k.Cluster(dataset, 2, assignments, centroids, false, true);
+
+ // Check results.
+ for (size_t i = 0; i < 10000; ++i)
+ BOOST_REQUIRE_EQUAL(assignments[i], 0);
+ for (size_t i = 10000; i < 10002; ++i)
+ BOOST_REQUIRE_EQUAL(assignments[i], 1);
+
+ // Now add a little noise to the initial centroids.
+ centroids.col(0) = arma::vec("3 4");
+ centroids.col(1) = arma::vec("25 10");
+
+ k.Cluster(dataset, 2, assignments, centroids, false, true);
+
+ // Check results.
+ for (size_t i = 0; i < 10000; ++i)
+ BOOST_REQUIRE_EQUAL(assignments[i], 0);
+ for (size_t i = 10000; i < 10002; ++i)
+ BOOST_REQUIRE_EQUAL(assignments[i], 1);
+}
+
+/**
+ * Ensure that initial assignments override initial centroids.
+ */
+BOOST_AUTO_TEST_CASE(InitialAssignmentOverrideTest)
+{
+ // For a better description of this dataset, see
+ // RandomInitialAssignmentFailureTest.
+ arma::mat dataset(2, 10002);
+ dataset.randn();
+ // Now move the second Gaussian far away.
+ for (size_t i = 0; i < 2; ++i)
+ dataset.col(10000 + i) += arma::vec("50 50");
+
+ arma::Col<size_t> assignments(10002);
+ assignments.fill(0);
+ assignments[10000] = 1;
+ assignments[10001] = 1;
+
+ // Note that this initial centroid guess is the opposite of the assignments
+ // guess!
+ arma::mat centroids(2, 2);
+ centroids.col(0) = arma::vec("50 50");
+ centroids.col(1) = arma::vec("0 0");
+
+ KMeans<> k;
+ k.Cluster(dataset, 2, assignments, centroids, true, true);
+
+ // Because the initial assignments guess should take priority, we should get
+ // those same results back.
+ for (size_t i = 0; i < 10000; ++i)
+ BOOST_REQUIRE_EQUAL(assignments[i], 0);
+ for (size_t i = 10000; i < 10002; ++i)
+ BOOST_REQUIRE_EQUAL(assignments[i], 1);
+
+ // Make sure the centroids are about right too.
+ BOOST_REQUIRE_LT(centroids(0, 0), 10.0);
+ BOOST_REQUIRE_LT(centroids(1, 0), 10.0);
+ BOOST_REQUIRE_GT(centroids(0, 1), 40.0);
+ BOOST_REQUIRE_GT(centroids(1, 1), 40.0);
+}
+
+/**
+ * Test that the refined starting policy returns decent initial cluster
+ * estimates.
+ */
+BOOST_AUTO_TEST_CASE(RefinedStartTest)
+{
+ // Our dataset will be five Gaussians of largely varying numbers of points and
+ // we expect that the refined starting policy should return good guesses at
+ // what these Gaussians are.
+ math::RandomSeed(std::time(NULL));
+ arma::mat data(3, 3000);
+ data.randn();
+
+ // First Gaussian: 10000 points, centered at (0, 0, 0).
+ // Second Gaussian: 2000 points, centered at (5, 0, -2).
+ // Third Gaussian: 5000 points, centered at (-2, -2, -2).
+ // Fourth Gaussian: 1000 points, centered at (-6, 8, 8).
+ // Fifth Gaussian: 12000 points, centered at (1, 6, 1).
+ arma::mat centroids(" 0 5 -2 -6 1;"
+ " 0 0 -2 8 6;"
+ " 0 -2 -2 8 1");
+
+ for (size_t i = 1000; i < 1200; ++i)
+ data.col(i) += centroids.col(1);
+ for (size_t i = 1200; i < 1700; ++i)
+ data.col(i) += centroids.col(2);
+ for (size_t i = 1700; i < 1800; ++i)
+ data.col(i) += centroids.col(3);
+ for (size_t i = 1800; i < 3000; ++i)
+ data.col(i) += centroids.col(4);
+
+ // Now run the RefinedStart algorithm and make sure it doesn't deviate too
+ // much from the actual solution.
+ RefinedStart rs;
+ arma::Col<size_t> assignments;
+ arma::mat resultingCentroids;
+ rs.Cluster(data, 5, assignments);
+
+ // Calculate resulting centroids.
+ resultingCentroids.zeros(3, 5);
+ arma::Col<size_t> counts(5);
+ counts.zeros();
+ for (size_t i = 0; i < 3000; ++i)
+ {
+ resultingCentroids.col(assignments[i]) += data.col(i);
+ ++counts[assignments[i]];
+ }
+
+ // Normalize centroids.
+ for (size_t i = 0; i < 5; ++i)
+ if (counts[i] != 0)
+ resultingCentroids /= counts[i];
+
+ // Calculate sum of distances from centroid means.
+ double distortion = 0;
+ for (size_t i = 0; i < 3000; ++i)
+ distortion += metric::EuclideanDistance::Evaluate(data.col(i),
+ resultingCentroids.col(assignments[i]));
+
+ // Using the refined start, the distance for this dataset is usually around
+ // 13500. Regular k-means is between 10000 and 30000 (I think the 10000
+ // figure is a corner case which actually does not give good clusters), and
+ // random initial starts give distortion around 22000. So we'll require that
+ // our distortion is less than 14000.
+ BOOST_REQUIRE_LT(distortion, 14000.0);
+}
+
+#ifdef ARMA_HAS_SPMAT
+// Can't do this test on Armadillo 3.4; var(SpBase) is not implemented.
+#if !((ARMA_VERSION_MAJOR == 3) && (ARMA_VERSION_MINOR == 4))
+
+/**
+ * Make sure sparse k-means works okay.
+ */
+BOOST_AUTO_TEST_CASE(SparseKMeansTest)
+{
+ // Huge dimensionality, few points.
+ arma::SpMat<double> data(5000, 12);
+ data(14, 0) = 6.4;
+ data(14, 1) = 6.3;
+ data(14, 2) = 6.5;
+ data(14, 3) = 6.2;
+ data(14, 4) = 6.1;
+ data(14, 5) = 6.6;
+ data(1402, 6) = -3.2;
+ data(1402, 7) = -3.3;
+ data(1402, 8) = -3.1;
+ data(1402, 9) = -3.4;
+ data(1402, 10) = -3.5;
+ data(1402, 11) = -3.0;
+
+ arma::Col<size_t> assignments;
+
+ KMeans<> kmeans; // Default options.
+
+ kmeans.Cluster(data, 2, assignments);
+
+ size_t clusterOne = assignments[0];
+ size_t clusterTwo = assignments[6];
+
+ BOOST_REQUIRE_EQUAL(assignments[0], clusterOne);
+ BOOST_REQUIRE_EQUAL(assignments[1], clusterOne);
+ BOOST_REQUIRE_EQUAL(assignments[2], clusterOne);
+ BOOST_REQUIRE_EQUAL(assignments[3], clusterOne);
+ BOOST_REQUIRE_EQUAL(assignments[4], clusterOne);
+ BOOST_REQUIRE_EQUAL(assignments[5], clusterOne);
+ BOOST_REQUIRE_EQUAL(assignments[6], clusterTwo);
+ BOOST_REQUIRE_EQUAL(assignments[7], clusterTwo);
+ BOOST_REQUIRE_EQUAL(assignments[8], clusterTwo);
+ BOOST_REQUIRE_EQUAL(assignments[9], clusterTwo);
+ BOOST_REQUIRE_EQUAL(assignments[10], clusterTwo);
+ BOOST_REQUIRE_EQUAL(assignments[11], clusterTwo);
+}
+
+#endif // Exclude Armadillo 3.4.
+#endif // ARMA_HAS_SPMAT
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lars_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/lars_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lars_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,124 +0,0 @@
-/**
- * @file lars_test.cpp
- *
- * Test for LARS
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-// Note: We don't use BOOST_REQUIRE_CLOSE in the code below because we need
-// to use FPC_WEAK, and it's not at all intuitive how to do that.
-
-
-#include <armadillo>
-#include <mlpack/methods/lars/lars.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::regression;
-
-BOOST_AUTO_TEST_SUITE(LARSTest);
-
-void GenerateProblem(arma::mat& X, arma::vec& y, size_t nPoints, size_t nDims)
-{
- X = arma::randn(nDims, nPoints);
- arma::vec beta = arma::randn(nDims, 1);
- y = trans(X) * beta;
-}
-
-
-void LARSVerifyCorrectness(arma::vec beta, arma::vec errCorr, double lambda)
-{
- size_t nDims = beta.n_elem;
- const double tol = 1e-12;
- for(size_t j = 0; j < nDims; j++)
- {
- if (beta(j) == 0)
- {
- // Make sure that |errCorr(j)| <= lambda.
- BOOST_REQUIRE_SMALL(std::max(fabs(errCorr(j)) - lambda, 0.0), tol);
- }
- else if (beta(j) < 0)
- {
- // Make sure that errCorr(j) == lambda.
- BOOST_REQUIRE_SMALL(errCorr(j) - lambda, tol);
- }
- else // beta(j) > 0
- {
- // Make sure that errCorr(j) == -lambda.
- BOOST_REQUIRE_SMALL(errCorr(j) + lambda, tol);
- }
- }
-}
-
-
-void LassoTest(size_t nPoints, size_t nDims, bool elasticNet, bool useCholesky)
-{
- arma::mat X;
- arma::vec y;
-
- for(size_t i = 0; i < 100; i++)
- {
- GenerateProblem(X, y, nPoints, nDims);
-
- // Armadillo's median is broken, so...
- arma::vec sortedAbsCorr = sort(abs(X * y));
- double lambda1 = sortedAbsCorr(nDims / 2);
- double lambda2;
- if (elasticNet)
- lambda2 = lambda1 / 2;
- else
- lambda2 = 0;
-
-
- LARS lars(useCholesky, lambda1, lambda2);
- arma::vec betaOpt;
- lars.Regress(X, y, betaOpt);
-
- arma::vec errCorr = (X * trans(X) + lambda2 *
- arma::eye(nDims, nDims)) * betaOpt - X * y;
-
- LARSVerifyCorrectness(betaOpt, errCorr, lambda1);
- }
-}
-
-
-BOOST_AUTO_TEST_CASE(LARSTestLassoCholesky)
-{
- LassoTest(100, 10, false, true);
-}
-
-
-BOOST_AUTO_TEST_CASE(LARSTestLassoGram)
-{
- LassoTest(100, 10, false, false);
-}
-
-
-BOOST_AUTO_TEST_CASE(LARSTestElasticNetCholesky)
-{
- LassoTest(100, 10, true, true);
-}
-
-
-BOOST_AUTO_TEST_CASE(LARSTestElasticNetGram)
-{
- LassoTest(100, 10, true, false);
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lars_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/lars_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lars_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lars_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,124 @@
+/**
+ * @file lars_test.cpp
+ *
+ * Test for LARS
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+// Note: We don't use BOOST_REQUIRE_CLOSE in the code below because we need
+// to use FPC_WEAK, and it's not at all intuitive how to do that.
+
+
+#include <armadillo>
+#include <mlpack/methods/lars/lars.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::regression;
+
+BOOST_AUTO_TEST_SUITE(LARSTest);
+
+void GenerateProblem(arma::mat& X, arma::vec& y, size_t nPoints, size_t nDims)
+{
+ X = arma::randn(nDims, nPoints);
+ arma::vec beta = arma::randn(nDims, 1);
+ y = trans(X) * beta;
+}
+
+
+void LARSVerifyCorrectness(arma::vec beta, arma::vec errCorr, double lambda)
+{
+ size_t nDims = beta.n_elem;
+ const double tol = 1e-12;
+ for(size_t j = 0; j < nDims; j++)
+ {
+ if (beta(j) == 0)
+ {
+ // Make sure that |errCorr(j)| <= lambda.
+ BOOST_REQUIRE_SMALL(std::max(fabs(errCorr(j)) - lambda, 0.0), tol);
+ }
+ else if (beta(j) < 0)
+ {
+ // Make sure that errCorr(j) == lambda.
+ BOOST_REQUIRE_SMALL(errCorr(j) - lambda, tol);
+ }
+ else // beta(j) > 0
+ {
+ // Make sure that errCorr(j) == -lambda.
+ BOOST_REQUIRE_SMALL(errCorr(j) + lambda, tol);
+ }
+ }
+}
+
+
+void LassoTest(size_t nPoints, size_t nDims, bool elasticNet, bool useCholesky)
+{
+ arma::mat X;
+ arma::vec y;
+
+ for(size_t i = 0; i < 100; i++)
+ {
+ GenerateProblem(X, y, nPoints, nDims);
+
+ // Armadillo's median is broken, so...
+ arma::vec sortedAbsCorr = sort(abs(X * y));
+ double lambda1 = sortedAbsCorr(nDims / 2);
+ double lambda2;
+ if (elasticNet)
+ lambda2 = lambda1 / 2;
+ else
+ lambda2 = 0;
+
+
+ LARS lars(useCholesky, lambda1, lambda2);
+ arma::vec betaOpt;
+ lars.Regress(X, y, betaOpt);
+
+ arma::vec errCorr = (X * trans(X) + lambda2 *
+ arma::eye(nDims, nDims)) * betaOpt - X * y;
+
+ LARSVerifyCorrectness(betaOpt, errCorr, lambda1);
+ }
+}
+
+
+BOOST_AUTO_TEST_CASE(LARSTestLassoCholesky)
+{
+ LassoTest(100, 10, false, true);
+}
+
+
+BOOST_AUTO_TEST_CASE(LARSTestLassoGram)
+{
+ LassoTest(100, 10, false, false);
+}
+
+
+BOOST_AUTO_TEST_CASE(LARSTestElasticNetCholesky)
+{
+ LassoTest(100, 10, true, true);
+}
+
+
+BOOST_AUTO_TEST_CASE(LARSTestElasticNetGram)
+{
+ LassoTest(100, 10, true, false);
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lbfgs_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/lbfgs_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lbfgs_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,130 +0,0 @@
-/**
- * @file lbfgs_test.cpp
- *
- * Tests the L-BFGS optimizer on a couple test functions.
- *
- * @author Ryan Curtin (gth671b at mail.gatech.edu)
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/core/optimizers/lbfgs/lbfgs.hpp>
-#include <mlpack/core/optimizers/lbfgs/test_functions.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack::optimization;
-using namespace mlpack::optimization::test;
-
-BOOST_AUTO_TEST_SUITE(LBFGSTest);
-
-/**
- * Tests the L-BFGS optimizer using the Rosenbrock Function.
- */
-BOOST_AUTO_TEST_CASE(RosenbrockFunction)
-{
- RosenbrockFunction f;
- L_BFGS<RosenbrockFunction> lbfgs(f);
- lbfgs.MaxIterations() = 10000;
-
- arma::vec coords = f.GetInitialPoint();
- if (!lbfgs.Optimize(coords))
- BOOST_FAIL("L-BFGS optimization reported failure.");
-
- double finalValue = f.Evaluate(coords);
-
- BOOST_REQUIRE_SMALL(finalValue, 1e-5);
- BOOST_REQUIRE_CLOSE(coords[0], 1, 1e-5);
- BOOST_REQUIRE_CLOSE(coords[1], 1, 1e-5);
-}
-
-/**
- * Tests the L-BFGS optimizer using the Wood Function.
- */
-BOOST_AUTO_TEST_CASE(WoodFunction)
-{
- WoodFunction f;
- L_BFGS<WoodFunction> lbfgs(f);
- lbfgs.MaxIterations() = 10000;
-
- arma::vec coords = f.GetInitialPoint();
- if (!lbfgs.Optimize(coords))
- BOOST_FAIL("L-BFGS optimization reported failure.");
-
- double finalValue = f.Evaluate(coords);
-
- BOOST_REQUIRE_SMALL(finalValue, 1e-5);
- BOOST_REQUIRE_CLOSE(coords[0], 1, 1e-5);
- BOOST_REQUIRE_CLOSE(coords[1], 1, 1e-5);
- BOOST_REQUIRE_CLOSE(coords[2], 1, 1e-5);
- BOOST_REQUIRE_CLOSE(coords[3], 1, 1e-5);
-}
-
-/**
- * Tests the L-BFGS optimizer using the generalized Rosenbrock function. This
- * is actually multiple tests, increasing the dimension by powers of 2, from 4
- * dimensions to 1024 dimensions.
- */
-BOOST_AUTO_TEST_CASE(GeneralizedRosenbrockFunction)
-{
- for (int i = 2; i < 10; i++)
- {
- // Dimension: powers of 2
- int dim = std::pow(2, i);
-
- GeneralizedRosenbrockFunction f(dim);
- L_BFGS<GeneralizedRosenbrockFunction> lbfgs(f, 20);
- lbfgs.MaxIterations() = 10000;
-
- arma::vec coords = f.GetInitialPoint();
- if (!lbfgs.Optimize(coords))
- BOOST_FAIL("L-BFGS optimization reported failure.");
-
- double finalValue = f.Evaluate(coords);
-
- // Test the output to make sure it is correct.
- BOOST_REQUIRE_SMALL(finalValue, 1e-5);
- for (int j = 0; j < dim; j++)
- BOOST_REQUIRE_CLOSE(coords[j], 1, 1e-5);
- }
-}
-
-/**
- * Tests the L-BFGS optimizer using the Rosenbrock-Wood combined function. This
- * is a test on optimizing a matrix of coordinates.
- */
-BOOST_AUTO_TEST_CASE(RosenbrockWoodFunction)
-{
- RosenbrockWoodFunction f;
- L_BFGS<RosenbrockWoodFunction> lbfgs(f);
- lbfgs.MaxIterations() = 10000;
-
- arma::mat coords = f.GetInitialPoint();
- if (!lbfgs.Optimize(coords))
- BOOST_FAIL("L-BFGS optimization reported failure.");
-
- double finalValue = f.Evaluate(coords);
-
- BOOST_REQUIRE_SMALL(finalValue, 1e-5);
- for (int row = 0; row < 4; row++)
- {
- BOOST_REQUIRE_CLOSE((coords(row, 0)), 1, 1e-5);
- BOOST_REQUIRE_CLOSE((coords(row, 1)), 1, 1e-5);
- }
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lbfgs_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/lbfgs_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lbfgs_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lbfgs_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,130 @@
+/**
+ * @file lbfgs_test.cpp
+ *
+ * Tests the L-BFGS optimizer on a couple test functions.
+ *
+ * @author Ryan Curtin (gth671b at mail.gatech.edu)
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/optimizers/lbfgs/lbfgs.hpp>
+#include <mlpack/core/optimizers/lbfgs/test_functions.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack::optimization;
+using namespace mlpack::optimization::test;
+
+BOOST_AUTO_TEST_SUITE(LBFGSTest);
+
+/**
+ * Tests the L-BFGS optimizer using the Rosenbrock Function.
+ */
+BOOST_AUTO_TEST_CASE(RosenbrockFunction)
+{
+ RosenbrockFunction f;
+ L_BFGS<RosenbrockFunction> lbfgs(f);
+ lbfgs.MaxIterations() = 10000;
+
+ arma::vec coords = f.GetInitialPoint();
+ if (!lbfgs.Optimize(coords))
+ BOOST_FAIL("L-BFGS optimization reported failure.");
+
+ double finalValue = f.Evaluate(coords);
+
+ BOOST_REQUIRE_SMALL(finalValue, 1e-5);
+ BOOST_REQUIRE_CLOSE(coords[0], 1, 1e-5);
+ BOOST_REQUIRE_CLOSE(coords[1], 1, 1e-5);
+}
+
+/**
+ * Tests the L-BFGS optimizer using the Wood Function.
+ */
+BOOST_AUTO_TEST_CASE(WoodFunction)
+{
+ WoodFunction f;
+ L_BFGS<WoodFunction> lbfgs(f);
+ lbfgs.MaxIterations() = 10000;
+
+ arma::vec coords = f.GetInitialPoint();
+ if (!lbfgs.Optimize(coords))
+ BOOST_FAIL("L-BFGS optimization reported failure.");
+
+ double finalValue = f.Evaluate(coords);
+
+ BOOST_REQUIRE_SMALL(finalValue, 1e-5);
+ BOOST_REQUIRE_CLOSE(coords[0], 1, 1e-5);
+ BOOST_REQUIRE_CLOSE(coords[1], 1, 1e-5);
+ BOOST_REQUIRE_CLOSE(coords[2], 1, 1e-5);
+ BOOST_REQUIRE_CLOSE(coords[3], 1, 1e-5);
+}
+
+/**
+ * Tests the L-BFGS optimizer using the generalized Rosenbrock function. This
+ * is actually multiple tests, increasing the dimension by powers of 2, from 4
+ * dimensions to 1024 dimensions.
+ */
+BOOST_AUTO_TEST_CASE(GeneralizedRosenbrockFunction)
+{
+ for (int i = 2; i < 10; i++)
+ {
+ // Dimension: powers of 2
+ int dim = std::pow(2, i);
+
+ GeneralizedRosenbrockFunction f(dim);
+ L_BFGS<GeneralizedRosenbrockFunction> lbfgs(f, 20);
+ lbfgs.MaxIterations() = 10000;
+
+ arma::vec coords = f.GetInitialPoint();
+ if (!lbfgs.Optimize(coords))
+ BOOST_FAIL("L-BFGS optimization reported failure.");
+
+ double finalValue = f.Evaluate(coords);
+
+ // Test the output to make sure it is correct.
+ BOOST_REQUIRE_SMALL(finalValue, 1e-5);
+ for (int j = 0; j < dim; j++)
+ BOOST_REQUIRE_CLOSE(coords[j], 1, 1e-5);
+ }
+}
+
+/**
+ * Tests the L-BFGS optimizer using the Rosenbrock-Wood combined function. This
+ * is a test on optimizing a matrix of coordinates.
+ */
+BOOST_AUTO_TEST_CASE(RosenbrockWoodFunction)
+{
+ RosenbrockWoodFunction f;
+ L_BFGS<RosenbrockWoodFunction> lbfgs(f);
+ lbfgs.MaxIterations() = 10000;
+
+ arma::mat coords = f.GetInitialPoint();
+ if (!lbfgs.Optimize(coords))
+ BOOST_FAIL("L-BFGS optimization reported failure.");
+
+ double finalValue = f.Evaluate(coords);
+
+ BOOST_REQUIRE_SMALL(finalValue, 1e-5);
+ for (int row = 0; row < 4; row++)
+ {
+ BOOST_REQUIRE_CLOSE((coords(row, 0)), 1, 1e-5);
+ BOOST_REQUIRE_CLOSE((coords(row, 1)), 1, 1e-5);
+ }
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lin_alg_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/lin_alg_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lin_alg_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,203 +0,0 @@
-/**
- * @file lin_alg_test.cpp
- * @author Ryan Curtin
- *
- * Simple tests for things in the linalg__private namespace.
- * Partly so I can be sure that my changes are working.
- * Move to boost unit testing framework at some point.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/core/math/lin_alg.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace arma;
-using namespace mlpack;
-using namespace mlpack::math;
-
-BOOST_AUTO_TEST_SUITE(LinAlgTest);
-
-/**
- * Test for linalg__private::Center(). There are no edge cases here, so we'll
- * just try it once for now.
- */
-BOOST_AUTO_TEST_CASE(TestCenterA)
-{
- mat tmp(5, 5);
- // [[0 0 0 0 0]
- // [1 2 3 4 5]
- // [2 4 6 8 10]
- // [3 6 9 12 15]
- // [4 8 12 16 20]]
- for (int row = 0; row < 5; row++)
- for (int col = 0; col < 5; col++)
- tmp(row, col) = row * (col + 1);
-
- mat tmp_out;
- Center(tmp, tmp_out);
-
- // average should be
- // [[0 3 6 9 12]]'
- // so result should be
- // [[ 0 0 0 0 0]
- // [-2 -1 0 1 2 ]
- // [-4 -2 0 2 4 ]
- // [-6 -3 0 3 6 ]
- // [-8 -4 0 4 8]]
- for (int row = 0; row < 5; row++)
- for (int col = 0; col < 5; col++)
- BOOST_REQUIRE_CLOSE(tmp_out(row, col), (double) (col - 2) * row, 1e-5);
-}
-
-BOOST_AUTO_TEST_CASE(TestCenterB)
-{
- mat tmp(5, 6);
- for (int row = 0; row < 5; row++)
- for (int col = 0; col < 6; col++)
- tmp(row, col) = row * (col + 1);
-
- mat tmp_out;
- Center(tmp, tmp_out);
-
- // average should be
- // [[0 3.5 7 10.5 14]]'
- // so result should be
- // [[ 0 0 0 0 0 0 ]
- // [-2.5 -1.5 -0.5 0.5 1.5 2.5]
- // [-5 -3 -1 1 3 5 ]
- // [-7.5 -4.5 -1.5 1.5 1.5 4.5]
- // [-10 -6 -2 2 6 10 ]]
- for (int row = 0; row < 5; row++)
- for (int col = 0; col < 6; col++)
- BOOST_REQUIRE_CLOSE(tmp_out(row, col), (double) (col - 2.5) * row, 1e-5);
-}
-
-BOOST_AUTO_TEST_CASE(TestWhitenUsingEig)
-{
- // After whitening using eigendecomposition, the covariance of
- // our matrix will be I (or something very close to that).
- // We are loading a matrix from an external file... bad choice.
- mat tmp, tmp_centered, whitened, whitening_matrix;
-
- data::Load("trainSet.csv", tmp);
- Center(tmp, tmp_centered);
- WhitenUsingEig(tmp_centered, whitened, whitening_matrix);
-
- mat newcov = ccov(whitened);
- for (int row = 0; row < 5; row++)
- {
- for (int col = 0; col < 5; col++)
- {
- if (row == col)
- {
- // diagonal will be 0 in the case of any zero-valued eigenvalues
- // (rank-deficient covariance case)
- if (std::abs(newcov(row, col)) > 1e-10)
- BOOST_REQUIRE_CLOSE(newcov(row, col), 1.0, 1e-10);
- }
- else
- {
- BOOST_REQUIRE_SMALL(newcov(row, col), 1e-10);
- }
- }
- }
-}
-
-BOOST_AUTO_TEST_CASE(TestOrthogonalize)
-{
- // Generate a random matrix; then, orthogonalize it and test if it's
- // orthogonal.
- mat tmp, orth;
- data::Load("fake.csv", tmp);
- Orthogonalize(tmp, orth);
-
- // test orthogonality
- mat test = ccov(orth);
- double ival = test(0, 0);
- for (size_t row = 0; row < test.n_rows; row++)
- {
- for (size_t col = 0; col < test.n_cols; col++)
- {
- if (row == col)
- {
- if (std::abs(test(row, col)) > 1e-10)
- BOOST_REQUIRE_CLOSE(test(row, col), ival, 1e-10);
- }
- else
- {
- BOOST_REQUIRE_SMALL(test(row, col), 1e-10);
- }
- }
- }
-}
-
-// Test RemoveRows().
-BOOST_AUTO_TEST_CASE(TestRemoveRows)
-{
- // Run this test several times.
- for (size_t run = 0; run < 10; ++run)
- {
- arma::mat input;
- input.randu(200, 200);
-
- // Now pick some random numbers.
- std::vector<size_t> rowsToRemove;
- size_t row = 0;
- while (row < 200)
- {
- row += RandInt(1, (2 * (run + 1) + 1));
- if (row < 200)
- {
- rowsToRemove.push_back(row);
- }
- }
-
- // Ensure we're not about to remove every single row.
- if (rowsToRemove.size() == 10)
- {
- rowsToRemove.erase(rowsToRemove.begin() + 4); // Random choice to remove.
- }
-
- arma::mat output;
- RemoveRows(input, rowsToRemove, output);
-
- // Now check that the output is right.
- size_t outputRow = 0;
- size_t skipIndex = 0;
-
- for (row = 0; row < 200; ++row)
- {
- // Was this row supposed to be removed? If so skip it.
- if ((skipIndex < rowsToRemove.size()) && (rowsToRemove[skipIndex] == row))
- {
- ++skipIndex;
- }
- else
- {
- // Compare.
- BOOST_REQUIRE_EQUAL(accu(input.row(row) == output.row(outputRow)), 200);
-
- // Increment output row counter.
- ++outputRow;
- }
- }
- }
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lin_alg_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/lin_alg_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lin_alg_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lin_alg_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,203 @@
+/**
+ * @file lin_alg_test.cpp
+ * @author Ryan Curtin
+ *
+ * Simple tests for things in the linalg__private namespace.
+ * Partly so I can be sure that my changes are working.
+ * Move to boost unit testing framework at some point.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/math/lin_alg.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace arma;
+using namespace mlpack;
+using namespace mlpack::math;
+
+BOOST_AUTO_TEST_SUITE(LinAlgTest);
+
+/**
+ * Test for linalg__private::Center(). There are no edge cases here, so we'll
+ * just try it once for now.
+ */
+BOOST_AUTO_TEST_CASE(TestCenterA)
+{
+ mat tmp(5, 5);
+ // [[0 0 0 0 0]
+ // [1 2 3 4 5]
+ // [2 4 6 8 10]
+ // [3 6 9 12 15]
+ // [4 8 12 16 20]]
+ for (int row = 0; row < 5; row++)
+ for (int col = 0; col < 5; col++)
+ tmp(row, col) = row * (col + 1);
+
+ mat tmp_out;
+ Center(tmp, tmp_out);
+
+ // average should be
+ // [[0 3 6 9 12]]'
+ // so result should be
+ // [[ 0 0 0 0 0]
+ // [-2 -1 0 1 2 ]
+ // [-4 -2 0 2 4 ]
+ // [-6 -3 0 3 6 ]
+ // [-8 -4 0 4 8]]
+ for (int row = 0; row < 5; row++)
+ for (int col = 0; col < 5; col++)
+ BOOST_REQUIRE_CLOSE(tmp_out(row, col), (double) (col - 2) * row, 1e-5);
+}
+
+BOOST_AUTO_TEST_CASE(TestCenterB)
+{
+ mat tmp(5, 6);
+ for (int row = 0; row < 5; row++)
+ for (int col = 0; col < 6; col++)
+ tmp(row, col) = row * (col + 1);
+
+ mat tmp_out;
+ Center(tmp, tmp_out);
+
+ // average should be
+ // [[0 3.5 7 10.5 14]]'
+ // so result should be
+ // [[ 0 0 0 0 0 0 ]
+ // [-2.5 -1.5 -0.5 0.5 1.5 2.5]
+ // [-5 -3 -1 1 3 5 ]
+ // [-7.5 -4.5 -1.5 1.5 1.5 4.5]
+ // [-10 -6 -2 2 6 10 ]]
+ for (int row = 0; row < 5; row++)
+ for (int col = 0; col < 6; col++)
+ BOOST_REQUIRE_CLOSE(tmp_out(row, col), (double) (col - 2.5) * row, 1e-5);
+}
+
+BOOST_AUTO_TEST_CASE(TestWhitenUsingEig)
+{
+ // After whitening using eigendecomposition, the covariance of
+ // our matrix will be I (or something very close to that).
+ // We are loading a matrix from an external file... bad choice.
+ mat tmp, tmp_centered, whitened, whitening_matrix;
+
+ data::Load("trainSet.csv", tmp);
+ Center(tmp, tmp_centered);
+ WhitenUsingEig(tmp_centered, whitened, whitening_matrix);
+
+ mat newcov = ccov(whitened);
+ for (int row = 0; row < 5; row++)
+ {
+ for (int col = 0; col < 5; col++)
+ {
+ if (row == col)
+ {
+ // diagonal will be 0 in the case of any zero-valued eigenvalues
+ // (rank-deficient covariance case)
+ if (std::abs(newcov(row, col)) > 1e-10)
+ BOOST_REQUIRE_CLOSE(newcov(row, col), 1.0, 1e-10);
+ }
+ else
+ {
+ BOOST_REQUIRE_SMALL(newcov(row, col), 1e-10);
+ }
+ }
+ }
+}
+
+BOOST_AUTO_TEST_CASE(TestOrthogonalize)
+{
+ // Generate a random matrix; then, orthogonalize it and test if it's
+ // orthogonal.
+ mat tmp, orth;
+ data::Load("fake.csv", tmp);
+ Orthogonalize(tmp, orth);
+
+ // test orthogonality
+ mat test = ccov(orth);
+ double ival = test(0, 0);
+ for (size_t row = 0; row < test.n_rows; row++)
+ {
+ for (size_t col = 0; col < test.n_cols; col++)
+ {
+ if (row == col)
+ {
+ if (std::abs(test(row, col)) > 1e-10)
+ BOOST_REQUIRE_CLOSE(test(row, col), ival, 1e-10);
+ }
+ else
+ {
+ BOOST_REQUIRE_SMALL(test(row, col), 1e-10);
+ }
+ }
+ }
+}
+
+// Test RemoveRows().
+BOOST_AUTO_TEST_CASE(TestRemoveRows)
+{
+ // Run this test several times.
+ for (size_t run = 0; run < 10; ++run)
+ {
+ arma::mat input;
+ input.randu(200, 200);
+
+ // Now pick some random numbers.
+ std::vector<size_t> rowsToRemove;
+ size_t row = 0;
+ while (row < 200)
+ {
+ row += RandInt(1, (2 * (run + 1) + 1));
+ if (row < 200)
+ {
+ rowsToRemove.push_back(row);
+ }
+ }
+
+ // Ensure we're not about to remove every single row.
+ if (rowsToRemove.size() == 10)
+ {
+ rowsToRemove.erase(rowsToRemove.begin() + 4); // Random choice to remove.
+ }
+
+ arma::mat output;
+ RemoveRows(input, rowsToRemove, output);
+
+ // Now check that the output is right.
+ size_t outputRow = 0;
+ size_t skipIndex = 0;
+
+ for (row = 0; row < 200; ++row)
+ {
+ // Was this row supposed to be removed? If so skip it.
+ if ((skipIndex < rowsToRemove.size()) && (rowsToRemove[skipIndex] == row))
+ {
+ ++skipIndex;
+ }
+ else
+ {
+ // Compare.
+ BOOST_REQUIRE_EQUAL(accu(input.row(row) == output.row(outputRow)), 200);
+
+ // Increment output row counter.
+ ++outputRow;
+ }
+ }
+ }
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/linear_regression_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/linear_regression_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/linear_regression_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,81 +0,0 @@
-/**
- * @file linear_regression_test.cpp
- *
- * Test for linear regression.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/methods/linear_regression/linear_regression.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::regression;
-
-BOOST_AUTO_TEST_SUITE(LinearRegressionTest);
-
-/**
- * Creates two 10x3 random matrices and one 10x1 "results" matrix.
- * Finds B in y=BX with one matrix, then predicts against the other.
- */
-BOOST_AUTO_TEST_CASE(LinearRegressionTestCase)
-{
- // Predictors and points are 100x3 matrices.
- arma::mat predictors(3, 10);
- arma::mat points(3, 10);
-
- // Responses is the "correct" value for each point in predictors and points.
- arma::vec responses(10);
-
- // The values we get back when we predict for points.
- arma::vec predictions(10);
-
- // We'll randomly select some coefficients for the linear response.
- arma::vec coeffs;
- coeffs.randu(4);
-
- // Now generate each point.
- for (size_t row = 0; row < 3; row++)
- predictors.row(row) = arma::linspace<arma::rowvec>(0, 9, 10);
-
- points = predictors;
-
- // Now add a small amount of noise to each point.
- for (size_t elem = 0; elem < points.n_elem; elem++)
- {
- // Max added noise is 0.02.
- points[elem] += math::Random() / 50.0;
- predictors[elem] += math::Random() / 50.0;
- }
-
- // Generate responses.
- for (size_t elem = 0; elem < responses.n_elem; elem++)
- responses[elem] = coeffs[0] +
- dot(coeffs.rows(1, 3), arma::ones<arma::rowvec>(3) * elem);
-
- // Initialize and predict
- LinearRegression lr(predictors, responses);
- lr.Predict(points, predictions);
-
- // Output result and verify we have less than 5% error from "correct" value
- // for each point
- for(size_t i = 0; i < predictions.n_cols; ++i)
- BOOST_REQUIRE_SMALL(predictions(i) - responses(i), .05);
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/linear_regression_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/linear_regression_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/linear_regression_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/linear_regression_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,81 @@
+/**
+ * @file linear_regression_test.cpp
+ *
+ * Test for linear regression.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/methods/linear_regression/linear_regression.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::regression;
+
+BOOST_AUTO_TEST_SUITE(LinearRegressionTest);
+
+/**
+ * Creates two 10x3 random matrices and one 10x1 "results" matrix.
+ * Finds B in y=BX with one matrix, then predicts against the other.
+ */
+BOOST_AUTO_TEST_CASE(LinearRegressionTestCase)
+{
+ // Predictors and points are 100x3 matrices.
+ arma::mat predictors(3, 10);
+ arma::mat points(3, 10);
+
+ // Responses is the "correct" value for each point in predictors and points.
+ arma::vec responses(10);
+
+ // The values we get back when we predict for points.
+ arma::vec predictions(10);
+
+ // We'll randomly select some coefficients for the linear response.
+ arma::vec coeffs;
+ coeffs.randu(4);
+
+ // Now generate each point.
+ for (size_t row = 0; row < 3; row++)
+ predictors.row(row) = arma::linspace<arma::rowvec>(0, 9, 10);
+
+ points = predictors;
+
+ // Now add a small amount of noise to each point.
+ for (size_t elem = 0; elem < points.n_elem; elem++)
+ {
+ // Max added noise is 0.02.
+ points[elem] += math::Random() / 50.0;
+ predictors[elem] += math::Random() / 50.0;
+ }
+
+ // Generate responses.
+ for (size_t elem = 0; elem < responses.n_elem; elem++)
+ responses[elem] = coeffs[0] +
+ dot(coeffs.rows(1, 3), arma::ones<arma::rowvec>(3) * elem);
+
+ // Initialize and predict
+ LinearRegression lr(predictors, responses);
+ lr.Predict(points, predictions);
+
+ // Output result and verify we have less than 5% error from "correct" value
+ // for each point
+ for(size_t i = 0; i < predictions.n_cols; ++i)
+ BOOST_REQUIRE_SMALL(predictions(i) - responses(i), .05);
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/load_save_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/load_save_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/load_save_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,526 +0,0 @@
-/**
- * @file load_save_test.cpp
- * @author Ryan Curtin
- *
- * Tests for data::Load() and data::Save().
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <sstream>
-
-#include <mlpack/core/data/load.hpp>
-#include <mlpack/core/data/save.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-
-BOOST_AUTO_TEST_SUITE(LoadSaveTest);
-
-/**
- * Make sure failure occurs when no extension given.
- */
-BOOST_AUTO_TEST_CASE(NoExtensionLoad)
-{
- arma::mat out;
- BOOST_REQUIRE(data::Load("noextension", out) == false);
-}
-
-/**
- * Make sure failure occurs when no extension given.
- */
-BOOST_AUTO_TEST_CASE(NoExtensionSave)
-{
- arma::mat out;
- BOOST_REQUIRE(data::Save("noextension", out) == false);
-}
-
-/**
- * Make sure load fails if the file does not exist.
- */
-BOOST_AUTO_TEST_CASE(NotExistLoad)
-{
- arma::mat out;
- BOOST_REQUIRE(data::Load("nonexistentfile_______________.csv", out) == false);
-}
-
-/**
- * Make sure a CSV is loaded correctly.
- */
-BOOST_AUTO_TEST_CASE(LoadCSVTest)
-{
- std::fstream f;
- f.open("test_file.csv", std::fstream::out);
-
- f << "1, 2, 3, 4" << std::endl;
- f << "5, 6, 7, 8" << std::endl;
-
- f.close();
-
- arma::mat test;
- BOOST_REQUIRE(data::Load("test_file.csv", test) == true);
-
- BOOST_REQUIRE_EQUAL(test.n_rows, 4);
- BOOST_REQUIRE_EQUAL(test.n_cols, 2);
-
- for (int i = 0; i < 8; i++)
- BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
-
- // Remove the file.
- remove("test_file.csv");
-}
-
-/**
- * Make sure a CSV is saved correctly.
- */
-BOOST_AUTO_TEST_CASE(SaveCSVTest)
-{
- arma::mat test = "1 5;"
- "2 6;"
- "3 7;"
- "4 8;";
-
- BOOST_REQUIRE(data::Save("test_file.csv", test) == true);
-
- // Load it in and make sure it is the same.
- arma::mat test2;
- BOOST_REQUIRE(data::Load("test_file.csv", test2) == true);
-
- BOOST_REQUIRE_EQUAL(test2.n_rows, 4);
- BOOST_REQUIRE_EQUAL(test2.n_cols, 2);
-
- for (int i = 0; i < 8; i++)
- BOOST_REQUIRE_CLOSE(test2[i], (double) (i + 1), 1e-5);
-
- // Remove the file.
- remove("test_file.csv");
-}
-
-/**
- * Make sure CSVs can be loaded in non-transposed form.
- */
-BOOST_AUTO_TEST_CASE(LoadNonTransposedCSVTest)
-{
- std::fstream f;
- f.open("test_file.csv", std::fstream::out);
-
- f << "1, 3, 5, 7" << std::endl;
- f << "2, 4, 6, 8" << std::endl;
-
- f.close();
-
- arma::mat test;
- BOOST_REQUIRE(data::Load("test_file.csv", test, false, false) == true);
-
- BOOST_REQUIRE_EQUAL(test.n_cols, 4);
- BOOST_REQUIRE_EQUAL(test.n_rows, 2);
-
- for (size_t i = 0; i < 8; ++i)
- BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
-
- // Remove the file.
- remove("test_file.csv");
-}
-
-/**
- * Make sure CSVs can be saved in non-transposed form.
- */
-BOOST_AUTO_TEST_CASE(SaveNonTransposedCSVTest)
-{
- arma::mat test = "1 2;"
- "3 4;"
- "5 6;"
- "7 8;";
-
- BOOST_REQUIRE(data::Save("test_file.csv", test, false, false) == true);
-
- // Load it in and make sure it is in the same.
- arma::mat test2;
- BOOST_REQUIRE(data::Load("test_file.csv", test2, false, false) == true);
-
- BOOST_REQUIRE_EQUAL(test2.n_rows, 4);
- BOOST_REQUIRE_EQUAL(test2.n_cols, 2);
-
- for (size_t i = 0; i < 8; ++i)
- BOOST_REQUIRE_CLOSE(test[i], test2[i], 1e-5);
-
- // Remove the file.
- remove("test_file.csv");
-}
-
-/**
- * Make sure arma_ascii is loaded correctly.
- */
-BOOST_AUTO_TEST_CASE(LoadArmaASCIITest)
-{
- arma::mat test = "1 5;"
- "2 6;"
- "3 7;"
- "4 8;";
-
- arma::mat testTrans = trans(test);
- BOOST_REQUIRE(testTrans.save("test_file.txt", arma::arma_ascii));
-
- BOOST_REQUIRE(data::Load("test_file.txt", test) == true);
-
- BOOST_REQUIRE_EQUAL(test.n_rows, 4);
- BOOST_REQUIRE_EQUAL(test.n_cols, 2);
-
- for (int i = 0; i < 8; i++)
- BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
-
- // Remove the file.
- remove("test_file.txt");
-}
-
-/**
- * Make sure a CSV is saved correctly.
- */
-BOOST_AUTO_TEST_CASE(SaveArmaASCIITest)
-{
- arma::mat test = "1 5;"
- "2 6;"
- "3 7;"
- "4 8;";
-
- BOOST_REQUIRE(data::Save("test_file.txt", test) == true);
-
- // Load it in and make sure it is the same.
- BOOST_REQUIRE(data::Load("test_file.txt", test) == true);
-
- BOOST_REQUIRE_EQUAL(test.n_rows, 4);
- BOOST_REQUIRE_EQUAL(test.n_cols, 2);
-
- for (int i = 0; i < 8; i++)
- BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
-
- // Remove the file.
- remove("test_file.txt");
-}
-
-/**
- * Make sure raw_ascii is loaded correctly.
- */
-BOOST_AUTO_TEST_CASE(LoadRawASCIITest)
-{
- std::fstream f;
- f.open("test_file.txt", std::fstream::out);
-
- f << "1 2 3 4" << std::endl;
- f << "5 6 7 8" << std::endl;
-
- f.close();
-
- arma::mat test;
- BOOST_REQUIRE(data::Load("test_file.txt", test) == true);
-
- BOOST_REQUIRE_EQUAL(test.n_rows, 4);
- BOOST_REQUIRE_EQUAL(test.n_cols, 2);
-
- for (int i = 0; i < 8; i++)
- BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
-
- // Remove the file.
- remove("test_file.txt");
-}
-
-/**
- * Make sure CSV is loaded correctly as .txt.
- */
-BOOST_AUTO_TEST_CASE(LoadCSVTxtTest)
-{
- std::fstream f;
- f.open("test_file.txt", std::fstream::out);
-
- f << "1, 2, 3, 4" << std::endl;
- f << "5, 6, 7, 8" << std::endl;
-
- f.close();
-
- arma::mat test;
- BOOST_REQUIRE(data::Load("test_file.txt", test) == true);
-
- BOOST_REQUIRE_EQUAL(test.n_rows, 4);
- BOOST_REQUIRE_EQUAL(test.n_cols, 2);
-
- for (int i = 0; i < 8; i++)
- BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
-
- // Remove the file.
- remove("test_file.txt");
-}
-
-/**
- * Make sure arma_binary is loaded correctly.
- */
-BOOST_AUTO_TEST_CASE(LoadArmaBinaryTest)
-{
- arma::mat test = "1 5;"
- "2 6;"
- "3 7;"
- "4 8;";
-
- arma::mat testTrans = trans(test);
- BOOST_REQUIRE(testTrans.quiet_save("test_file.bin", arma::arma_binary)
- == true);
-
- // Now reload through our interface.
- BOOST_REQUIRE(data::Load("test_file.bin", test) == true);
-
- BOOST_REQUIRE_EQUAL(test.n_rows, 4);
- BOOST_REQUIRE_EQUAL(test.n_cols, 2);
-
- for (int i = 0; i < 8; i++)
- BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
-
- // Remove the file.
- remove("test_file.bin");
-}
-
-/**
- * Make sure arma_binary is saved correctly.
- */
-BOOST_AUTO_TEST_CASE(SaveArmaBinaryTest)
-{
- arma::mat test = "1 5;"
- "2 6;"
- "3 7;"
- "4 8;";
-
- BOOST_REQUIRE(data::Save("test_file.bin", test) == true);
-
- BOOST_REQUIRE(data::Load("test_file.bin", test) == true);
-
- BOOST_REQUIRE_EQUAL(test.n_rows, 4);
- BOOST_REQUIRE_EQUAL(test.n_cols, 2);
-
- for (int i = 0; i < 8; i++)
- BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
-
- // Remove the file.
- remove("test_file.bin");
-}
-
-/**
- * Make sure raw_binary is loaded correctly.
- */
-BOOST_AUTO_TEST_CASE(LoadRawBinaryTest)
-{
- arma::mat test = "1 2;"
- "3 4;"
- "5 6;"
- "7 8;";
-
- arma::mat testTrans = trans(test);
- BOOST_REQUIRE(testTrans.quiet_save("test_file.bin", arma::raw_binary)
- == true);
-
- // Now reload through our interface.
- BOOST_REQUIRE(data::Load("test_file.bin", test) == true);
-
- BOOST_REQUIRE_EQUAL(test.n_rows, 1);
- BOOST_REQUIRE_EQUAL(test.n_cols, 8);
-
- for (int i = 0; i < 8; i++)
- BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
-
- // Remove the file.
- remove("test_file.bin");
-}
-
-/**
- * Make sure load as PGM is successful.
- */
-BOOST_AUTO_TEST_CASE(LoadPGMBinaryTest)
-{
- arma::mat test = "1 5;"
- "2 6;"
- "3 7;"
- "4 8;";
-
- arma::mat testTrans = trans(test);
- BOOST_REQUIRE(testTrans.quiet_save("test_file.pgm", arma::pgm_binary)
- == true);
-
- // Now reload through our interface.
- BOOST_REQUIRE(data::Load("test_file.pgm", test) == true);
-
- BOOST_REQUIRE_EQUAL(test.n_rows, 4);
- BOOST_REQUIRE_EQUAL(test.n_cols, 2);
-
- for (int i = 0; i < 8; i++)
- BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
-
- // Remove the file.
- remove("test_file.pgm");
-}
-
-/**
- * Make sure save as PGM is successful.
- */
-BOOST_AUTO_TEST_CASE(SavePGMBinaryTest)
-{
- arma::mat test = "1 5;"
- "2 6;"
- "3 7;"
- "4 8;";
-
- BOOST_REQUIRE(data::Save("test_file.pgm", test) == true);
-
- // Now reload through our interface.
- BOOST_REQUIRE(data::Load("test_file.pgm", test) == true);
-
- BOOST_REQUIRE_EQUAL(test.n_rows, 4);
- BOOST_REQUIRE_EQUAL(test.n_cols, 2);
-
- for (int i = 0; i < 8; i++)
- BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
-
- // Remove the file.
- remove("test_file.pgm");
-}
-
-#ifdef ARMA_USE_HDF5
-/**
- * Make sure load as HDF5 is successful.
- */
-BOOST_AUTO_TEST_CASE(LoadHDF5Test)
-{
- arma::mat test = "1 5;"
- "2 6;"
- "3 7;"
- "4 8;";
- arma::mat testTrans = trans(test);
- BOOST_REQUIRE(testTrans.quiet_save("test_file.h5", arma::hdf5_binary)
- == true);
- BOOST_REQUIRE(testTrans.quiet_save("test_file.hdf5", arma::hdf5_binary)
- == true);
- BOOST_REQUIRE(testTrans.quiet_save("test_file.hdf", arma::hdf5_binary)
- == true);
- BOOST_REQUIRE(testTrans.quiet_save("test_file.he5", arma::hdf5_binary)
- == true);
-
- // Now reload through our interface.
- BOOST_REQUIRE(data::Load("test_file.h5", test) == true);
-
- BOOST_REQUIRE_EQUAL(test.n_rows, 4);
- BOOST_REQUIRE_EQUAL(test.n_cols, 2);
-
- for (int i = 0; i < 8; ++i)
- BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
-
- // Make sure the other extensions work too.
- BOOST_REQUIRE(data::Load("test_file.hdf5", test) == true);
-
- BOOST_REQUIRE_EQUAL(test.n_rows, 4);
- BOOST_REQUIRE_EQUAL(test.n_cols, 2);
-
- for (int i = 0; i < 8; ++i)
- BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
-
- BOOST_REQUIRE(data::Load("test_file.hdf", test) == true);
-
- BOOST_REQUIRE_EQUAL(test.n_rows, 4);
- BOOST_REQUIRE_EQUAL(test.n_cols, 2);
-
- for (int i = 0; i < 8; ++i)
- BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
-
- BOOST_REQUIRE(data::Load("test_file.he5", test) == true);
-
- BOOST_REQUIRE_EQUAL(test.n_rows, 4);
- BOOST_REQUIRE_EQUAL(test.n_cols, 2);
-
- for (int i = 0; i < 8; ++i)
- BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
-
- remove("test_file.h5");
- remove("test_file.hdf");
- remove("test_file.hdf5");
- remove("test_file.he5");
-}
-
-/**
- * Make sure save as HDF5 is successful.
- */
-BOOST_AUTO_TEST_CASE(SaveHDF5Test)
-{
- arma::mat test = "1 5;"
- "2 6;"
- "3 7;"
- "4 8;";
- BOOST_REQUIRE(data::Save("test_file.h5", test) == true);
- BOOST_REQUIRE(data::Save("test_file.hdf5", test) == true);
- BOOST_REQUIRE(data::Save("test_file.hdf", test) == true);
- BOOST_REQUIRE(data::Save("test_file.he5", test) == true);
-
- // Now load them all and verify they were saved okay.
- BOOST_REQUIRE(data::Load("test_file.h5", test) == true);
-
- BOOST_REQUIRE_EQUAL(test.n_rows, 4);
- BOOST_REQUIRE_EQUAL(test.n_cols, 2);
-
- for (int i = 0; i < 8; ++i)
- BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
-
- // Make sure the other extensions work too.
- BOOST_REQUIRE(data::Load("test_file.hdf5", test) == true);
-
- BOOST_REQUIRE_EQUAL(test.n_rows, 4);
- BOOST_REQUIRE_EQUAL(test.n_cols, 2);
-
- for (int i = 0; i < 8; ++i)
- BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
-
- BOOST_REQUIRE(data::Load("test_file.hdf", test) == true);
-
- BOOST_REQUIRE_EQUAL(test.n_rows, 4);
- BOOST_REQUIRE_EQUAL(test.n_cols, 2);
-
- for (int i = 0; i < 8; ++i)
- BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
-
- BOOST_REQUIRE(data::Load("test_file.he5", test) == true);
-
- BOOST_REQUIRE_EQUAL(test.n_rows, 4);
- BOOST_REQUIRE_EQUAL(test.n_cols, 2);
-
- for (int i = 0; i < 8; ++i)
- BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
-
- remove("test_file.h5");
- remove("test_file.hdf");
- remove("test_file.hdf5");
- remove("test_file.he5");
-}
-#else
-/**
- * Ensure saving as HDF5 fails.
- */
-BOOST_AUTO_TEST_CASE(NoHDF5Test)
-{
- arma::mat test;
- test.randu(5, 5);
-
- BOOST_REQUIRE(data::Save("test_file.h5", test) == false);
- BOOST_REQUIRE(data::Save("test_file.hdf5", test) == false);
- BOOST_REQUIRE(data::Save("test_file.hdf", test) == false);
- BOOST_REQUIRE(data::Save("test_file.he5", test) == false);
-}
-#endif
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/load_save_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/load_save_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/load_save_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/load_save_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,526 @@
+/**
+ * @file load_save_test.cpp
+ * @author Ryan Curtin
+ *
+ * Tests for data::Load() and data::Save().
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <sstream>
+
+#include <mlpack/core/data/load.hpp>
+#include <mlpack/core/data/save.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+
+BOOST_AUTO_TEST_SUITE(LoadSaveTest);
+
+/**
+ * Make sure failure occurs when no extension given.
+ */
+BOOST_AUTO_TEST_CASE(NoExtensionLoad)
+{
+ arma::mat out;
+ BOOST_REQUIRE(data::Load("noextension", out) == false);
+}
+
+/**
+ * Make sure failure occurs when no extension given.
+ */
+BOOST_AUTO_TEST_CASE(NoExtensionSave)
+{
+ arma::mat out;
+ BOOST_REQUIRE(data::Save("noextension", out) == false);
+}
+
+/**
+ * Make sure load fails if the file does not exist.
+ */
+BOOST_AUTO_TEST_CASE(NotExistLoad)
+{
+ arma::mat out;
+ BOOST_REQUIRE(data::Load("nonexistentfile_______________.csv", out) == false);
+}
+
+/**
+ * Make sure a CSV is loaded correctly.
+ */
+BOOST_AUTO_TEST_CASE(LoadCSVTest)
+{
+ std::fstream f;
+ f.open("test_file.csv", std::fstream::out);
+
+ f << "1, 2, 3, 4" << std::endl;
+ f << "5, 6, 7, 8" << std::endl;
+
+ f.close();
+
+ arma::mat test;
+ BOOST_REQUIRE(data::Load("test_file.csv", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.csv");
+}
+
+/**
+ * Make sure a CSV is saved correctly.
+ */
+BOOST_AUTO_TEST_CASE(SaveCSVTest)
+{
+ arma::mat test = "1 5;"
+ "2 6;"
+ "3 7;"
+ "4 8;";
+
+ BOOST_REQUIRE(data::Save("test_file.csv", test) == true);
+
+ // Load it in and make sure it is the same.
+ arma::mat test2;
+ BOOST_REQUIRE(data::Load("test_file.csv", test2) == true);
+
+ BOOST_REQUIRE_EQUAL(test2.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test2.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test2[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.csv");
+}
+
+/**
+ * Make sure CSVs can be loaded in non-transposed form.
+ */
+BOOST_AUTO_TEST_CASE(LoadNonTransposedCSVTest)
+{
+ std::fstream f;
+ f.open("test_file.csv", std::fstream::out);
+
+ f << "1, 3, 5, 7" << std::endl;
+ f << "2, 4, 6, 8" << std::endl;
+
+ f.close();
+
+ arma::mat test;
+ BOOST_REQUIRE(data::Load("test_file.csv", test, false, false) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_cols, 4);
+ BOOST_REQUIRE_EQUAL(test.n_rows, 2);
+
+ for (size_t i = 0; i < 8; ++i)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.csv");
+}
+
+/**
+ * Make sure CSVs can be saved in non-transposed form.
+ */
+BOOST_AUTO_TEST_CASE(SaveNonTransposedCSVTest)
+{
+ arma::mat test = "1 2;"
+ "3 4;"
+ "5 6;"
+ "7 8;";
+
+ BOOST_REQUIRE(data::Save("test_file.csv", test, false, false) == true);
+
+ // Load it in and make sure it is in the same.
+ arma::mat test2;
+ BOOST_REQUIRE(data::Load("test_file.csv", test2, false, false) == true);
+
+ BOOST_REQUIRE_EQUAL(test2.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test2.n_cols, 2);
+
+ for (size_t i = 0; i < 8; ++i)
+ BOOST_REQUIRE_CLOSE(test[i], test2[i], 1e-5);
+
+ // Remove the file.
+ remove("test_file.csv");
+}
+
+/**
+ * Make sure arma_ascii is loaded correctly.
+ */
+BOOST_AUTO_TEST_CASE(LoadArmaASCIITest)
+{
+ arma::mat test = "1 5;"
+ "2 6;"
+ "3 7;"
+ "4 8;";
+
+ arma::mat testTrans = trans(test);
+ BOOST_REQUIRE(testTrans.save("test_file.txt", arma::arma_ascii));
+
+ BOOST_REQUIRE(data::Load("test_file.txt", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.txt");
+}
+
+/**
+ * Make sure a CSV is saved correctly.
+ */
+BOOST_AUTO_TEST_CASE(SaveArmaASCIITest)
+{
+ arma::mat test = "1 5;"
+ "2 6;"
+ "3 7;"
+ "4 8;";
+
+ BOOST_REQUIRE(data::Save("test_file.txt", test) == true);
+
+ // Load it in and make sure it is the same.
+ BOOST_REQUIRE(data::Load("test_file.txt", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.txt");
+}
+
+/**
+ * Make sure raw_ascii is loaded correctly.
+ */
+BOOST_AUTO_TEST_CASE(LoadRawASCIITest)
+{
+ std::fstream f;
+ f.open("test_file.txt", std::fstream::out);
+
+ f << "1 2 3 4" << std::endl;
+ f << "5 6 7 8" << std::endl;
+
+ f.close();
+
+ arma::mat test;
+ BOOST_REQUIRE(data::Load("test_file.txt", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.txt");
+}
+
+/**
+ * Make sure CSV is loaded correctly as .txt.
+ */
+BOOST_AUTO_TEST_CASE(LoadCSVTxtTest)
+{
+ std::fstream f;
+ f.open("test_file.txt", std::fstream::out);
+
+ f << "1, 2, 3, 4" << std::endl;
+ f << "5, 6, 7, 8" << std::endl;
+
+ f.close();
+
+ arma::mat test;
+ BOOST_REQUIRE(data::Load("test_file.txt", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.txt");
+}
+
+/**
+ * Make sure arma_binary is loaded correctly.
+ */
+BOOST_AUTO_TEST_CASE(LoadArmaBinaryTest)
+{
+ arma::mat test = "1 5;"
+ "2 6;"
+ "3 7;"
+ "4 8;";
+
+ arma::mat testTrans = trans(test);
+ BOOST_REQUIRE(testTrans.quiet_save("test_file.bin", arma::arma_binary)
+ == true);
+
+ // Now reload through our interface.
+ BOOST_REQUIRE(data::Load("test_file.bin", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.bin");
+}
+
+/**
+ * Make sure arma_binary is saved correctly.
+ */
+BOOST_AUTO_TEST_CASE(SaveArmaBinaryTest)
+{
+ arma::mat test = "1 5;"
+ "2 6;"
+ "3 7;"
+ "4 8;";
+
+ BOOST_REQUIRE(data::Save("test_file.bin", test) == true);
+
+ BOOST_REQUIRE(data::Load("test_file.bin", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.bin");
+}
+
+/**
+ * Make sure raw_binary is loaded correctly.
+ */
+BOOST_AUTO_TEST_CASE(LoadRawBinaryTest)
+{
+ arma::mat test = "1 2;"
+ "3 4;"
+ "5 6;"
+ "7 8;";
+
+ arma::mat testTrans = trans(test);
+ BOOST_REQUIRE(testTrans.quiet_save("test_file.bin", arma::raw_binary)
+ == true);
+
+ // Now reload through our interface.
+ BOOST_REQUIRE(data::Load("test_file.bin", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 1);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 8);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.bin");
+}
+
+/**
+ * Make sure load as PGM is successful.
+ */
+BOOST_AUTO_TEST_CASE(LoadPGMBinaryTest)
+{
+ arma::mat test = "1 5;"
+ "2 6;"
+ "3 7;"
+ "4 8;";
+
+ arma::mat testTrans = trans(test);
+ BOOST_REQUIRE(testTrans.quiet_save("test_file.pgm", arma::pgm_binary)
+ == true);
+
+ // Now reload through our interface.
+ BOOST_REQUIRE(data::Load("test_file.pgm", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.pgm");
+}
+
+/**
+ * Make sure save as PGM is successful.
+ */
+BOOST_AUTO_TEST_CASE(SavePGMBinaryTest)
+{
+ arma::mat test = "1 5;"
+ "2 6;"
+ "3 7;"
+ "4 8;";
+
+ BOOST_REQUIRE(data::Save("test_file.pgm", test) == true);
+
+ // Now reload through our interface.
+ BOOST_REQUIRE(data::Load("test_file.pgm", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; i++)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Remove the file.
+ remove("test_file.pgm");
+}
+
+#ifdef ARMA_USE_HDF5
+/**
+ * Make sure load as HDF5 is successful.
+ */
+BOOST_AUTO_TEST_CASE(LoadHDF5Test)
+{
+ arma::mat test = "1 5;"
+ "2 6;"
+ "3 7;"
+ "4 8;";
+ arma::mat testTrans = trans(test);
+ BOOST_REQUIRE(testTrans.quiet_save("test_file.h5", arma::hdf5_binary)
+ == true);
+ BOOST_REQUIRE(testTrans.quiet_save("test_file.hdf5", arma::hdf5_binary)
+ == true);
+ BOOST_REQUIRE(testTrans.quiet_save("test_file.hdf", arma::hdf5_binary)
+ == true);
+ BOOST_REQUIRE(testTrans.quiet_save("test_file.he5", arma::hdf5_binary)
+ == true);
+
+ // Now reload through our interface.
+ BOOST_REQUIRE(data::Load("test_file.h5", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; ++i)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Make sure the other extensions work too.
+ BOOST_REQUIRE(data::Load("test_file.hdf5", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; ++i)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ BOOST_REQUIRE(data::Load("test_file.hdf", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; ++i)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ BOOST_REQUIRE(data::Load("test_file.he5", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; ++i)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ remove("test_file.h5");
+ remove("test_file.hdf");
+ remove("test_file.hdf5");
+ remove("test_file.he5");
+}
+
+/**
+ * Make sure save as HDF5 is successful.
+ */
+BOOST_AUTO_TEST_CASE(SaveHDF5Test)
+{
+ arma::mat test = "1 5;"
+ "2 6;"
+ "3 7;"
+ "4 8;";
+ BOOST_REQUIRE(data::Save("test_file.h5", test) == true);
+ BOOST_REQUIRE(data::Save("test_file.hdf5", test) == true);
+ BOOST_REQUIRE(data::Save("test_file.hdf", test) == true);
+ BOOST_REQUIRE(data::Save("test_file.he5", test) == true);
+
+ // Now load them all and verify they were saved okay.
+ BOOST_REQUIRE(data::Load("test_file.h5", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; ++i)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ // Make sure the other extensions work too.
+ BOOST_REQUIRE(data::Load("test_file.hdf5", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; ++i)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ BOOST_REQUIRE(data::Load("test_file.hdf", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; ++i)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ BOOST_REQUIRE(data::Load("test_file.he5", test) == true);
+
+ BOOST_REQUIRE_EQUAL(test.n_rows, 4);
+ BOOST_REQUIRE_EQUAL(test.n_cols, 2);
+
+ for (int i = 0; i < 8; ++i)
+ BOOST_REQUIRE_CLOSE(test[i], (double) (i + 1), 1e-5);
+
+ remove("test_file.h5");
+ remove("test_file.hdf");
+ remove("test_file.hdf5");
+ remove("test_file.he5");
+}
+#else
+/**
+ * Ensure saving as HDF5 fails.
+ */
+BOOST_AUTO_TEST_CASE(NoHDF5Test)
+{
+ arma::mat test;
+ test.randu(5, 5);
+
+ BOOST_REQUIRE(data::Save("test_file.h5", test) == false);
+ BOOST_REQUIRE(data::Save("test_file.hdf5", test) == false);
+ BOOST_REQUIRE(data::Save("test_file.hdf", test) == false);
+ BOOST_REQUIRE(data::Save("test_file.he5", test) == false);
+}
+#endif
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/local_coordinate_coding_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/local_coordinate_coding_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/local_coordinate_coding_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,142 +0,0 @@
-/**
- * @file local_coordinate_coding_test.cpp
- *
- * Test for Local Coordinate Coding
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-// Note: We don't use BOOST_REQUIRE_CLOSE in the code below because we need
-// to use FPC_WEAK, and it's not at all intuitive how to do that.
-#include <armadillo>
-#include <mlpack/methods/local_coordinate_coding/lcc.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace arma;
-using namespace mlpack;
-using namespace mlpack::regression;
-using namespace mlpack::lcc;
-
-BOOST_AUTO_TEST_SUITE(LocalCoordinateCodingTest);
-
-void VerifyCorrectness(vec beta, vec errCorr, double lambda)
-{
- const double tol = 1e-12;
- size_t nDims = beta.n_elem;
- for(size_t j = 0; j < nDims; j++)
- {
- if (beta(j) == 0)
- {
- // make sure that errCorr(j) <= lambda
- BOOST_REQUIRE_SMALL(std::max(fabs(errCorr(j)) - lambda, 0.0), tol);
- }
- else if (beta(j) < 0)
- {
- // make sure that errCorr(j) == lambda
- BOOST_REQUIRE_SMALL(errCorr(j) - lambda, tol);
- }
- else
- { // beta(j) > 0
- // make sure that errCorr(j) == -lambda
- BOOST_REQUIRE_SMALL(errCorr(j) + lambda, tol);
- }
- }
-}
-
-
-BOOST_AUTO_TEST_CASE(LocalCoordinateCodingTestCodingStep)
-{
- double lambda1 = 0.1;
- uword nAtoms = 25;
-
- mat X;
- X.load("mnist_first250_training_4s_and_9s.arm");
- uword nPoints = X.n_cols;
-
- // normalize each point since these are images
- for (uword i = 0; i < nPoints; i++)
- {
- X.col(i) /= norm(X.col(i), 2);
- }
-
- LocalCoordinateCoding<> lcc(X, nAtoms, lambda1);
- lcc.OptimizeCode();
-
- mat D = lcc.Dictionary();
- mat Z = lcc.Codes();
-
- for(uword i = 0; i < nPoints; i++) {
- vec sq_dists = vec(nAtoms);
- for(uword j = 0; j < nAtoms; j++) {
- vec diff = D.unsafe_col(j) - X.unsafe_col(i);
- sq_dists[j] = dot(diff, diff);
- }
- mat Dprime = D * diagmat(1.0 / sq_dists);
- mat zPrime = Z.unsafe_col(i) % sq_dists;
-
- vec errCorr = trans(Dprime) * (Dprime * zPrime - X.unsafe_col(i));
- VerifyCorrectness(zPrime, errCorr, 0.5 * lambda1);
- }
-}
-
-BOOST_AUTO_TEST_CASE(LocalCoordinateCodingTestDictionaryStep)
-{
- const double tol = 1e-12;
-
- double lambda = 0.1;
- uword nAtoms = 25;
-
- mat X;
- X.load("mnist_first250_training_4s_and_9s.arm");
- uword nPoints = X.n_cols;
-
- // normalize each point since these are images
- for (uword i = 0; i < nPoints; i++)
- {
- X.col(i) /= norm(X.col(i), 2);
- }
-
- LocalCoordinateCoding<> lcc(X, nAtoms, lambda);
- lcc.OptimizeCode();
- mat Z = lcc.Codes();
- uvec adjacencies = find(Z);
- lcc.OptimizeDictionary(adjacencies);
-
- mat D = lcc.Dictionary();
-
- mat grad = zeros(D.n_rows, D.n_cols);
- for (uword i = 0; i < nPoints; i++)
- {
- grad += (D - repmat(X.unsafe_col(i), 1, nAtoms)) *
- diagmat(abs(Z.unsafe_col(i)));
- }
- grad = lambda * grad + (D * Z - X) * trans(Z);
-
- BOOST_REQUIRE_SMALL(norm(grad, "fro"), tol);
-
-}
-
-/*
-BOOST_AUTO_TEST_CASE(LocalCoordinateCodingTestWhole)
-{
-
-}
-*/
-
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/local_coordinate_coding_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/local_coordinate_coding_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/local_coordinate_coding_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/local_coordinate_coding_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,142 @@
+/**
+ * @file local_coordinate_coding_test.cpp
+ *
+ * Test for Local Coordinate Coding
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+// Note: We don't use BOOST_REQUIRE_CLOSE in the code below because we need
+// to use FPC_WEAK, and it's not at all intuitive how to do that.
+#include <armadillo>
+#include <mlpack/methods/local_coordinate_coding/lcc.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace arma;
+using namespace mlpack;
+using namespace mlpack::regression;
+using namespace mlpack::lcc;
+
+BOOST_AUTO_TEST_SUITE(LocalCoordinateCodingTest);
+
+void VerifyCorrectness(vec beta, vec errCorr, double lambda)
+{
+ const double tol = 1e-12;
+ size_t nDims = beta.n_elem;
+ for(size_t j = 0; j < nDims; j++)
+ {
+ if (beta(j) == 0)
+ {
+ // make sure that errCorr(j) <= lambda
+ BOOST_REQUIRE_SMALL(std::max(fabs(errCorr(j)) - lambda, 0.0), tol);
+ }
+ else if (beta(j) < 0)
+ {
+ // make sure that errCorr(j) == lambda
+ BOOST_REQUIRE_SMALL(errCorr(j) - lambda, tol);
+ }
+ else
+ { // beta(j) > 0
+ // make sure that errCorr(j) == -lambda
+ BOOST_REQUIRE_SMALL(errCorr(j) + lambda, tol);
+ }
+ }
+}
+
+
+BOOST_AUTO_TEST_CASE(LocalCoordinateCodingTestCodingStep)
+{
+ double lambda1 = 0.1;
+ uword nAtoms = 25;
+
+ mat X;
+ X.load("mnist_first250_training_4s_and_9s.arm");
+ uword nPoints = X.n_cols;
+
+ // normalize each point since these are images
+ for (uword i = 0; i < nPoints; i++)
+ {
+ X.col(i) /= norm(X.col(i), 2);
+ }
+
+ LocalCoordinateCoding<> lcc(X, nAtoms, lambda1);
+ lcc.OptimizeCode();
+
+ mat D = lcc.Dictionary();
+ mat Z = lcc.Codes();
+
+ for(uword i = 0; i < nPoints; i++) {
+ vec sq_dists = vec(nAtoms);
+ for(uword j = 0; j < nAtoms; j++) {
+ vec diff = D.unsafe_col(j) - X.unsafe_col(i);
+ sq_dists[j] = dot(diff, diff);
+ }
+ mat Dprime = D * diagmat(1.0 / sq_dists);
+ mat zPrime = Z.unsafe_col(i) % sq_dists;
+
+ vec errCorr = trans(Dprime) * (Dprime * zPrime - X.unsafe_col(i));
+ VerifyCorrectness(zPrime, errCorr, 0.5 * lambda1);
+ }
+}
+
+BOOST_AUTO_TEST_CASE(LocalCoordinateCodingTestDictionaryStep)
+{
+ const double tol = 1e-12;
+
+ double lambda = 0.1;
+ uword nAtoms = 25;
+
+ mat X;
+ X.load("mnist_first250_training_4s_and_9s.arm");
+ uword nPoints = X.n_cols;
+
+ // normalize each point since these are images
+ for (uword i = 0; i < nPoints; i++)
+ {
+ X.col(i) /= norm(X.col(i), 2);
+ }
+
+ LocalCoordinateCoding<> lcc(X, nAtoms, lambda);
+ lcc.OptimizeCode();
+ mat Z = lcc.Codes();
+ uvec adjacencies = find(Z);
+ lcc.OptimizeDictionary(adjacencies);
+
+ mat D = lcc.Dictionary();
+
+ mat grad = zeros(D.n_rows, D.n_cols);
+ for (uword i = 0; i < nPoints; i++)
+ {
+ grad += (D - repmat(X.unsafe_col(i), 1, nAtoms)) *
+ diagmat(abs(Z.unsafe_col(i)));
+ }
+ grad = lambda * grad + (D * Z - X) * trans(Z);
+
+ BOOST_REQUIRE_SMALL(norm(grad, "fro"), tol);
+
+}
+
+/*
+BOOST_AUTO_TEST_CASE(LocalCoordinateCodingTestWhole)
+{
+
+}
+*/
+
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lrsdp_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/lrsdp_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lrsdp_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,187 +0,0 @@
-/**
- * @file lrsdp_test.cpp
- * @author Ryan Curtin
- *
- * Tests for LR-SDP (core/optimizers/lrsdp/).
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/core/optimizers/lrsdp/lrsdp.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::optimization;
-
-BOOST_AUTO_TEST_SUITE(LRSDPTest);
-
-/**
- * Create a Lovasz-Theta initial point.
- */
-void createLovaszThetaInitialPoint(const arma::mat& edges,
- arma::mat& coordinates)
-{
- // Get the number of vertices in the problem.
- const size_t vertices = max(max(edges)) + 1;
-
- const size_t m = edges.n_cols + 1;
- float r = 0.5 + sqrt(0.25 + 2 * m);
- if (ceil(r) > vertices)
- r = vertices; // An upper bound on the dimension.
-
- coordinates.set_size(vertices, ceil(r));
-
- // Now we set the entries of the initial matrix according to the formula given
- // in Section 4 of Monteiro and Burer.
- for (size_t i = 0; i < vertices; ++i)
- {
- for (size_t j = 0; j < ceil(r); ++j)
- {
- if (i == j)
- coordinates(i, j) = sqrt(1.0 / r) + sqrt(1.0 / (vertices * m));
- else
- coordinates(i, j) = sqrt(1.0 / (vertices * m));
- }
- }
-}
-
-/**
- * Prepare an LRSDP object to solve the Lovasz-Theta SDP in the manner detailed
- * in Monteiro + Burer 2004. The list of edges in the graph must be given; that
- * is all that is necessary to set up the problem. A matrix which will contain
- * initial point coordinates should be given also.
- */
-void setupLovaszTheta(const arma::mat& edges,
- LRSDP& lovasz)
-{
- // Get the number of vertices in the problem.
- const size_t vertices = max(max(edges)) + 1;
-
- // C = -(e e^T) = -ones().
- lovasz.C().ones(vertices, vertices);
- lovasz.C() *= -1;
-
- // b_0 = 1; else = 0.
- lovasz.B().zeros(edges.n_cols);
- lovasz.B()[0] = 1;
-
- // All of the matrices will just contain coordinates because they are
- // super-sparse (two entries each). Except for A_0, which is I_n.
- lovasz.AModes().ones();
- lovasz.AModes()[0] = 0;
-
- // A_0 = I_n.
- lovasz.A()[0].eye(vertices, vertices);
-
- // A_ij only has ones at (i, j) and (j, i) and 1 elsewhere.
- for (size_t i = 0; i < edges.n_cols; ++i)
- {
- arma::mat a(3, 2);
-
- a(0, 0) = edges(0, i);
- a(1, 0) = edges(1, i);
- a(2, 0) = 1;
-
- a(0, 1) = edges(1, i);
- a(1, 1) = edges(0, i);
- a(2, 1) = 1;
-
- lovasz.A()[i + 1] = a;
- }
-
- // Set the Lagrange multipliers right.
- lovasz.AugLag().Lambda().ones(edges.n_cols);
- lovasz.AugLag().Lambda() *= -1;
- lovasz.AugLag().Lambda()[0] = -double(vertices);
-}
-
-/**
- * johnson8-4-4.co test case for Lovasz-Theta LRSDP.
- * See Monteiro and Burer 2004.
- */
-BOOST_AUTO_TEST_CASE(Johnson844LovaszThetaSDP)
-{
- // Load the edges.
- arma::mat edges;
- data::Load("johnson8-4-4.csv", edges, true);
-
- // The LRSDP itself and the initial point.
- arma::mat coordinates;
-
- createLovaszThetaInitialPoint(edges, coordinates);
-
- LRSDP lovasz(edges.n_cols + 1, coordinates);
-
- setupLovaszTheta(edges, lovasz);
-
- double finalValue = lovasz.Optimize(coordinates);
-
- // Final value taken from Monteiro + Burer 2004.
- BOOST_REQUIRE_CLOSE(finalValue, -14.0, 1e-5);
-
- // Now ensure that all the constraints are satisfied.
- arma::mat rrt = coordinates * trans(coordinates);
- BOOST_REQUIRE_CLOSE(trace(rrt), 1.0, 1e-5);
-
- // All those edge constraints...
- for (size_t i = 0; i < edges.n_cols; ++i)
- {
- BOOST_REQUIRE_SMALL(rrt(edges(0, i), edges(1, i)), 1e-5);
- BOOST_REQUIRE_SMALL(rrt(edges(1, i), edges(0, i)), 1e-5);
- }
-}
-
-/**
- * keller4.co test case for Lovasz-Theta LRSDP.
- * This is commented out because it takes a long time to run.
- * See Monteiro and Burer 2004.
- *
-BOOST_AUTO_TEST_CASE(Keller4LovaszThetaSDP)
-{
- // Load the edges.
- arma::mat edges;
- data::Load("keller4.csv", edges, true);
-
- // The LRSDP itself and the initial point.
- arma::mat coordinates;
-
- createLovaszThetaInitialPoint(edges, coordinates);
-
- LRSDP lovasz(edges.n_cols, coordinates);
-
- setupLovaszTheta(edges, lovasz);
-
- double finalValue = lovasz.Optimize(coordinates);
-
- // Final value taken from Monteiro + Burer 2004.
- BOOST_REQUIRE_CLOSE(finalValue, -14.013, 1e-2); // Not as much precision...
- // The SB method came to -14.013, but M&B's method only came to -14.005.
-
- // Now ensure that all the constraints are satisfied.
- arma::mat rrt = coordinates * trans(coordinates);
- BOOST_REQUIRE_CLOSE(trace(rrt), 1.0, 1e-5);
-
- // All those edge constraints...
- for (size_t i = 0; i < edges.n_cols; ++i)
- {
- BOOST_REQUIRE_SMALL(rrt(edges(0, i), edges(1, i)), 1e-3);
- BOOST_REQUIRE_SMALL(rrt(edges(1, i), edges(0, i)), 1e-3);
- }
-}*/
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lrsdp_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/lrsdp_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lrsdp_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lrsdp_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,187 @@
+/**
+ * @file lrsdp_test.cpp
+ * @author Ryan Curtin
+ *
+ * Tests for LR-SDP (core/optimizers/lrsdp/).
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/optimizers/lrsdp/lrsdp.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::optimization;
+
+BOOST_AUTO_TEST_SUITE(LRSDPTest);
+
+/**
+ * Create a Lovasz-Theta initial point.
+ */
+void createLovaszThetaInitialPoint(const arma::mat& edges,
+ arma::mat& coordinates)
+{
+ // Get the number of vertices in the problem.
+ const size_t vertices = max(max(edges)) + 1;
+
+ const size_t m = edges.n_cols + 1;
+ float r = 0.5 + sqrt(0.25 + 2 * m);
+ if (ceil(r) > vertices)
+ r = vertices; // An upper bound on the dimension.
+
+ coordinates.set_size(vertices, ceil(r));
+
+ // Now we set the entries of the initial matrix according to the formula given
+ // in Section 4 of Monteiro and Burer.
+ for (size_t i = 0; i < vertices; ++i)
+ {
+ for (size_t j = 0; j < ceil(r); ++j)
+ {
+ if (i == j)
+ coordinates(i, j) = sqrt(1.0 / r) + sqrt(1.0 / (vertices * m));
+ else
+ coordinates(i, j) = sqrt(1.0 / (vertices * m));
+ }
+ }
+}
+
+/**
+ * Prepare an LRSDP object to solve the Lovasz-Theta SDP in the manner detailed
+ * in Monteiro + Burer 2004. The list of edges in the graph must be given; that
+ * is all that is necessary to set up the problem. A matrix which will contain
+ * initial point coordinates should be given also.
+ */
+void setupLovaszTheta(const arma::mat& edges,
+ LRSDP& lovasz)
+{
+ // Get the number of vertices in the problem.
+ const size_t vertices = max(max(edges)) + 1;
+
+ // C = -(e e^T) = -ones().
+ lovasz.C().ones(vertices, vertices);
+ lovasz.C() *= -1;
+
+ // b_0 = 1; else = 0.
+ lovasz.B().zeros(edges.n_cols);
+ lovasz.B()[0] = 1;
+
+ // All of the matrices will just contain coordinates because they are
+ // super-sparse (two entries each). Except for A_0, which is I_n.
+ lovasz.AModes().ones();
+ lovasz.AModes()[0] = 0;
+
+ // A_0 = I_n.
+ lovasz.A()[0].eye(vertices, vertices);
+
+ // A_ij only has ones at (i, j) and (j, i) and 1 elsewhere.
+ for (size_t i = 0; i < edges.n_cols; ++i)
+ {
+ arma::mat a(3, 2);
+
+ a(0, 0) = edges(0, i);
+ a(1, 0) = edges(1, i);
+ a(2, 0) = 1;
+
+ a(0, 1) = edges(1, i);
+ a(1, 1) = edges(0, i);
+ a(2, 1) = 1;
+
+ lovasz.A()[i + 1] = a;
+ }
+
+ // Set the Lagrange multipliers right.
+ lovasz.AugLag().Lambda().ones(edges.n_cols);
+ lovasz.AugLag().Lambda() *= -1;
+ lovasz.AugLag().Lambda()[0] = -double(vertices);
+}
+
+/**
+ * johnson8-4-4.co test case for Lovasz-Theta LRSDP.
+ * See Monteiro and Burer 2004.
+ */
+BOOST_AUTO_TEST_CASE(Johnson844LovaszThetaSDP)
+{
+ // Load the edges.
+ arma::mat edges;
+ data::Load("johnson8-4-4.csv", edges, true);
+
+ // The LRSDP itself and the initial point.
+ arma::mat coordinates;
+
+ createLovaszThetaInitialPoint(edges, coordinates);
+
+ LRSDP lovasz(edges.n_cols + 1, coordinates);
+
+ setupLovaszTheta(edges, lovasz);
+
+ double finalValue = lovasz.Optimize(coordinates);
+
+ // Final value taken from Monteiro + Burer 2004.
+ BOOST_REQUIRE_CLOSE(finalValue, -14.0, 1e-5);
+
+ // Now ensure that all the constraints are satisfied.
+ arma::mat rrt = coordinates * trans(coordinates);
+ BOOST_REQUIRE_CLOSE(trace(rrt), 1.0, 1e-5);
+
+ // All those edge constraints...
+ for (size_t i = 0; i < edges.n_cols; ++i)
+ {
+ BOOST_REQUIRE_SMALL(rrt(edges(0, i), edges(1, i)), 1e-5);
+ BOOST_REQUIRE_SMALL(rrt(edges(1, i), edges(0, i)), 1e-5);
+ }
+}
+
+/**
+ * keller4.co test case for Lovasz-Theta LRSDP.
+ * This is commented out because it takes a long time to run.
+ * See Monteiro and Burer 2004.
+ *
+BOOST_AUTO_TEST_CASE(Keller4LovaszThetaSDP)
+{
+ // Load the edges.
+ arma::mat edges;
+ data::Load("keller4.csv", edges, true);
+
+ // The LRSDP itself and the initial point.
+ arma::mat coordinates;
+
+ createLovaszThetaInitialPoint(edges, coordinates);
+
+ LRSDP lovasz(edges.n_cols, coordinates);
+
+ setupLovaszTheta(edges, lovasz);
+
+ double finalValue = lovasz.Optimize(coordinates);
+
+ // Final value taken from Monteiro + Burer 2004.
+ BOOST_REQUIRE_CLOSE(finalValue, -14.013, 1e-2); // Not as much precision...
+ // The SB method came to -14.013, but M&B's method only came to -14.005.
+
+ // Now ensure that all the constraints are satisfied.
+ arma::mat rrt = coordinates * trans(coordinates);
+ BOOST_REQUIRE_CLOSE(trace(rrt), 1.0, 1e-5);
+
+ // All those edge constraints...
+ for (size_t i = 0; i < edges.n_cols; ++i)
+ {
+ BOOST_REQUIRE_SMALL(rrt(edges(0, i), edges(1, i)), 1e-3);
+ BOOST_REQUIRE_SMALL(rrt(edges(1, i), edges(0, i)), 1e-3);
+ }
+}*/
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lsh_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/lsh_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lsh_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,130 +0,0 @@
-/**
- * @file lsh_test.cpp
- *
- * Unit tests for the 'LSHSearch' class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-
-// So that we can test private members. This is hackish (for now).
-#define private public
-#include <mlpack/methods/lsh/lsh_search.hpp>
-#undef private
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace std;
-using namespace mlpack;
-using namespace mlpack::neighbor;
-
-BOOST_AUTO_TEST_SUITE(LSHTest);
-
-BOOST_AUTO_TEST_CASE(LSHSearchTest)
-{
- // Force to specific random seed for these results.
- math::RandomSeed(0);
-
- // Precomputed hash width value.
- const double hashWidth = 4.24777;
-
- arma::mat rdata(2, 10);
- rdata << 3 << 2 << 4 << 3 << 5 << 6 << 0 << 8 << 3 << 1 << arma::endr <<
- 0 << 3 << 4 << 7 << 8 << 4 << 1 << 0 << 4 << 3 << arma::endr;
-
- arma::mat qdata(2, 3);
- qdata << 3 << 2 << 0 << arma::endr << 5 << 3 << 4 << arma::endr;
-
- // INPUT TO LSH:
- // Number of points: 10
- // Number of dimensions: 2
- // Number of projections per table: 'numProj' = 3
- // Number of hash tables: 'numTables' = 2
- // hashWidth (computed): 'hashWidth' = 4.24777
- // Second hash size: 'secondHashSize' = 11
- // Size of the bucket: 'bucketSize' = 3
-
- // Things obtained by random sampling listed in the sequences
- // as they will be obtained in the 'LSHSearch::BuildHash()' private function
- // in 'LSHSearch' class.
- //
- // 1. The weights of the second hash obtained as:
- // secondHashWeights = arma::floor(arma::randu(3) * 11.0);
- // COR.SOL.: secondHashWeights = [9, 4, 8];
- //
- // 2. The offsets for all the 3 projections in each of the 2 tables:
- // offsets.randu(3, 2)
- // COR.SOL.: [0.7984 0.3352; 0.9116 0.7682; 0.1976 0.2778]
- // offsets *= hashWidth
- // COR.SOL.: [3.3916 1.4240; 3.8725 3.2633; 0.8392 1.1799]
- //
- // 3. The (2 x 3) projection matrices for the 2 tables:
- // projMat.randn(2, 3)
- // COR.SOL.: Proj. Mat 1: [2.7020 0.0187 0.4355; 1.3692 0.6933 0.0416]
- // COR.SOL.: Proj. Mat 2: [-0.3961 -0.2666 1.1001; 0.3895 -1.5118 -1.3964]
- LSHSearch<> lsh_test(rdata, qdata, 3, 2, hashWidth, 11, 3);
-// LSHSearch<> lsh_test(rdata, qdata, 3, 2, 0.0, 11, 3);
-
- // Given this, the 'LSHSearch::bucketRowInHashTable' should be:
- // COR.SOL.: [2 11 4 7 6 3 11 0 5 1 8]
- //
- // The 'LSHSearch::bucketContentSize' should be:
- // COR.SOL.: [2 0 1 1 3 1 0 3 3 3 1]
- //
- // The final hash table 'LSHSearch::secondHashTable' should be
- // of size (3 x 9) with the following content:
- // COR.SOL.:
- // [0 2 4; 1 7 8; 3 9 10; 5 10 10; 6 10 10; 0 5 6; 1 2 8; 3 10 10; 4 10 10]
-
- arma::Mat<size_t> neighbors;
- arma::mat distances;
-
- lsh_test.Search(2, neighbors, distances);
-
- // The private function 'LSHSearch::ReturnIndicesFromTable(0, refInds)'
- // should hash the query 0 into the following buckets:
- // COR.SOL.: Table 1 Bucket 7, Table 2 Bucket 0, refInds = [0 2 3 4 9]
- //
- // The private function 'LSHSearch::ReturnIndicesFromTable(1, refInds)'
- // should hash the query 1 into the following buckets:
- // COR.SOL.: Table 1 Bucket 9, Table 2 Bucket 4, refInds = [1 2 7 8]
- //
- // The private function 'LSHSearch::ReturnIndicesFromTable(2, refInds)'
- // should hash the query 2 into the following buckets:
- // COR.SOL.: Table 1 Bucket 0, Table 2 Bucket 7, refInds = [0 2 3 4 9]
-
- // After search
- // COR.SOL.: 'neighbors' = [2 1 9; 3 8 2]
- // COR.SOL.: 'distances' = [2 0 2; 4 2 16]
-
- arma::Mat<size_t> true_neighbors(2, 3);
- true_neighbors << 2 << 1 << 9 << arma::endr << 3 << 8 << 2 << arma::endr;
- arma::mat true_distances(2, 3);
- true_distances << 2 << 0 << 2 << arma::endr << 4 << 2 << 16 << arma::endr;
-
- for (size_t i = 0; i < 3; i++)
- {
- for (size_t j = 0; j < 2; j++)
- {
-// BOOST_REQUIRE_EQUAL(neighbors(j, i), true_neighbors(j, i));
-// BOOST_REQUIRE_CLOSE(distances(j, i), true_distances(j, i), 1e-5);
- }
- }
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lsh_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/lsh_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lsh_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/lsh_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,130 @@
+/**
+ * @file lsh_test.cpp
+ *
+ * Unit tests for the 'LSHSearch' class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+
+// So that we can test private members. This is hackish (for now).
+#define private public
+#include <mlpack/methods/lsh/lsh_search.hpp>
+#undef private
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::neighbor;
+
+BOOST_AUTO_TEST_SUITE(LSHTest);
+
+BOOST_AUTO_TEST_CASE(LSHSearchTest)
+{
+ // Force to specific random seed for these results.
+ math::RandomSeed(0);
+
+ // Precomputed hash width value.
+ const double hashWidth = 4.24777;
+
+ arma::mat rdata(2, 10);
+ rdata << 3 << 2 << 4 << 3 << 5 << 6 << 0 << 8 << 3 << 1 << arma::endr <<
+ 0 << 3 << 4 << 7 << 8 << 4 << 1 << 0 << 4 << 3 << arma::endr;
+
+ arma::mat qdata(2, 3);
+ qdata << 3 << 2 << 0 << arma::endr << 5 << 3 << 4 << arma::endr;
+
+ // INPUT TO LSH:
+ // Number of points: 10
+ // Number of dimensions: 2
+ // Number of projections per table: 'numProj' = 3
+ // Number of hash tables: 'numTables' = 2
+ // hashWidth (computed): 'hashWidth' = 4.24777
+ // Second hash size: 'secondHashSize' = 11
+ // Size of the bucket: 'bucketSize' = 3
+
+ // Things obtained by random sampling listed in the sequences
+ // as they will be obtained in the 'LSHSearch::BuildHash()' private function
+ // in 'LSHSearch' class.
+ //
+ // 1. The weights of the second hash obtained as:
+ // secondHashWeights = arma::floor(arma::randu(3) * 11.0);
+ // COR.SOL.: secondHashWeights = [9, 4, 8];
+ //
+ // 2. The offsets for all the 3 projections in each of the 2 tables:
+ // offsets.randu(3, 2)
+ // COR.SOL.: [0.7984 0.3352; 0.9116 0.7682; 0.1976 0.2778]
+ // offsets *= hashWidth
+ // COR.SOL.: [3.3916 1.4240; 3.8725 3.2633; 0.8392 1.1799]
+ //
+ // 3. The (2 x 3) projection matrices for the 2 tables:
+ // projMat.randn(2, 3)
+ // COR.SOL.: Proj. Mat 1: [2.7020 0.0187 0.4355; 1.3692 0.6933 0.0416]
+ // COR.SOL.: Proj. Mat 2: [-0.3961 -0.2666 1.1001; 0.3895 -1.5118 -1.3964]
+ LSHSearch<> lsh_test(rdata, qdata, 3, 2, hashWidth, 11, 3);
+// LSHSearch<> lsh_test(rdata, qdata, 3, 2, 0.0, 11, 3);
+
+ // Given this, the 'LSHSearch::bucketRowInHashTable' should be:
+ // COR.SOL.: [2 11 4 7 6 3 11 0 5 1 8]
+ //
+ // The 'LSHSearch::bucketContentSize' should be:
+ // COR.SOL.: [2 0 1 1 3 1 0 3 3 3 1]
+ //
+ // The final hash table 'LSHSearch::secondHashTable' should be
+ // of size (3 x 9) with the following content:
+ // COR.SOL.:
+ // [0 2 4; 1 7 8; 3 9 10; 5 10 10; 6 10 10; 0 5 6; 1 2 8; 3 10 10; 4 10 10]
+
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ lsh_test.Search(2, neighbors, distances);
+
+ // The private function 'LSHSearch::ReturnIndicesFromTable(0, refInds)'
+ // should hash the query 0 into the following buckets:
+ // COR.SOL.: Table 1 Bucket 7, Table 2 Bucket 0, refInds = [0 2 3 4 9]
+ //
+ // The private function 'LSHSearch::ReturnIndicesFromTable(1, refInds)'
+ // should hash the query 1 into the following buckets:
+ // COR.SOL.: Table 1 Bucket 9, Table 2 Bucket 4, refInds = [1 2 7 8]
+ //
+ // The private function 'LSHSearch::ReturnIndicesFromTable(2, refInds)'
+ // should hash the query 2 into the following buckets:
+ // COR.SOL.: Table 1 Bucket 0, Table 2 Bucket 7, refInds = [0 2 3 4 9]
+
+ // After search
+ // COR.SOL.: 'neighbors' = [2 1 9; 3 8 2]
+ // COR.SOL.: 'distances' = [2 0 2; 4 2 16]
+
+ arma::Mat<size_t> true_neighbors(2, 3);
+ true_neighbors << 2 << 1 << 9 << arma::endr << 3 << 8 << 2 << arma::endr;
+ arma::mat true_distances(2, 3);
+ true_distances << 2 << 0 << 2 << arma::endr << 4 << 2 << 16 << arma::endr;
+
+ for (size_t i = 0; i < 3; i++)
+ {
+ for (size_t j = 0; j < 2; j++)
+ {
+// BOOST_REQUIRE_EQUAL(neighbors(j, i), true_neighbors(j, i));
+// BOOST_REQUIRE_CLOSE(distances(j, i), true_distances(j, i), 1e-5);
+ }
+ }
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/math_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/math_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/math_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,540 +0,0 @@
-/**
- * @file math_test.cpp
- * @author Ryan Curtin
- *
- * Tests for everything in the math:: namespace.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core/math/clamp.hpp>
-#include <mlpack/core/math/random.hpp>
-#include <mlpack/core/math/range.hpp>
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace math;
-
-BOOST_AUTO_TEST_SUITE(MathTest);
-
-/**
- * Verify that the empty constructor creates an empty range.
- */
-BOOST_AUTO_TEST_CASE(RangeEmptyConstructor)
-{
- Range x = Range();
-
- // Just verify that it is empty.
- BOOST_REQUIRE_GT(x.Lo(), x.Hi());
-}
-
-/**
- * Verify that the point constructor correctly creates a range that is just a
- * point.
- */
-BOOST_AUTO_TEST_CASE(RangePointConstructor)
-{
- Range x(10.0);
-
- BOOST_REQUIRE_CLOSE(x.Lo(), x.Hi(), 1e-25);
- BOOST_REQUIRE_SMALL(x.Width(), 1e-5);
- BOOST_REQUIRE_CLOSE(x.Lo(), 10.0, 1e-25);
- BOOST_REQUIRE_CLOSE(x.Hi(), 10.0, 1e-25);
-}
-
-/**
- * Verify that the range constructor correctly creates the range.
- */
-BOOST_AUTO_TEST_CASE(RangeConstructor)
-{
- Range x(0.5, 5.5);
-
- BOOST_REQUIRE_CLOSE(x.Lo(), 0.5, 1e-25);
- BOOST_REQUIRE_CLOSE(x.Hi(), 5.5, 1e-25);
-}
-
-/**
- * Test that we get the width correct.
- */
-BOOST_AUTO_TEST_CASE(RangeWidth)
-{
- Range x(0.0, 10.0);
-
- BOOST_REQUIRE_CLOSE(x.Width(), 10.0, 1e-20);
-
- // Make it empty.
- x.Hi() = 0.0;
-
- BOOST_REQUIRE_SMALL(x.Width(), 1e-5);
-
- // Make it negative.
- x.Hi() = -2.0;
-
- BOOST_REQUIRE_SMALL(x.Width(), 1e-5);
-
- // Just one more test.
- x.Lo() = -5.2;
- x.Hi() = 5.2;
-
- BOOST_REQUIRE_CLOSE(x.Width(), 10.4, 1e-5);
-}
-
-/**
- * Test that we get the midpoint correct.
- */
-BOOST_AUTO_TEST_CASE(RangeMidpoint)
-{
- Range x(0.0, 10.0);
-
- BOOST_REQUIRE_CLOSE(x.Mid(), 5.0, 1e-5);
-
- x.Lo() = -5.0;
-
- BOOST_REQUIRE_CLOSE(x.Mid(), 2.5, 1e-5);
-}
-
-/**
- * Test that we can expand to include other ranges correctly.
- */
-BOOST_AUTO_TEST_CASE(RangeIncludeOther)
-{
- // We need to test both |= and |.
- // We have three cases: non-overlapping; overlapping; equivalent, and then a
- // couple permutations (switch left with right and make sure it still works).
- Range x(0.0, 2.0);
- Range y(3.0, 5.0);
-
- Range z(0.0, 2.0); // Used for operator|=().
- Range w;
- z |= y;
- w = x | y;
-
- BOOST_REQUIRE_SMALL(z.Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(z.Hi(), 5.0, 1e-5);
- BOOST_REQUIRE_SMALL(w.Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(w.Hi(), 5.0, 1e-5);
-
- // Switch operator precedence.
- z = y;
- z |= x;
- w = y | x;
-
- BOOST_REQUIRE_SMALL(z.Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(z.Hi(), 5.0, 1e-5);
- BOOST_REQUIRE_SMALL(w.Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(w.Hi(), 5.0, 1e-5);
-
- // Now make them overlapping.
- x = Range(0.0, 3.5);
- y = Range(3.0, 4.0);
-
- z = x;
- z |= y;
- w = x | y;
-
- BOOST_REQUIRE_SMALL(z.Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(z.Hi(), 4.0, 1e-5);
- BOOST_REQUIRE_SMALL(w.Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(w.Hi(), 4.0, 1e-5);
-
- // Switch operator precedence.
- z = y;
- z |= x;
- w = y | x;
-
- BOOST_REQUIRE_SMALL(z.Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(z.Hi(), 4.0, 1e-5);
- BOOST_REQUIRE_SMALL(w.Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(w.Hi(), 4.0, 1e-5);
-
- // Now the equivalent case.
- x = Range(0.0, 2.0);
- y = Range(0.0, 2.0);
-
- z = x;
- z |= y;
- w = x | y;
-
- BOOST_REQUIRE_SMALL(z.Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(z.Hi(), 2.0, 1e-5);
- BOOST_REQUIRE_SMALL(w.Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(w.Hi(), 2.0, 1e-5);
-
- z = y;
- z |= x;
- w = y | x;
-
- BOOST_REQUIRE_SMALL(z.Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(z.Hi(), 2.0, 1e-5);
- BOOST_REQUIRE_SMALL(w.Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(w.Hi(), 2.0, 1e-5);
-}
-
-/**
- * Test that we can 'and' ranges correctly.
- */
-BOOST_AUTO_TEST_CASE(RangeIntersectOther)
-{
- // We need to test both &= and &.
- // We have three cases: non-overlapping, overlapping; equivalent, and then a
- // couple permutations (switch left with right and make sure it still works).
- Range x(0.0, 2.0);
- Range y(3.0, 5.0);
-
- Range z(0.0, 2.0);
- Range w;
- z &= y;
- w = x & y;
-
- BOOST_REQUIRE_SMALL(z.Width(), 1e-5);
- BOOST_REQUIRE_SMALL(w.Width(), 1e-5);
-
- // Reverse operator precedence.
- z = y;
- z &= x;
- w = y & x;
-
- BOOST_REQUIRE_SMALL(z.Width(), 1e-5);
- BOOST_REQUIRE_SMALL(w.Width(), 1e-5);
-
- // Now make them overlapping.
- x = Range(0.0, 3.5);
- y = Range(3.0, 4.0);
-
- z = x;
- z &= y;
- w = x & y;
-
- BOOST_REQUIRE_CLOSE(z.Lo(), 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(z.Hi(), 3.5, 1e-5);
- BOOST_REQUIRE_CLOSE(w.Lo(), 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(w.Hi(), 3.5, 1e-5);
-
- // Reverse operator precedence.
- z = y;
- z &= x;
- w = y & x;
-
- BOOST_REQUIRE_CLOSE(z.Lo(), 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(z.Hi(), 3.5, 1e-5);
- BOOST_REQUIRE_CLOSE(w.Lo(), 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(w.Hi(), 3.5, 1e-5);
-
- // Now make them equivalent.
- x = Range(2.0, 4.0);
- y = Range(2.0, 4.0);
-
- z = x;
- z &= y;
- w = x & y;
-
- BOOST_REQUIRE_CLOSE(z.Lo(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(z.Hi(), 4.0, 1e-5);
- BOOST_REQUIRE_CLOSE(w.Lo(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(w.Hi(), 4.0, 1e-5);
-}
-
-/**
- * Test multiplication of a range with a double.
- */
-BOOST_AUTO_TEST_CASE(RangeMultiply)
-{
- // We need to test both * and *=, as well as both cases of *.
- // We'll try with a couple of numbers: -1, 0, 2.
- // And we'll have a couple of cases for bounds: strictly less than zero;
- // including zero; and strictly greater than zero.
- //
- // So, nine total cases.
- Range x(-5.0, -3.0);
- Range y(-5.0, -3.0);
- Range z;
- Range w;
-
- y *= -1.0;
- z = x * -1.0;
- w = -1.0 * x;
-
- BOOST_REQUIRE_CLOSE(y.Lo(), 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(y.Hi(), 5.0, 1e-5);
- BOOST_REQUIRE_CLOSE(z.Lo(), 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(z.Hi(), 5.0, 1e-5);
- BOOST_REQUIRE_CLOSE(w.Lo(), 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(w.Hi(), 5.0, 1e-5);
-
- y = x;
- y *= 0.0;
- z = x * 0.0;
- w = 0.0 * x;
-
- BOOST_REQUIRE_SMALL(y.Lo(), 1e-5);
- BOOST_REQUIRE_SMALL(y.Hi(), 1e-5);
- BOOST_REQUIRE_SMALL(z.Lo(), 1e-5);
- BOOST_REQUIRE_SMALL(z.Hi(), 1e-5);
- BOOST_REQUIRE_SMALL(w.Lo(), 1e-5);
- BOOST_REQUIRE_SMALL(w.Hi(), 1e-5);
-
- y = x;
- y *= 2.0;
- z = x * 2.0;
- w = 2.0 * x;
-
- BOOST_REQUIRE_CLOSE(y.Lo(), -10.0, 1e-5);
- BOOST_REQUIRE_CLOSE(y.Hi(), -6.0, 1e-5);
- BOOST_REQUIRE_CLOSE(z.Lo(), -10.0, 1e-5);
- BOOST_REQUIRE_CLOSE(z.Hi(), -6.0, 1e-5);
- BOOST_REQUIRE_CLOSE(w.Lo(), -10.0, 1e-5);
- BOOST_REQUIRE_CLOSE(w.Hi(), -6.0, 1e-5);
-
- x = Range(-2.0, 2.0);
- y = x;
-
- y *= -1.0;
- z = x * -1.0;
- w = -1.0 * x;
-
- BOOST_REQUIRE_CLOSE(y.Lo(), -2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(y.Hi(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(z.Lo(), -2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(z.Hi(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(w.Lo(), -2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(w.Hi(), 2.0, 1e-5);
-
- y = x;
- y *= 0.0;
- z = x * 0.0;
- w = 0.0 * x;
-
- BOOST_REQUIRE_SMALL(y.Lo(), 1e-5);
- BOOST_REQUIRE_SMALL(y.Hi(), 1e-5);
- BOOST_REQUIRE_SMALL(z.Lo(), 1e-5);
- BOOST_REQUIRE_SMALL(z.Hi(), 1e-5);
- BOOST_REQUIRE_SMALL(w.Lo(), 1e-5);
- BOOST_REQUIRE_SMALL(w.Hi(), 1e-5);
-
- y = x;
- y *= 2.0;
- z = x * 2.0;
- w = 2.0 * x;
-
- BOOST_REQUIRE_CLOSE(y.Lo(), -4.0, 1e-5);
- BOOST_REQUIRE_CLOSE(y.Hi(), 4.0, 1e-5);
- BOOST_REQUIRE_CLOSE(z.Lo(), -4.0, 1e-5);
- BOOST_REQUIRE_CLOSE(z.Hi(), 4.0, 1e-5);
- BOOST_REQUIRE_CLOSE(w.Lo(), -4.0, 1e-5);
- BOOST_REQUIRE_CLOSE(w.Hi(), 4.0, 1e-5);
-
- x = Range(3.0, 5.0);
-
- y = x;
- y *= -1.0;
- z = x * -1.0;
- w = -1.0 * x;
-
- BOOST_REQUIRE_CLOSE(y.Lo(), -5.0, 1e-5);
- BOOST_REQUIRE_CLOSE(y.Hi(), -3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(z.Lo(), -5.0, 1e-5);
- BOOST_REQUIRE_CLOSE(z.Hi(), -3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(w.Lo(), -5.0, 1e-5);
- BOOST_REQUIRE_CLOSE(w.Hi(), -3.0, 1e-5);
-
- y = x;
- y *= 0.0;
- z = x * 0.0;
- w = 0.0 * x;
-
- BOOST_REQUIRE_SMALL(y.Lo(), 1e-5);
- BOOST_REQUIRE_SMALL(y.Hi(), 1e-5);
- BOOST_REQUIRE_SMALL(z.Lo(), 1e-5);
- BOOST_REQUIRE_SMALL(z.Hi(), 1e-5);
- BOOST_REQUIRE_SMALL(w.Lo(), 1e-5);
- BOOST_REQUIRE_SMALL(w.Hi(), 1e-5);
-
- y = x;
- y *= 2.0;
- z = x * 2.0;
- w = 2.0 * x;
-
- BOOST_REQUIRE_CLOSE(y.Lo(), 6.0, 1e-5);
- BOOST_REQUIRE_CLOSE(y.Hi(), 10.0, 1e-5);
- BOOST_REQUIRE_CLOSE(z.Lo(), 6.0, 1e-5);
- BOOST_REQUIRE_CLOSE(z.Hi(), 10.0, 1e-5);
- BOOST_REQUIRE_CLOSE(w.Lo(), 6.0, 1e-5);
- BOOST_REQUIRE_CLOSE(w.Hi(), 10.0, 1e-5);
-}
-
-/**
- * Test equality operator.
- */
-BOOST_AUTO_TEST_CASE(RangeEquality)
-{
- // Three cases: non-overlapping, overlapping, equivalent. We should also
- // consider empty ranges, which are not necessarily equal...
- Range x(0.0, 2.0);
- Range y(3.0, 5.0);
-
- // These are odd calls, but we don't want to use operator!= here.
- BOOST_REQUIRE_EQUAL((x == y), false);
- BOOST_REQUIRE_EQUAL((y == x), false);
-
- y = Range(1.0, 3.0);
-
- BOOST_REQUIRE_EQUAL((x == y), false);
- BOOST_REQUIRE_EQUAL((y == x), false);
-
- y = Range(0.0, 2.0);
-
- BOOST_REQUIRE_EQUAL((x == y), true);
- BOOST_REQUIRE_EQUAL((y == x), true);
-
- x = Range(1.0, -1.0); // Empty.
- y = Range(1.0, -1.0); // Also empty.
-
- BOOST_REQUIRE_EQUAL((x == y), true);
- BOOST_REQUIRE_EQUAL((y == x), true);
-
- // No need to test what it does if the empty ranges are different "ranges"
- // because we are not forcing behavior for that.
-}
-
-/**
- * Test inequality operator.
- */
-BOOST_AUTO_TEST_CASE(RangeInequality)
-{
- // We will use the same three cases as the RangeEquality test.
- Range x(0.0, 2.0);
- Range y(3.0, 5.0);
-
- // Again, odd calls, but we want to force use of operator!=.
- BOOST_REQUIRE_EQUAL((x != y), true);
- BOOST_REQUIRE_EQUAL((y != x), true);
-
- y = Range(1.0, 3.0);
-
- BOOST_REQUIRE_EQUAL((x != y), true);
- BOOST_REQUIRE_EQUAL((y != x), true);
-
- y = Range(0.0, 2.0);
-
- BOOST_REQUIRE_EQUAL((x != y), false);
- BOOST_REQUIRE_EQUAL((y != x), false);
-
- x = Range(1.0, -1.0); // Empty.
- y = Range(1.0, -1.0); // Also empty.
-
- BOOST_REQUIRE_EQUAL((x != y), false);
- BOOST_REQUIRE_EQUAL((y != x), false);
-}
-
-/**
- * Test strict less-than operator.
- */
-BOOST_AUTO_TEST_CASE(RangeStrictLessThan)
-{
- // Three cases: non-overlapping, overlapping, and equivalent.
- Range x(0.0, 2.0);
- Range y(3.0, 5.0);
-
- BOOST_REQUIRE_EQUAL((x < y), true);
- BOOST_REQUIRE_EQUAL((y < x), false);
-
- y = Range(1.0, 3.0);
-
- BOOST_REQUIRE_EQUAL((x < y), false);
- BOOST_REQUIRE_EQUAL((y < x), false);
-
- y = Range(0.0, 2.0);
-
- BOOST_REQUIRE_EQUAL((x < y), false);
- BOOST_REQUIRE_EQUAL((y < x), false);
-}
-
-/**
- * Test strict greater-than operator.
- */
-BOOST_AUTO_TEST_CASE(RangeStrictGreaterThan)
-{
- // Three cases: non-overlapping, overlapping, and equivalent.
- Range x(0.0, 2.0);
- Range y(3.0, 5.0);
-
- BOOST_REQUIRE_EQUAL((x > y), false);
- BOOST_REQUIRE_EQUAL((y > x), true);
-
- y = Range(1.0, 3.0);
-
- BOOST_REQUIRE_EQUAL((x > y), false);
- BOOST_REQUIRE_EQUAL((y > x), false);
-
- y = Range(0.0, 2.0);
-
- BOOST_REQUIRE_EQUAL((x > y), false);
- BOOST_REQUIRE_EQUAL((y > x), false);
-}
-
-/**
- * Test the Contains() operator.
- */
-BOOST_AUTO_TEST_CASE(RangeContains)
-{
- // We have three Range cases: strictly less than 0; overlapping 0; and
- // strictly greater than 0. Then the numbers we check can be the same three
- // cases, including one greater than and one less than the range. This should
- // be about 15 total cases.
- Range x(-2.0, -1.0);
-
- BOOST_REQUIRE(!x.Contains(-3.0));
- BOOST_REQUIRE(x.Contains(-2.0));
- BOOST_REQUIRE(x.Contains(-1.5));
- BOOST_REQUIRE(x.Contains(-1.0));
- BOOST_REQUIRE(!x.Contains(-0.5));
- BOOST_REQUIRE(!x.Contains(0.0));
- BOOST_REQUIRE(!x.Contains(1.0));
-
- x = Range(-1.0, 1.0);
-
- BOOST_REQUIRE(!x.Contains(-2.0));
- BOOST_REQUIRE(x.Contains(-1.0));
- BOOST_REQUIRE(x.Contains(0.0));
- BOOST_REQUIRE(x.Contains(1.0));
- BOOST_REQUIRE(!x.Contains(2.0));
-
- x = Range(1.0, 2.0);
-
- BOOST_REQUIRE(!x.Contains(-1.0));
- BOOST_REQUIRE(!x.Contains(0.0));
- BOOST_REQUIRE(!x.Contains(0.5));
- BOOST_REQUIRE(x.Contains(1.0));
- BOOST_REQUIRE(x.Contains(1.5));
- BOOST_REQUIRE(x.Contains(2.0));
- BOOST_REQUIRE(!x.Contains(2.5));
-
- // Now let's try it on an empty range.
- x = Range();
-
- BOOST_REQUIRE(!x.Contains(-10.0));
- BOOST_REQUIRE(!x.Contains(0.0));
- BOOST_REQUIRE(!x.Contains(10.0));
-
- // And an infinite range.
- x = Range(-DBL_MAX, DBL_MAX);
-
- BOOST_REQUIRE(x.Contains(-10.0));
- BOOST_REQUIRE(x.Contains(0.0));
- BOOST_REQUIRE(x.Contains(10.0));
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/math_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/math_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/math_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/math_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,540 @@
+/**
+ * @file math_test.cpp
+ * @author Ryan Curtin
+ *
+ * Tests for everything in the math:: namespace.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core/math/clamp.hpp>
+#include <mlpack/core/math/random.hpp>
+#include <mlpack/core/math/range.hpp>
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace math;
+
+BOOST_AUTO_TEST_SUITE(MathTest);
+
+/**
+ * Verify that the empty constructor creates an empty range.
+ */
+BOOST_AUTO_TEST_CASE(RangeEmptyConstructor)
+{
+ Range x = Range();
+
+ // Just verify that it is empty.
+ BOOST_REQUIRE_GT(x.Lo(), x.Hi());
+}
+
+/**
+ * Verify that the point constructor correctly creates a range that is just a
+ * point.
+ */
+BOOST_AUTO_TEST_CASE(RangePointConstructor)
+{
+ Range x(10.0);
+
+ BOOST_REQUIRE_CLOSE(x.Lo(), x.Hi(), 1e-25);
+ BOOST_REQUIRE_SMALL(x.Width(), 1e-5);
+ BOOST_REQUIRE_CLOSE(x.Lo(), 10.0, 1e-25);
+ BOOST_REQUIRE_CLOSE(x.Hi(), 10.0, 1e-25);
+}
+
+/**
+ * Verify that the range constructor correctly creates the range.
+ */
+BOOST_AUTO_TEST_CASE(RangeConstructor)
+{
+ Range x(0.5, 5.5);
+
+ BOOST_REQUIRE_CLOSE(x.Lo(), 0.5, 1e-25);
+ BOOST_REQUIRE_CLOSE(x.Hi(), 5.5, 1e-25);
+}
+
+/**
+ * Test that we get the width correct.
+ */
+BOOST_AUTO_TEST_CASE(RangeWidth)
+{
+ Range x(0.0, 10.0);
+
+ BOOST_REQUIRE_CLOSE(x.Width(), 10.0, 1e-20);
+
+ // Make it empty.
+ x.Hi() = 0.0;
+
+ BOOST_REQUIRE_SMALL(x.Width(), 1e-5);
+
+ // Make it negative.
+ x.Hi() = -2.0;
+
+ BOOST_REQUIRE_SMALL(x.Width(), 1e-5);
+
+ // Just one more test.
+ x.Lo() = -5.2;
+ x.Hi() = 5.2;
+
+ BOOST_REQUIRE_CLOSE(x.Width(), 10.4, 1e-5);
+}
+
+/**
+ * Test that we get the midpoint correct.
+ */
+BOOST_AUTO_TEST_CASE(RangeMidpoint)
+{
+ Range x(0.0, 10.0);
+
+ BOOST_REQUIRE_CLOSE(x.Mid(), 5.0, 1e-5);
+
+ x.Lo() = -5.0;
+
+ BOOST_REQUIRE_CLOSE(x.Mid(), 2.5, 1e-5);
+}
+
+/**
+ * Test that we can expand to include other ranges correctly.
+ */
+BOOST_AUTO_TEST_CASE(RangeIncludeOther)
+{
+ // We need to test both |= and |.
+ // We have three cases: non-overlapping; overlapping; equivalent, and then a
+ // couple permutations (switch left with right and make sure it still works).
+ Range x(0.0, 2.0);
+ Range y(3.0, 5.0);
+
+ Range z(0.0, 2.0); // Used for operator|=().
+ Range w;
+ z |= y;
+ w = x | y;
+
+ BOOST_REQUIRE_SMALL(z.Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(z.Hi(), 5.0, 1e-5);
+ BOOST_REQUIRE_SMALL(w.Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Hi(), 5.0, 1e-5);
+
+ // Switch operator precedence.
+ z = y;
+ z |= x;
+ w = y | x;
+
+ BOOST_REQUIRE_SMALL(z.Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(z.Hi(), 5.0, 1e-5);
+ BOOST_REQUIRE_SMALL(w.Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Hi(), 5.0, 1e-5);
+
+ // Now make them overlapping.
+ x = Range(0.0, 3.5);
+ y = Range(3.0, 4.0);
+
+ z = x;
+ z |= y;
+ w = x | y;
+
+ BOOST_REQUIRE_SMALL(z.Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(z.Hi(), 4.0, 1e-5);
+ BOOST_REQUIRE_SMALL(w.Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Hi(), 4.0, 1e-5);
+
+ // Switch operator precedence.
+ z = y;
+ z |= x;
+ w = y | x;
+
+ BOOST_REQUIRE_SMALL(z.Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(z.Hi(), 4.0, 1e-5);
+ BOOST_REQUIRE_SMALL(w.Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Hi(), 4.0, 1e-5);
+
+ // Now the equivalent case.
+ x = Range(0.0, 2.0);
+ y = Range(0.0, 2.0);
+
+ z = x;
+ z |= y;
+ w = x | y;
+
+ BOOST_REQUIRE_SMALL(z.Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(z.Hi(), 2.0, 1e-5);
+ BOOST_REQUIRE_SMALL(w.Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Hi(), 2.0, 1e-5);
+
+ z = y;
+ z |= x;
+ w = y | x;
+
+ BOOST_REQUIRE_SMALL(z.Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(z.Hi(), 2.0, 1e-5);
+ BOOST_REQUIRE_SMALL(w.Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Hi(), 2.0, 1e-5);
+}
+
+/**
+ * Test that we can 'and' ranges correctly.
+ */
+BOOST_AUTO_TEST_CASE(RangeIntersectOther)
+{
+ // We need to test both &= and &.
+ // We have three cases: non-overlapping, overlapping; equivalent, and then a
+ // couple permutations (switch left with right and make sure it still works).
+ Range x(0.0, 2.0);
+ Range y(3.0, 5.0);
+
+ Range z(0.0, 2.0);
+ Range w;
+ z &= y;
+ w = x & y;
+
+ BOOST_REQUIRE_SMALL(z.Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(w.Width(), 1e-5);
+
+ // Reverse operator precedence.
+ z = y;
+ z &= x;
+ w = y & x;
+
+ BOOST_REQUIRE_SMALL(z.Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(w.Width(), 1e-5);
+
+ // Now make them overlapping.
+ x = Range(0.0, 3.5);
+ y = Range(3.0, 4.0);
+
+ z = x;
+ z &= y;
+ w = x & y;
+
+ BOOST_REQUIRE_CLOSE(z.Lo(), 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(z.Hi(), 3.5, 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Lo(), 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Hi(), 3.5, 1e-5);
+
+ // Reverse operator precedence.
+ z = y;
+ z &= x;
+ w = y & x;
+
+ BOOST_REQUIRE_CLOSE(z.Lo(), 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(z.Hi(), 3.5, 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Lo(), 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Hi(), 3.5, 1e-5);
+
+ // Now make them equivalent.
+ x = Range(2.0, 4.0);
+ y = Range(2.0, 4.0);
+
+ z = x;
+ z &= y;
+ w = x & y;
+
+ BOOST_REQUIRE_CLOSE(z.Lo(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(z.Hi(), 4.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Lo(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Hi(), 4.0, 1e-5);
+}
+
+/**
+ * Test multiplication of a range with a double.
+ */
+BOOST_AUTO_TEST_CASE(RangeMultiply)
+{
+ // We need to test both * and *=, as well as both cases of *.
+ // We'll try with a couple of numbers: -1, 0, 2.
+ // And we'll have a couple of cases for bounds: strictly less than zero;
+ // including zero; and strictly greater than zero.
+ //
+ // So, nine total cases.
+ Range x(-5.0, -3.0);
+ Range y(-5.0, -3.0);
+ Range z;
+ Range w;
+
+ y *= -1.0;
+ z = x * -1.0;
+ w = -1.0 * x;
+
+ BOOST_REQUIRE_CLOSE(y.Lo(), 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(y.Hi(), 5.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(z.Lo(), 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(z.Hi(), 5.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Lo(), 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Hi(), 5.0, 1e-5);
+
+ y = x;
+ y *= 0.0;
+ z = x * 0.0;
+ w = 0.0 * x;
+
+ BOOST_REQUIRE_SMALL(y.Lo(), 1e-5);
+ BOOST_REQUIRE_SMALL(y.Hi(), 1e-5);
+ BOOST_REQUIRE_SMALL(z.Lo(), 1e-5);
+ BOOST_REQUIRE_SMALL(z.Hi(), 1e-5);
+ BOOST_REQUIRE_SMALL(w.Lo(), 1e-5);
+ BOOST_REQUIRE_SMALL(w.Hi(), 1e-5);
+
+ y = x;
+ y *= 2.0;
+ z = x * 2.0;
+ w = 2.0 * x;
+
+ BOOST_REQUIRE_CLOSE(y.Lo(), -10.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(y.Hi(), -6.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(z.Lo(), -10.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(z.Hi(), -6.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Lo(), -10.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Hi(), -6.0, 1e-5);
+
+ x = Range(-2.0, 2.0);
+ y = x;
+
+ y *= -1.0;
+ z = x * -1.0;
+ w = -1.0 * x;
+
+ BOOST_REQUIRE_CLOSE(y.Lo(), -2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(y.Hi(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(z.Lo(), -2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(z.Hi(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Lo(), -2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Hi(), 2.0, 1e-5);
+
+ y = x;
+ y *= 0.0;
+ z = x * 0.0;
+ w = 0.0 * x;
+
+ BOOST_REQUIRE_SMALL(y.Lo(), 1e-5);
+ BOOST_REQUIRE_SMALL(y.Hi(), 1e-5);
+ BOOST_REQUIRE_SMALL(z.Lo(), 1e-5);
+ BOOST_REQUIRE_SMALL(z.Hi(), 1e-5);
+ BOOST_REQUIRE_SMALL(w.Lo(), 1e-5);
+ BOOST_REQUIRE_SMALL(w.Hi(), 1e-5);
+
+ y = x;
+ y *= 2.0;
+ z = x * 2.0;
+ w = 2.0 * x;
+
+ BOOST_REQUIRE_CLOSE(y.Lo(), -4.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(y.Hi(), 4.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(z.Lo(), -4.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(z.Hi(), 4.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Lo(), -4.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Hi(), 4.0, 1e-5);
+
+ x = Range(3.0, 5.0);
+
+ y = x;
+ y *= -1.0;
+ z = x * -1.0;
+ w = -1.0 * x;
+
+ BOOST_REQUIRE_CLOSE(y.Lo(), -5.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(y.Hi(), -3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(z.Lo(), -5.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(z.Hi(), -3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Lo(), -5.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Hi(), -3.0, 1e-5);
+
+ y = x;
+ y *= 0.0;
+ z = x * 0.0;
+ w = 0.0 * x;
+
+ BOOST_REQUIRE_SMALL(y.Lo(), 1e-5);
+ BOOST_REQUIRE_SMALL(y.Hi(), 1e-5);
+ BOOST_REQUIRE_SMALL(z.Lo(), 1e-5);
+ BOOST_REQUIRE_SMALL(z.Hi(), 1e-5);
+ BOOST_REQUIRE_SMALL(w.Lo(), 1e-5);
+ BOOST_REQUIRE_SMALL(w.Hi(), 1e-5);
+
+ y = x;
+ y *= 2.0;
+ z = x * 2.0;
+ w = 2.0 * x;
+
+ BOOST_REQUIRE_CLOSE(y.Lo(), 6.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(y.Hi(), 10.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(z.Lo(), 6.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(z.Hi(), 10.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Lo(), 6.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(w.Hi(), 10.0, 1e-5);
+}
+
+/**
+ * Test equality operator.
+ */
+BOOST_AUTO_TEST_CASE(RangeEquality)
+{
+ // Three cases: non-overlapping, overlapping, equivalent. We should also
+ // consider empty ranges, which are not necessarily equal...
+ Range x(0.0, 2.0);
+ Range y(3.0, 5.0);
+
+ // These are odd calls, but we don't want to use operator!= here.
+ BOOST_REQUIRE_EQUAL((x == y), false);
+ BOOST_REQUIRE_EQUAL((y == x), false);
+
+ y = Range(1.0, 3.0);
+
+ BOOST_REQUIRE_EQUAL((x == y), false);
+ BOOST_REQUIRE_EQUAL((y == x), false);
+
+ y = Range(0.0, 2.0);
+
+ BOOST_REQUIRE_EQUAL((x == y), true);
+ BOOST_REQUIRE_EQUAL((y == x), true);
+
+ x = Range(1.0, -1.0); // Empty.
+ y = Range(1.0, -1.0); // Also empty.
+
+ BOOST_REQUIRE_EQUAL((x == y), true);
+ BOOST_REQUIRE_EQUAL((y == x), true);
+
+ // No need to test what it does if the empty ranges are different "ranges"
+ // because we are not forcing behavior for that.
+}
+
+/**
+ * Test inequality operator.
+ */
+BOOST_AUTO_TEST_CASE(RangeInequality)
+{
+ // We will use the same three cases as the RangeEquality test.
+ Range x(0.0, 2.0);
+ Range y(3.0, 5.0);
+
+ // Again, odd calls, but we want to force use of operator!=.
+ BOOST_REQUIRE_EQUAL((x != y), true);
+ BOOST_REQUIRE_EQUAL((y != x), true);
+
+ y = Range(1.0, 3.0);
+
+ BOOST_REQUIRE_EQUAL((x != y), true);
+ BOOST_REQUIRE_EQUAL((y != x), true);
+
+ y = Range(0.0, 2.0);
+
+ BOOST_REQUIRE_EQUAL((x != y), false);
+ BOOST_REQUIRE_EQUAL((y != x), false);
+
+ x = Range(1.0, -1.0); // Empty.
+ y = Range(1.0, -1.0); // Also empty.
+
+ BOOST_REQUIRE_EQUAL((x != y), false);
+ BOOST_REQUIRE_EQUAL((y != x), false);
+}
+
+/**
+ * Test strict less-than operator.
+ */
+BOOST_AUTO_TEST_CASE(RangeStrictLessThan)
+{
+ // Three cases: non-overlapping, overlapping, and equivalent.
+ Range x(0.0, 2.0);
+ Range y(3.0, 5.0);
+
+ BOOST_REQUIRE_EQUAL((x < y), true);
+ BOOST_REQUIRE_EQUAL((y < x), false);
+
+ y = Range(1.0, 3.0);
+
+ BOOST_REQUIRE_EQUAL((x < y), false);
+ BOOST_REQUIRE_EQUAL((y < x), false);
+
+ y = Range(0.0, 2.0);
+
+ BOOST_REQUIRE_EQUAL((x < y), false);
+ BOOST_REQUIRE_EQUAL((y < x), false);
+}
+
+/**
+ * Test strict greater-than operator.
+ */
+BOOST_AUTO_TEST_CASE(RangeStrictGreaterThan)
+{
+ // Three cases: non-overlapping, overlapping, and equivalent.
+ Range x(0.0, 2.0);
+ Range y(3.0, 5.0);
+
+ BOOST_REQUIRE_EQUAL((x > y), false);
+ BOOST_REQUIRE_EQUAL((y > x), true);
+
+ y = Range(1.0, 3.0);
+
+ BOOST_REQUIRE_EQUAL((x > y), false);
+ BOOST_REQUIRE_EQUAL((y > x), false);
+
+ y = Range(0.0, 2.0);
+
+ BOOST_REQUIRE_EQUAL((x > y), false);
+ BOOST_REQUIRE_EQUAL((y > x), false);
+}
+
+/**
+ * Test the Contains() operator.
+ */
+BOOST_AUTO_TEST_CASE(RangeContains)
+{
+ // We have three Range cases: strictly less than 0; overlapping 0; and
+ // strictly greater than 0. Then the numbers we check can be the same three
+ // cases, including one greater than and one less than the range. This should
+ // be about 15 total cases.
+ Range x(-2.0, -1.0);
+
+ BOOST_REQUIRE(!x.Contains(-3.0));
+ BOOST_REQUIRE(x.Contains(-2.0));
+ BOOST_REQUIRE(x.Contains(-1.5));
+ BOOST_REQUIRE(x.Contains(-1.0));
+ BOOST_REQUIRE(!x.Contains(-0.5));
+ BOOST_REQUIRE(!x.Contains(0.0));
+ BOOST_REQUIRE(!x.Contains(1.0));
+
+ x = Range(-1.0, 1.0);
+
+ BOOST_REQUIRE(!x.Contains(-2.0));
+ BOOST_REQUIRE(x.Contains(-1.0));
+ BOOST_REQUIRE(x.Contains(0.0));
+ BOOST_REQUIRE(x.Contains(1.0));
+ BOOST_REQUIRE(!x.Contains(2.0));
+
+ x = Range(1.0, 2.0);
+
+ BOOST_REQUIRE(!x.Contains(-1.0));
+ BOOST_REQUIRE(!x.Contains(0.0));
+ BOOST_REQUIRE(!x.Contains(0.5));
+ BOOST_REQUIRE(x.Contains(1.0));
+ BOOST_REQUIRE(x.Contains(1.5));
+ BOOST_REQUIRE(x.Contains(2.0));
+ BOOST_REQUIRE(!x.Contains(2.5));
+
+ // Now let's try it on an empty range.
+ x = Range();
+
+ BOOST_REQUIRE(!x.Contains(-10.0));
+ BOOST_REQUIRE(!x.Contains(0.0));
+ BOOST_REQUIRE(!x.Contains(10.0));
+
+ // And an infinite range.
+ x = Range(-DBL_MAX, DBL_MAX);
+
+ BOOST_REQUIRE(x.Contains(-10.0));
+ BOOST_REQUIRE(x.Contains(0.0));
+ BOOST_REQUIRE(x.Contains(10.0));
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/mlpack_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/mlpack_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/mlpack_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,32 +0,0 @@
-/**
- * @file mlpack_test.cpp
- *
- * Simple file defining the name of the overall test for MLPACK. Each
- * individual test is contained in its own file.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#define BOOST_TEST_MODULE MLPACKTest
-
-#include <boost/version.hpp>
-
-// We only need to do this for old Boost versions.
-#if BOOST_VERSION < 103600
- #define BOOST_AUTO_TEST_MAIN
-#endif
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/mlpack_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/mlpack_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/mlpack_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/mlpack_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,32 @@
+/**
+ * @file mlpack_test.cpp
+ *
+ * Simple file defining the name of the overall test for MLPACK. Each
+ * individual test is contained in its own file.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#define BOOST_TEST_MODULE MLPACKTest
+
+#include <boost/version.hpp>
+
+// We only need to do this for old Boost versions.
+#if BOOST_VERSION < 103600
+ #define BOOST_AUTO_TEST_MAIN
+#endif
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/nbc_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/nbc_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/nbc_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,79 +0,0 @@
-/**
- * @file nbc_test.cpp
- *
- * Test for the Naive Bayes classifier.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/methods/naive_bayes/naive_bayes_classifier.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace naive_bayes;
-
-BOOST_AUTO_TEST_SUITE(NBCTest);
-
-BOOST_AUTO_TEST_CASE(NaiveBayesClassifierTest)
-{
- const char* trainFilename = "trainSet.csv";
- const char* testFilename = "testSet.csv";
- const char* trainResultFilename = "trainRes.csv";
- const char* testResultFilename = "testRes.csv";
- size_t classes = 2;
-
- arma::mat trainData, trainRes, calcMat;
- data::Load(trainFilename, trainData, true);
- data::Load(trainResultFilename, trainRes, true);
-
- NaiveBayesClassifier<> nbcTest(trainData, classes);
-
- size_t dimension = nbcTest.Means().n_rows;
- calcMat.zeros(2 * dimension + 1, classes);
-
- for (size_t i = 0; i < dimension; i++)
- {
- for (size_t j = 0; j < classes; j++)
- {
- calcMat(i, j) = nbcTest.Means()(i, j);
- calcMat(i + dimension, j) = nbcTest.Variances()(i, j);
- }
- }
-
- for (size_t i = 0; i < classes; i++)
- calcMat(2 * dimension, i) = nbcTest.Probabilities()(i);
-
-// for(size_t i = 0; i < calcMat.n_rows; i++)
-// for(size_t j = 0; j < classes; j++)
-// BOOST_REQUIRE_CLOSE(trainRes(i, j) + .00001, calcMat(i, j), 0.01);
-
- arma::mat testData;
- arma::Mat<size_t> testRes;
- arma::Col<size_t> calcVec;
- data::Load(testFilename, testData, true);
- data::Load(testResultFilename, testRes, true);
-
- testData.shed_row(testData.n_rows - 1); // Remove the labels.
-
- nbcTest.Classify(testData, calcVec);
-
-// for(size_t i = 0; i < testData.n_cols; i++)
-// BOOST_REQUIRE_EQUAL(testRes(i), calcVec(i));
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/nbc_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/nbc_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/nbc_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/nbc_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,79 @@
+/**
+ * @file nbc_test.cpp
+ *
+ * Test for the Naive Bayes classifier.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/methods/naive_bayes/naive_bayes_classifier.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace naive_bayes;
+
+BOOST_AUTO_TEST_SUITE(NBCTest);
+
+BOOST_AUTO_TEST_CASE(NaiveBayesClassifierTest)
+{
+ const char* trainFilename = "trainSet.csv";
+ const char* testFilename = "testSet.csv";
+ const char* trainResultFilename = "trainRes.csv";
+ const char* testResultFilename = "testRes.csv";
+ size_t classes = 2;
+
+ arma::mat trainData, trainRes, calcMat;
+ data::Load(trainFilename, trainData, true);
+ data::Load(trainResultFilename, trainRes, true);
+
+ NaiveBayesClassifier<> nbcTest(trainData, classes);
+
+ size_t dimension = nbcTest.Means().n_rows;
+ calcMat.zeros(2 * dimension + 1, classes);
+
+ for (size_t i = 0; i < dimension; i++)
+ {
+ for (size_t j = 0; j < classes; j++)
+ {
+ calcMat(i, j) = nbcTest.Means()(i, j);
+ calcMat(i + dimension, j) = nbcTest.Variances()(i, j);
+ }
+ }
+
+ for (size_t i = 0; i < classes; i++)
+ calcMat(2 * dimension, i) = nbcTest.Probabilities()(i);
+
+// for(size_t i = 0; i < calcMat.n_rows; i++)
+// for(size_t j = 0; j < classes; j++)
+// BOOST_REQUIRE_CLOSE(trainRes(i, j) + .00001, calcMat(i, j), 0.01);
+
+ arma::mat testData;
+ arma::Mat<size_t> testRes;
+ arma::Col<size_t> calcVec;
+ data::Load(testFilename, testData, true);
+ data::Load(testResultFilename, testRes, true);
+
+ testData.shed_row(testData.n_rows - 1); // Remove the labels.
+
+ nbcTest.Classify(testData, calcVec);
+
+// for(size_t i = 0; i < testData.n_cols; i++)
+// BOOST_REQUIRE_EQUAL(testRes(i), calcVec(i));
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/nca_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/nca_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/nca_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,336 +0,0 @@
-/**
- * @file nca_test.cpp
- * @author Ryan Curtin
- *
- * Unit tests for Neighborhood Components Analysis and related code (including
- * the softmax error function).
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-#include <mlpack/methods/nca/nca.hpp>
-#include <mlpack/core/optimizers/lbfgs/lbfgs.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::metric;
-using namespace mlpack::nca;
-using namespace mlpack::optimization;
-
-//
-// Tests for the SoftmaxErrorFunction
-//
-
-BOOST_AUTO_TEST_SUITE(NCATest);
-
-/**
- * The Softmax error function should return the identity matrix as its initial
- * point.
- */
-BOOST_AUTO_TEST_CASE(SoftmaxInitialPoint)
-{
- // Cheap fake dataset.
- arma::mat data;
- data.randu(5, 5);
- arma::uvec labels;
- labels.zeros(5);
-
- SoftmaxErrorFunction<SquaredEuclideanDistance> sef(data, labels);
-
- // Verify the initial point is the identity matrix.
- arma::mat initialPoint = sef.GetInitialPoint();
- for (int row = 0; row < 5; row++)
- {
- for (int col = 0; col < 5; col++)
- {
- if (row == col)
- BOOST_REQUIRE_CLOSE(initialPoint(row, col), 1.0, 1e-5);
- else
- BOOST_REQUIRE_SMALL(initialPoint(row, col), 1e-5);
- }
- }
-}
-
-/***
- * On a simple fake dataset, ensure that the initial function evaluation is
- * correct.
- */
-BOOST_AUTO_TEST_CASE(SoftmaxInitialEvaluation)
-{
- // Useful but simple dataset with six points and two classes.
- arma::mat data = "-0.1 -0.1 -0.1 0.1 0.1 0.1;"
- " 1.0 0.0 -1.0 1.0 0.0 -1.0 ";
- arma::uvec labels = " 0 0 0 1 1 1 ";
-
- SoftmaxErrorFunction<SquaredEuclideanDistance> sef(data, labels);
-
- double objective = sef.Evaluate(arma::eye<arma::mat>(2, 2));
-
- // Result painstakingly calculated by hand by rcurtin (recorded forever in his
- // notebook). As a result of lack of precision of the by-hand result, the
- // tolerance is fairly high.
- BOOST_REQUIRE_CLOSE(objective, -1.5115, 0.01);
-}
-
-/**
- * On a simple fake dataset, ensure that the initial gradient evaluation is
- * correct.
- */
-BOOST_AUTO_TEST_CASE(SoftmaxInitialGradient)
-{
- // Useful but simple dataset with six points and two classes.
- arma::mat data = "-0.1 -0.1 -0.1 0.1 0.1 0.1;"
- " 1.0 0.0 -1.0 1.0 0.0 -1.0 ";
- arma::uvec labels = " 0 0 0 1 1 1 ";
-
- SoftmaxErrorFunction<SquaredEuclideanDistance> sef(data, labels);
-
- arma::mat gradient;
- arma::mat coordinates = arma::eye<arma::mat>(2, 2);
- sef.Gradient(coordinates, gradient);
-
- // Results painstakingly calculated by hand by rcurtin (recorded forever in
- // his notebook). As a result of lack of precision of the by-hand result, the
- // tolerance is fairly high.
- BOOST_REQUIRE_CLOSE(gradient(0, 0), -0.089766, 0.05);
- BOOST_REQUIRE_SMALL(gradient(1, 0), 1e-5);
- BOOST_REQUIRE_SMALL(gradient(0, 1), 1e-5);
- BOOST_REQUIRE_CLOSE(gradient(1, 1), 1.63823, 0.01);
-}
-
-/**
- * On optimally separated datasets, ensure that the objective function is
- * optimal (equal to the negative number of points).
- */
-BOOST_AUTO_TEST_CASE(SoftmaxOptimalEvaluation)
-{
- // Simple optimal dataset.
- arma::mat data = " 500 500 -500 -500;"
- " 1 0 1 0 ";
- arma::uvec labels = " 0 0 1 1 ";
-
- SoftmaxErrorFunction<SquaredEuclideanDistance> sef(data, labels);
-
- double objective = sef.Evaluate(arma::eye<arma::mat>(2, 2));
-
- // Use a very close tolerance for optimality; we need to be sure this function
- // gives optimal results correctly.
- BOOST_REQUIRE_CLOSE(objective, -4.0, 1e-10);
-}
-
-/**
- * On optimally separated datasets, ensure that the gradient is zero.
- */
-BOOST_AUTO_TEST_CASE(SoftmaxOptimalGradient)
-{
- // Simple optimal dataset.
- arma::mat data = " 500 500 -500 -500;"
- " 1 0 1 0 ";
- arma::uvec labels = " 0 0 1 1 ";
-
- SoftmaxErrorFunction<SquaredEuclideanDistance> sef(data, labels);
-
- arma::mat gradient;
- sef.Gradient(arma::eye<arma::mat>(2, 2), gradient);
-
- BOOST_REQUIRE_SMALL(gradient(0, 0), 1e-5);
- BOOST_REQUIRE_SMALL(gradient(0, 1), 1e-5);
- BOOST_REQUIRE_SMALL(gradient(1, 0), 1e-5);
- BOOST_REQUIRE_SMALL(gradient(1, 1), 1e-5);
-}
-
-/**
- * Ensure the separable objective function is right.
- */
-BOOST_AUTO_TEST_CASE(SoftmaxSeparableObjective)
-{
- // Useful but simple dataset with six points and two classes.
- arma::mat data = "-0.1 -0.1 -0.1 0.1 0.1 0.1;"
- " 1.0 0.0 -1.0 1.0 0.0 -1.0 ";
- arma::uvec labels = " 0 0 0 1 1 1 ";
-
- SoftmaxErrorFunction<SquaredEuclideanDistance> sef(data, labels);
-
- // Results painstakingly calculated by hand by rcurtin (recorded forever in
- // his notebook). As a result of lack of precision of the by-hand result, the
- // tolerance is fairly high.
- arma::mat coordinates = arma::eye<arma::mat>(2, 2);
- BOOST_REQUIRE_CLOSE(sef.Evaluate(coordinates, 0), -0.22480, 0.01);
- BOOST_REQUIRE_CLOSE(sef.Evaluate(coordinates, 1), -0.30613, 0.01);
- BOOST_REQUIRE_CLOSE(sef.Evaluate(coordinates, 2), -0.22480, 0.01);
- BOOST_REQUIRE_CLOSE(sef.Evaluate(coordinates, 3), -0.22480, 0.01);
- BOOST_REQUIRE_CLOSE(sef.Evaluate(coordinates, 4), -0.30613, 0.01);
- BOOST_REQUIRE_CLOSE(sef.Evaluate(coordinates, 5), -0.22480, 0.01);
-}
-
-/**
- * Ensure the optimal separable objective function is right.
- */
-BOOST_AUTO_TEST_CASE(OptimalSoftmaxSeparableObjective)
-{
- // Simple optimal dataset.
- arma::mat data = " 500 500 -500 -500;"
- " 1 0 1 0 ";
- arma::uvec labels = " 0 0 1 1 ";
-
- SoftmaxErrorFunction<SquaredEuclideanDistance> sef(data, labels);
-
- arma::mat coordinates = arma::eye<arma::mat>(2, 2);
-
- // Use a very close tolerance for optimality; we need to be sure this function
- // gives optimal results correctly.
- BOOST_REQUIRE_CLOSE(sef.Evaluate(coordinates, 0), -1.0, 1e-10);
- BOOST_REQUIRE_CLOSE(sef.Evaluate(coordinates, 1), -1.0, 1e-10);
- BOOST_REQUIRE_CLOSE(sef.Evaluate(coordinates, 2), -1.0, 1e-10);
- BOOST_REQUIRE_CLOSE(sef.Evaluate(coordinates, 3), -1.0, 1e-10);
-}
-
-/**
- * Ensure the separable gradient is right.
- */
-BOOST_AUTO_TEST_CASE(SoftmaxSeparableGradient)
-{
- // Useful but simple dataset with six points and two classes.
- arma::mat data = "-0.1 -0.1 -0.1 0.1 0.1 0.1;"
- " 1.0 0.0 -1.0 1.0 0.0 -1.0 ";
- arma::uvec labels = " 0 0 0 1 1 1 ";
-
- SoftmaxErrorFunction<SquaredEuclideanDistance> sef(data, labels);
-
- arma::mat coordinates = arma::eye<arma::mat>(2, 2);
- arma::mat gradient(2, 2);
-
- sef.Gradient(coordinates, 0, gradient);
-
- BOOST_REQUIRE_CLOSE(gradient(0, 0), -2.0 * 0.0069708, 0.01);
- BOOST_REQUIRE_CLOSE(gradient(0, 1), -2.0 * -0.0101707, 0.01);
- BOOST_REQUIRE_CLOSE(gradient(1, 0), -2.0 * -0.0101707, 0.01);
- BOOST_REQUIRE_CLOSE(gradient(1, 1), -2.0 * -0.14359, 0.01);
-
- sef.Gradient(coordinates, 1, gradient);
-
- BOOST_REQUIRE_CLOSE(gradient(0, 0), -2.0 * 0.008496, 0.01);
- BOOST_REQUIRE_SMALL(gradient(0, 1), 1e-5);
- BOOST_REQUIRE_SMALL(gradient(1, 0), 1e-5);
- BOOST_REQUIRE_CLOSE(gradient(1, 1), -2.0 * -0.12238, 0.01);
-
- sef.Gradient(coordinates, 2, gradient);
-
- BOOST_REQUIRE_CLOSE(gradient(0, 0), -2.0 * 0.0069708, 0.01);
- BOOST_REQUIRE_CLOSE(gradient(0, 1), -2.0 * 0.0101707, 0.01);
- BOOST_REQUIRE_CLOSE(gradient(1, 0), -2.0 * 0.0101707, 0.01);
- BOOST_REQUIRE_CLOSE(gradient(1, 1), -2.0 * -0.1435886, 0.01);
-
- sef.Gradient(coordinates, 3, gradient);
-
- BOOST_REQUIRE_CLOSE(gradient(0, 0), -2.0 * 0.0069708, 0.01);
- BOOST_REQUIRE_CLOSE(gradient(0, 1), -2.0 * 0.0101707, 0.01);
- BOOST_REQUIRE_CLOSE(gradient(1, 0), -2.0 * 0.0101707, 0.01);
- BOOST_REQUIRE_CLOSE(gradient(1, 1), -2.0 * -0.1435886, 0.01);
-
- sef.Gradient(coordinates, 4, gradient);
-
- BOOST_REQUIRE_CLOSE(gradient(0, 0), -2.0 * 0.008496, 0.01);
- BOOST_REQUIRE_SMALL(gradient(0, 1), 1e-5);
- BOOST_REQUIRE_SMALL(gradient(1, 0), 1e-5);
- BOOST_REQUIRE_CLOSE(gradient(1, 1), -2.0 * -0.12238, 0.01);
-
- sef.Gradient(coordinates, 5, gradient);
-
- BOOST_REQUIRE_CLOSE(gradient(0, 0), -2.0 * 0.0069708, 0.01);
- BOOST_REQUIRE_CLOSE(gradient(0, 1), -2.0 * -0.0101707, 0.01);
- BOOST_REQUIRE_CLOSE(gradient(1, 0), -2.0 * -0.0101707, 0.01);
- BOOST_REQUIRE_CLOSE(gradient(1, 1), -2.0 * -0.1435886, 0.01);
-}
-
-//
-// Tests for the NCA algorithm.
-//
-
-/**
- * On our simple dataset, ensure that the NCA algorithm fully separates the
- * points.
- */
-BOOST_AUTO_TEST_CASE(NCASGDSimpleDataset)
-{
- // Useful but simple dataset with six points and two classes.
- arma::mat data = "-0.1 -0.1 -0.1 0.1 0.1 0.1;"
- " 1.0 0.0 -1.0 1.0 0.0 -1.0 ";
- arma::uvec labels = " 0 0 0 1 1 1 ";
-
- // Huge learning rate because this is so simple.
- NCA<SquaredEuclideanDistance> nca(data, labels);
- nca.Optimizer().StepSize() = 1.2;
- nca.Optimizer().MaxIterations() = 300000;
- nca.Optimizer().Tolerance() = 0;
- nca.Optimizer().Shuffle() = true;
-
- arma::mat outputMatrix;
- nca.LearnDistance(outputMatrix);
-
- // Ensure that the objective function is better now.
- SoftmaxErrorFunction<SquaredEuclideanDistance> sef(data, labels);
-
- double initObj = sef.Evaluate(arma::eye<arma::mat>(2, 2));
- double finalObj = sef.Evaluate(outputMatrix);
- arma::mat finalGradient;
- sef.Gradient(outputMatrix, finalGradient);
-
- // finalObj must be less than initObj.
- BOOST_REQUIRE_LT(finalObj, initObj);
- // Verify that final objective is optimal.
- BOOST_REQUIRE_CLOSE(finalObj, -6.0, 0.005);
- // The solution is not unique, so the best we can do is ensure the gradient
- // norm is close to 0.
- BOOST_REQUIRE_LT(arma::norm(finalGradient, 2), 1e-4);
-}
-
-BOOST_AUTO_TEST_CASE(NCALBFGSSimpleDataset)
-{
- // Useful but simple dataset with six points and two classes.
- arma::mat data = "-0.1 -0.1 -0.1 0.1 0.1 0.1;"
- " 1.0 0.0 -1.0 1.0 0.0 -1.0 ";
- arma::uvec labels = " 0 0 0 1 1 1 ";
-
- // Huge learning rate because this is so simple.
- NCA<SquaredEuclideanDistance, L_BFGS> nca(data, labels);
- nca.Optimizer().NumBasis() = 5;
-
- arma::mat outputMatrix;
- nca.LearnDistance(outputMatrix);
-
- // Ensure that the objective function is better now.
- SoftmaxErrorFunction<SquaredEuclideanDistance> sef(data, labels);
-
- double initObj = sef.Evaluate(arma::eye<arma::mat>(2, 2));
- double finalObj = sef.Evaluate(outputMatrix);
- arma::mat finalGradient;
- sef.Gradient(outputMatrix, finalGradient);
-
- // finalObj must be less than initObj.
- BOOST_REQUIRE_LT(finalObj, initObj);
- // Verify that final objective is optimal.
- BOOST_REQUIRE_CLOSE(finalObj, -6.0, 1e-5);
- // The solution is not unique, so the best we can do is ensure the gradient
- // norm is close to 0.
- BOOST_REQUIRE_LT(arma::norm(finalGradient, 2), 1e-6);
-
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/nca_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/nca_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/nca_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/nca_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,336 @@
+/**
+ * @file nca_test.cpp
+ * @author Ryan Curtin
+ *
+ * Unit tests for Neighborhood Components Analysis and related code (including
+ * the softmax error function).
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/methods/nca/nca.hpp>
+#include <mlpack/core/optimizers/lbfgs/lbfgs.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::metric;
+using namespace mlpack::nca;
+using namespace mlpack::optimization;
+
+//
+// Tests for the SoftmaxErrorFunction
+//
+
+BOOST_AUTO_TEST_SUITE(NCATest);
+
+/**
+ * The Softmax error function should return the identity matrix as its initial
+ * point.
+ */
+BOOST_AUTO_TEST_CASE(SoftmaxInitialPoint)
+{
+ // Cheap fake dataset.
+ arma::mat data;
+ data.randu(5, 5);
+ arma::uvec labels;
+ labels.zeros(5);
+
+ SoftmaxErrorFunction<SquaredEuclideanDistance> sef(data, labels);
+
+ // Verify the initial point is the identity matrix.
+ arma::mat initialPoint = sef.GetInitialPoint();
+ for (int row = 0; row < 5; row++)
+ {
+ for (int col = 0; col < 5; col++)
+ {
+ if (row == col)
+ BOOST_REQUIRE_CLOSE(initialPoint(row, col), 1.0, 1e-5);
+ else
+ BOOST_REQUIRE_SMALL(initialPoint(row, col), 1e-5);
+ }
+ }
+}
+
+/***
+ * On a simple fake dataset, ensure that the initial function evaluation is
+ * correct.
+ */
+BOOST_AUTO_TEST_CASE(SoftmaxInitialEvaluation)
+{
+ // Useful but simple dataset with six points and two classes.
+ arma::mat data = "-0.1 -0.1 -0.1 0.1 0.1 0.1;"
+ " 1.0 0.0 -1.0 1.0 0.0 -1.0 ";
+ arma::uvec labels = " 0 0 0 1 1 1 ";
+
+ SoftmaxErrorFunction<SquaredEuclideanDistance> sef(data, labels);
+
+ double objective = sef.Evaluate(arma::eye<arma::mat>(2, 2));
+
+ // Result painstakingly calculated by hand by rcurtin (recorded forever in his
+ // notebook). As a result of lack of precision of the by-hand result, the
+ // tolerance is fairly high.
+ BOOST_REQUIRE_CLOSE(objective, -1.5115, 0.01);
+}
+
+/**
+ * On a simple fake dataset, ensure that the initial gradient evaluation is
+ * correct.
+ */
+BOOST_AUTO_TEST_CASE(SoftmaxInitialGradient)
+{
+ // Useful but simple dataset with six points and two classes.
+ arma::mat data = "-0.1 -0.1 -0.1 0.1 0.1 0.1;"
+ " 1.0 0.0 -1.0 1.0 0.0 -1.0 ";
+ arma::uvec labels = " 0 0 0 1 1 1 ";
+
+ SoftmaxErrorFunction<SquaredEuclideanDistance> sef(data, labels);
+
+ arma::mat gradient;
+ arma::mat coordinates = arma::eye<arma::mat>(2, 2);
+ sef.Gradient(coordinates, gradient);
+
+ // Results painstakingly calculated by hand by rcurtin (recorded forever in
+ // his notebook). As a result of lack of precision of the by-hand result, the
+ // tolerance is fairly high.
+ BOOST_REQUIRE_CLOSE(gradient(0, 0), -0.089766, 0.05);
+ BOOST_REQUIRE_SMALL(gradient(1, 0), 1e-5);
+ BOOST_REQUIRE_SMALL(gradient(0, 1), 1e-5);
+ BOOST_REQUIRE_CLOSE(gradient(1, 1), 1.63823, 0.01);
+}
+
+/**
+ * On optimally separated datasets, ensure that the objective function is
+ * optimal (equal to the negative number of points).
+ */
+BOOST_AUTO_TEST_CASE(SoftmaxOptimalEvaluation)
+{
+ // Simple optimal dataset.
+ arma::mat data = " 500 500 -500 -500;"
+ " 1 0 1 0 ";
+ arma::uvec labels = " 0 0 1 1 ";
+
+ SoftmaxErrorFunction<SquaredEuclideanDistance> sef(data, labels);
+
+ double objective = sef.Evaluate(arma::eye<arma::mat>(2, 2));
+
+ // Use a very close tolerance for optimality; we need to be sure this function
+ // gives optimal results correctly.
+ BOOST_REQUIRE_CLOSE(objective, -4.0, 1e-10);
+}
+
+/**
+ * On optimally separated datasets, ensure that the gradient is zero.
+ */
+BOOST_AUTO_TEST_CASE(SoftmaxOptimalGradient)
+{
+ // Simple optimal dataset.
+ arma::mat data = " 500 500 -500 -500;"
+ " 1 0 1 0 ";
+ arma::uvec labels = " 0 0 1 1 ";
+
+ SoftmaxErrorFunction<SquaredEuclideanDistance> sef(data, labels);
+
+ arma::mat gradient;
+ sef.Gradient(arma::eye<arma::mat>(2, 2), gradient);
+
+ BOOST_REQUIRE_SMALL(gradient(0, 0), 1e-5);
+ BOOST_REQUIRE_SMALL(gradient(0, 1), 1e-5);
+ BOOST_REQUIRE_SMALL(gradient(1, 0), 1e-5);
+ BOOST_REQUIRE_SMALL(gradient(1, 1), 1e-5);
+}
+
+/**
+ * Ensure the separable objective function is right.
+ */
+BOOST_AUTO_TEST_CASE(SoftmaxSeparableObjective)
+{
+ // Useful but simple dataset with six points and two classes.
+ arma::mat data = "-0.1 -0.1 -0.1 0.1 0.1 0.1;"
+ " 1.0 0.0 -1.0 1.0 0.0 -1.0 ";
+ arma::uvec labels = " 0 0 0 1 1 1 ";
+
+ SoftmaxErrorFunction<SquaredEuclideanDistance> sef(data, labels);
+
+ // Results painstakingly calculated by hand by rcurtin (recorded forever in
+ // his notebook). As a result of lack of precision of the by-hand result, the
+ // tolerance is fairly high.
+ arma::mat coordinates = arma::eye<arma::mat>(2, 2);
+ BOOST_REQUIRE_CLOSE(sef.Evaluate(coordinates, 0), -0.22480, 0.01);
+ BOOST_REQUIRE_CLOSE(sef.Evaluate(coordinates, 1), -0.30613, 0.01);
+ BOOST_REQUIRE_CLOSE(sef.Evaluate(coordinates, 2), -0.22480, 0.01);
+ BOOST_REQUIRE_CLOSE(sef.Evaluate(coordinates, 3), -0.22480, 0.01);
+ BOOST_REQUIRE_CLOSE(sef.Evaluate(coordinates, 4), -0.30613, 0.01);
+ BOOST_REQUIRE_CLOSE(sef.Evaluate(coordinates, 5), -0.22480, 0.01);
+}
+
+/**
+ * Ensure the optimal separable objective function is right.
+ */
+BOOST_AUTO_TEST_CASE(OptimalSoftmaxSeparableObjective)
+{
+ // Simple optimal dataset.
+ arma::mat data = " 500 500 -500 -500;"
+ " 1 0 1 0 ";
+ arma::uvec labels = " 0 0 1 1 ";
+
+ SoftmaxErrorFunction<SquaredEuclideanDistance> sef(data, labels);
+
+ arma::mat coordinates = arma::eye<arma::mat>(2, 2);
+
+ // Use a very close tolerance for optimality; we need to be sure this function
+ // gives optimal results correctly.
+ BOOST_REQUIRE_CLOSE(sef.Evaluate(coordinates, 0), -1.0, 1e-10);
+ BOOST_REQUIRE_CLOSE(sef.Evaluate(coordinates, 1), -1.0, 1e-10);
+ BOOST_REQUIRE_CLOSE(sef.Evaluate(coordinates, 2), -1.0, 1e-10);
+ BOOST_REQUIRE_CLOSE(sef.Evaluate(coordinates, 3), -1.0, 1e-10);
+}
+
+/**
+ * Ensure the separable gradient is right.
+ */
+BOOST_AUTO_TEST_CASE(SoftmaxSeparableGradient)
+{
+ // Useful but simple dataset with six points and two classes.
+ arma::mat data = "-0.1 -0.1 -0.1 0.1 0.1 0.1;"
+ " 1.0 0.0 -1.0 1.0 0.0 -1.0 ";
+ arma::uvec labels = " 0 0 0 1 1 1 ";
+
+ SoftmaxErrorFunction<SquaredEuclideanDistance> sef(data, labels);
+
+ arma::mat coordinates = arma::eye<arma::mat>(2, 2);
+ arma::mat gradient(2, 2);
+
+ sef.Gradient(coordinates, 0, gradient);
+
+ BOOST_REQUIRE_CLOSE(gradient(0, 0), -2.0 * 0.0069708, 0.01);
+ BOOST_REQUIRE_CLOSE(gradient(0, 1), -2.0 * -0.0101707, 0.01);
+ BOOST_REQUIRE_CLOSE(gradient(1, 0), -2.0 * -0.0101707, 0.01);
+ BOOST_REQUIRE_CLOSE(gradient(1, 1), -2.0 * -0.14359, 0.01);
+
+ sef.Gradient(coordinates, 1, gradient);
+
+ BOOST_REQUIRE_CLOSE(gradient(0, 0), -2.0 * 0.008496, 0.01);
+ BOOST_REQUIRE_SMALL(gradient(0, 1), 1e-5);
+ BOOST_REQUIRE_SMALL(gradient(1, 0), 1e-5);
+ BOOST_REQUIRE_CLOSE(gradient(1, 1), -2.0 * -0.12238, 0.01);
+
+ sef.Gradient(coordinates, 2, gradient);
+
+ BOOST_REQUIRE_CLOSE(gradient(0, 0), -2.0 * 0.0069708, 0.01);
+ BOOST_REQUIRE_CLOSE(gradient(0, 1), -2.0 * 0.0101707, 0.01);
+ BOOST_REQUIRE_CLOSE(gradient(1, 0), -2.0 * 0.0101707, 0.01);
+ BOOST_REQUIRE_CLOSE(gradient(1, 1), -2.0 * -0.1435886, 0.01);
+
+ sef.Gradient(coordinates, 3, gradient);
+
+ BOOST_REQUIRE_CLOSE(gradient(0, 0), -2.0 * 0.0069708, 0.01);
+ BOOST_REQUIRE_CLOSE(gradient(0, 1), -2.0 * 0.0101707, 0.01);
+ BOOST_REQUIRE_CLOSE(gradient(1, 0), -2.0 * 0.0101707, 0.01);
+ BOOST_REQUIRE_CLOSE(gradient(1, 1), -2.0 * -0.1435886, 0.01);
+
+ sef.Gradient(coordinates, 4, gradient);
+
+ BOOST_REQUIRE_CLOSE(gradient(0, 0), -2.0 * 0.008496, 0.01);
+ BOOST_REQUIRE_SMALL(gradient(0, 1), 1e-5);
+ BOOST_REQUIRE_SMALL(gradient(1, 0), 1e-5);
+ BOOST_REQUIRE_CLOSE(gradient(1, 1), -2.0 * -0.12238, 0.01);
+
+ sef.Gradient(coordinates, 5, gradient);
+
+ BOOST_REQUIRE_CLOSE(gradient(0, 0), -2.0 * 0.0069708, 0.01);
+ BOOST_REQUIRE_CLOSE(gradient(0, 1), -2.0 * -0.0101707, 0.01);
+ BOOST_REQUIRE_CLOSE(gradient(1, 0), -2.0 * -0.0101707, 0.01);
+ BOOST_REQUIRE_CLOSE(gradient(1, 1), -2.0 * -0.1435886, 0.01);
+}
+
+//
+// Tests for the NCA algorithm.
+//
+
+/**
+ * On our simple dataset, ensure that the NCA algorithm fully separates the
+ * points.
+ */
+BOOST_AUTO_TEST_CASE(NCASGDSimpleDataset)
+{
+ // Useful but simple dataset with six points and two classes.
+ arma::mat data = "-0.1 -0.1 -0.1 0.1 0.1 0.1;"
+ " 1.0 0.0 -1.0 1.0 0.0 -1.0 ";
+ arma::uvec labels = " 0 0 0 1 1 1 ";
+
+ // Huge learning rate because this is so simple.
+ NCA<SquaredEuclideanDistance> nca(data, labels);
+ nca.Optimizer().StepSize() = 1.2;
+ nca.Optimizer().MaxIterations() = 300000;
+ nca.Optimizer().Tolerance() = 0;
+ nca.Optimizer().Shuffle() = true;
+
+ arma::mat outputMatrix;
+ nca.LearnDistance(outputMatrix);
+
+ // Ensure that the objective function is better now.
+ SoftmaxErrorFunction<SquaredEuclideanDistance> sef(data, labels);
+
+ double initObj = sef.Evaluate(arma::eye<arma::mat>(2, 2));
+ double finalObj = sef.Evaluate(outputMatrix);
+ arma::mat finalGradient;
+ sef.Gradient(outputMatrix, finalGradient);
+
+ // finalObj must be less than initObj.
+ BOOST_REQUIRE_LT(finalObj, initObj);
+ // Verify that final objective is optimal.
+ BOOST_REQUIRE_CLOSE(finalObj, -6.0, 0.005);
+ // The solution is not unique, so the best we can do is ensure the gradient
+ // norm is close to 0.
+ BOOST_REQUIRE_LT(arma::norm(finalGradient, 2), 1e-4);
+}
+
+BOOST_AUTO_TEST_CASE(NCALBFGSSimpleDataset)
+{
+ // Useful but simple dataset with six points and two classes.
+ arma::mat data = "-0.1 -0.1 -0.1 0.1 0.1 0.1;"
+ " 1.0 0.0 -1.0 1.0 0.0 -1.0 ";
+ arma::uvec labels = " 0 0 0 1 1 1 ";
+
+ // Huge learning rate because this is so simple.
+ NCA<SquaredEuclideanDistance, L_BFGS> nca(data, labels);
+ nca.Optimizer().NumBasis() = 5;
+
+ arma::mat outputMatrix;
+ nca.LearnDistance(outputMatrix);
+
+ // Ensure that the objective function is better now.
+ SoftmaxErrorFunction<SquaredEuclideanDistance> sef(data, labels);
+
+ double initObj = sef.Evaluate(arma::eye<arma::mat>(2, 2));
+ double finalObj = sef.Evaluate(outputMatrix);
+ arma::mat finalGradient;
+ sef.Gradient(outputMatrix, finalGradient);
+
+ // finalObj must be less than initObj.
+ BOOST_REQUIRE_LT(finalObj, initObj);
+ // Verify that final objective is optimal.
+ BOOST_REQUIRE_CLOSE(finalObj, -6.0, 1e-5);
+ // The solution is not unique, so the best we can do is ensure the gradient
+ // norm is close to 0.
+ BOOST_REQUIRE_LT(arma::norm(finalGradient, 2), 1e-6);
+
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/nmf_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/nmf_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/nmf_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,127 +0,0 @@
-/**
- * @file nmf_test.cpp
- * @author Mohan Rajendran
- *
- * Test file for NMF class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/methods/nmf/nmf.hpp>
-#include <mlpack/methods/nmf/random_acol_init.hpp>
-#include <mlpack/methods/nmf/mult_div_update_rules.hpp>
-#include <mlpack/methods/nmf/als_update_rules.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-BOOST_AUTO_TEST_SUITE(NMFTest);
-
-using namespace std;
-using namespace arma;
-using namespace mlpack;
-using namespace mlpack::nmf;
-
-/**
- * Check the if the product of the calculated factorization is close to the
- * input matrix. Default case
- */
-BOOST_AUTO_TEST_CASE(NMFDefaultTest)
-{
- mat w = randu<mat>(20, 16);
- mat h = randu<mat>(16, 20);
- mat v = w * h;
- size_t r = 16;
-
- NMF<> nmf;
- nmf.Apply(v, r, w, h);
-
- mat wh = w * h;
-
- for (size_t row = 0; row < 5; row++)
- for (size_t col = 0; col < 5; col++)
- BOOST_REQUIRE_CLOSE(v(row, col), wh(row, col), 10.0);
-}
-
-/**
- * Check the if the product of the calculated factorization is close to the
- * input matrix. Random Acol Initialization Distance Minimization Update
- */
-BOOST_AUTO_TEST_CASE(NMFAcolDistTest)
-{
- mat w = randu<mat>(20, 16);
- mat h = randu<mat>(16, 20);
- mat v = w * h;
- size_t r = 16;
-
- NMF<RandomAcolInitialization<> > nmf;
- nmf.Apply(v, r, w, h);
-
- mat wh = w * h;
-
- for (size_t row = 0; row < 5; row++)
- for (size_t col = 0; col < 5; col++)
- BOOST_REQUIRE_CLOSE(v(row, col), wh(row, col), 10.0);
-}
-
-/**
- * Check the if the product of the calculated factorization is close to the
- * input matrix. Random Initialization Divergence Minimization Update
- */
-BOOST_AUTO_TEST_CASE(NMFRandomDivTest)
-{
- mat w = randu<mat>(20, 16);
- mat h = randu<mat>(16, 20);
- mat v = w * h;
- size_t r = 16;
-
- NMF<RandomInitialization,
- WMultiplicativeDivergenceRule,
- HMultiplicativeDivergenceRule> nmf;
- nmf.Apply(v, r, w, h);
-
- mat wh = w * h;
-
- for (size_t row = 0; row < 5; row++)
- for (size_t col = 0; col < 5; col++)
- BOOST_REQUIRE_CLOSE(v(row, col), wh(row, col), 10.0);
-}
-
-/**
- * Check that the product of the calculated factorization is close to the
- * input matrix. This uses the random initialization and alternating least
- * squares update rule.
- */
-BOOST_AUTO_TEST_CASE(NMFALSTest)
-{
- mat w = randu<mat>(20, 16);
- mat h = randu<mat>(16, 20);
- mat v = w * h;
- size_t r = 16;
-
- NMF<RandomInitialization,
- WAlternatingLeastSquaresRule,
- HAlternatingLeastSquaresRule> nmf;
- nmf.Apply(v, r, w, h);
-
- mat wh = w * h;
-
- for (size_t row = 0; row < 5; row++)
- for (size_t col = 0; col < 5; col++)
- BOOST_REQUIRE_CLOSE(v(row, col), wh(row, col), 13.0);
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/nmf_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/nmf_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/nmf_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/nmf_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,127 @@
+/**
+ * @file nmf_test.cpp
+ * @author Mohan Rajendran
+ *
+ * Test file for NMF class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/methods/nmf/nmf.hpp>
+#include <mlpack/methods/nmf/random_acol_init.hpp>
+#include <mlpack/methods/nmf/mult_div_update_rules.hpp>
+#include <mlpack/methods/nmf/als_update_rules.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+BOOST_AUTO_TEST_SUITE(NMFTest);
+
+using namespace std;
+using namespace arma;
+using namespace mlpack;
+using namespace mlpack::nmf;
+
+/**
+ * Check the if the product of the calculated factorization is close to the
+ * input matrix. Default case
+ */
+BOOST_AUTO_TEST_CASE(NMFDefaultTest)
+{
+ mat w = randu<mat>(20, 16);
+ mat h = randu<mat>(16, 20);
+ mat v = w * h;
+ size_t r = 16;
+
+ NMF<> nmf;
+ nmf.Apply(v, r, w, h);
+
+ mat wh = w * h;
+
+ for (size_t row = 0; row < 5; row++)
+ for (size_t col = 0; col < 5; col++)
+ BOOST_REQUIRE_CLOSE(v(row, col), wh(row, col), 10.0);
+}
+
+/**
+ * Check the if the product of the calculated factorization is close to the
+ * input matrix. Random Acol Initialization Distance Minimization Update
+ */
+BOOST_AUTO_TEST_CASE(NMFAcolDistTest)
+{
+ mat w = randu<mat>(20, 16);
+ mat h = randu<mat>(16, 20);
+ mat v = w * h;
+ size_t r = 16;
+
+ NMF<RandomAcolInitialization<> > nmf;
+ nmf.Apply(v, r, w, h);
+
+ mat wh = w * h;
+
+ for (size_t row = 0; row < 5; row++)
+ for (size_t col = 0; col < 5; col++)
+ BOOST_REQUIRE_CLOSE(v(row, col), wh(row, col), 10.0);
+}
+
+/**
+ * Check the if the product of the calculated factorization is close to the
+ * input matrix. Random Initialization Divergence Minimization Update
+ */
+BOOST_AUTO_TEST_CASE(NMFRandomDivTest)
+{
+ mat w = randu<mat>(20, 16);
+ mat h = randu<mat>(16, 20);
+ mat v = w * h;
+ size_t r = 16;
+
+ NMF<RandomInitialization,
+ WMultiplicativeDivergenceRule,
+ HMultiplicativeDivergenceRule> nmf;
+ nmf.Apply(v, r, w, h);
+
+ mat wh = w * h;
+
+ for (size_t row = 0; row < 5; row++)
+ for (size_t col = 0; col < 5; col++)
+ BOOST_REQUIRE_CLOSE(v(row, col), wh(row, col), 10.0);
+}
+
+/**
+ * Check that the product of the calculated factorization is close to the
+ * input matrix. This uses the random initialization and alternating least
+ * squares update rule.
+ */
+BOOST_AUTO_TEST_CASE(NMFALSTest)
+{
+ mat w = randu<mat>(20, 16);
+ mat h = randu<mat>(16, 20);
+ mat v = w * h;
+ size_t r = 16;
+
+ NMF<RandomInitialization,
+ WAlternatingLeastSquaresRule,
+ HAlternatingLeastSquaresRule> nmf;
+ nmf.Apply(v, r, w, h);
+
+ mat wh = w * h;
+
+ for (size_t row = 0; row < 5; row++)
+ for (size_t col = 0; col < 5; col++)
+ BOOST_REQUIRE_CLOSE(v(row, col), wh(row, col), 13.0);
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/old_boost_test_definitions.hpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/old_boost_test_definitions.hpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/old_boost_test_definitions.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,53 +0,0 @@
-/**
- * @file old_boost_test_definitions.hpp
- * @author Ryan Curtin
- *
- * Ancient Boost.Test versions don't act how we expect. This file includes the
- * things we need to fix that.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#ifndef __MLPACK_TESTS_OLD_BOOST_TEST_DEFINITIONS_HPP
-#define __MLPACK_TESTS_OLD_BOOST_TEST_DEFINITIONS_HPP
-
-#include <boost/version.hpp>
-
-// This is only necessary for pre-1.36 Boost.Test.
-#if BOOST_VERSION < 103600
-
-#include <boost/test/floating_point_comparison.hpp>
-#include <boost/test/auto_unit_test.hpp>
-
-// This depends on other macros. Probably not a great idea... but it works, and
-// we only need it for ancient Boost versions.
-#define BOOST_REQUIRE_GE( L, R ) \
- BOOST_REQUIRE_EQUAL( (L >= R), true )
-
-#define BOOST_REQUIRE_NE( L, R ) \
- BOOST_REQUIRE_EQUAL( (L != R), true )
-
-#define BOOST_REQUIRE_LE( L, R ) \
- BOOST_REQUIRE_EQUAL( (L <= R), true )
-
-#define BOOST_REQUIRE_LT( L, R ) \
- BOOST_REQUIRE_EQUAL( (L < R), true )
-
-#define BOOST_REQUIRE_GT( L, R ) \
- BOOST_REQUIRE_EQUAL( (L > R), true )
-
-#endif
-
-#endif
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/old_boost_test_definitions.hpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/old_boost_test_definitions.hpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/old_boost_test_definitions.hpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/old_boost_test_definitions.hpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,53 @@
+/**
+ * @file old_boost_test_definitions.hpp
+ * @author Ryan Curtin
+ *
+ * Ancient Boost.Test versions don't act how we expect. This file includes the
+ * things we need to fix that.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#ifndef __MLPACK_TESTS_OLD_BOOST_TEST_DEFINITIONS_HPP
+#define __MLPACK_TESTS_OLD_BOOST_TEST_DEFINITIONS_HPP
+
+#include <boost/version.hpp>
+
+// This is only necessary for pre-1.36 Boost.Test.
+#if BOOST_VERSION < 103600
+
+#include <boost/test/floating_point_comparison.hpp>
+#include <boost/test/auto_unit_test.hpp>
+
+// This depends on other macros. Probably not a great idea... but it works, and
+// we only need it for ancient Boost versions.
+#define BOOST_REQUIRE_GE( L, R ) \
+ BOOST_REQUIRE_EQUAL( (L >= R), true )
+
+#define BOOST_REQUIRE_NE( L, R ) \
+ BOOST_REQUIRE_EQUAL( (L != R), true )
+
+#define BOOST_REQUIRE_LE( L, R ) \
+ BOOST_REQUIRE_EQUAL( (L <= R), true )
+
+#define BOOST_REQUIRE_LT( L, R ) \
+ BOOST_REQUIRE_EQUAL( (L < R), true )
+
+#define BOOST_REQUIRE_GT( L, R ) \
+ BOOST_REQUIRE_EQUAL( (L > R), true )
+
+#endif
+
+#endif
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/pca_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/pca_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/pca_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,93 +0,0 @@
-/**
- * @file pca_test.cpp
- * @author Ajinkya Kale
- *
- * Test file for PCA class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/methods/pca/pca.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-BOOST_AUTO_TEST_SUITE(PCATest);
-
-using namespace std;
-using namespace arma;
-using namespace mlpack;
-using namespace mlpack::pca;
-
-/**
- * Compare the output of our PCA implementation with Armadillo's.
- */
-BOOST_AUTO_TEST_CASE(ArmaComparisonPCATest)
-{
- mat coeff, coeff1;
- vec eigVal, eigVal1;
- mat score, score1;
-
- mat data = randu<mat>(100,100);
-
- PCA p;
-
- p.Apply(data, score1, eigVal1, coeff1);
- princomp(coeff, score, eigVal, trans(data));
-
- // Verify the PCA results based on the eigenvalues.
- for(size_t i = 0; i < eigVal.n_rows; i++)
- for(size_t j = 0; j < eigVal.n_cols; j++)
- BOOST_REQUIRE_SMALL(eigVal(i, j) - eigVal1(i, j), 0.0001);
-}
-
-/**
- * Test that dimensionality reduction with PCA works the same way MATLAB does
- * (which should be correct!).
- */
-BOOST_AUTO_TEST_CASE(PCADimensionalityReductionTest)
-{
- // Fake, simple dataset. The results we will compare against are from MATLAB.
- mat data("1 0 2 3 9;"
- "5 2 8 4 8;"
- "6 7 3 1 8");
-
- // Now run PCA to reduce the dimensionality.
- PCA p;
- p.Apply(data, 2); // Reduce to 2 dimensions.
-
- // Compare with correct results.
- mat correct("-1.53781086 -3.51358020 -0.16139887 -1.87706634 7.08985628;"
- " 1.29937798 3.45762685 -2.69910005 -3.15620704 1.09830225");
-
- // If the eigenvectors are pointed opposite directions, they will cancel
-// each other out in this summation.
- for(size_t i = 0; i < data.n_rows; i++)
- {
- if (fabs(correct(i, 1) + data(i,1)) < 0.001 /* arbitrary */)
- {
- // Flip Armadillo coefficients for this column.
- data.row(i) *= -1;
- }
- }
-
- for (size_t row = 0; row < 2; row++)
- for (size_t col = 0; col < 5; col++)
- BOOST_REQUIRE_CLOSE(data(row, col), correct(row, col), 1e-3);
-}
-
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/pca_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/pca_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/pca_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/pca_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,93 @@
+/**
+ * @file pca_test.cpp
+ * @author Ajinkya Kale
+ *
+ * Test file for PCA class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/methods/pca/pca.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+BOOST_AUTO_TEST_SUITE(PCATest);
+
+using namespace std;
+using namespace arma;
+using namespace mlpack;
+using namespace mlpack::pca;
+
+/**
+ * Compare the output of our PCA implementation with Armadillo's.
+ */
+BOOST_AUTO_TEST_CASE(ArmaComparisonPCATest)
+{
+ mat coeff, coeff1;
+ vec eigVal, eigVal1;
+ mat score, score1;
+
+ mat data = randu<mat>(100,100);
+
+ PCA p;
+
+ p.Apply(data, score1, eigVal1, coeff1);
+ princomp(coeff, score, eigVal, trans(data));
+
+ // Verify the PCA results based on the eigenvalues.
+ for(size_t i = 0; i < eigVal.n_rows; i++)
+ for(size_t j = 0; j < eigVal.n_cols; j++)
+ BOOST_REQUIRE_SMALL(eigVal(i, j) - eigVal1(i, j), 0.0001);
+}
+
+/**
+ * Test that dimensionality reduction with PCA works the same way MATLAB does
+ * (which should be correct!).
+ */
+BOOST_AUTO_TEST_CASE(PCADimensionalityReductionTest)
+{
+ // Fake, simple dataset. The results we will compare against are from MATLAB.
+ mat data("1 0 2 3 9;"
+ "5 2 8 4 8;"
+ "6 7 3 1 8");
+
+ // Now run PCA to reduce the dimensionality.
+ PCA p;
+ p.Apply(data, 2); // Reduce to 2 dimensions.
+
+ // Compare with correct results.
+ mat correct("-1.53781086 -3.51358020 -0.16139887 -1.87706634 7.08985628;"
+ " 1.29937798 3.45762685 -2.69910005 -3.15620704 1.09830225");
+
+ // If the eigenvectors are pointed opposite directions, they will cancel
+// each other out in this summation.
+ for(size_t i = 0; i < data.n_rows; i++)
+ {
+ if (fabs(correct(i, 1) + data(i,1)) < 0.001 /* arbitrary */)
+ {
+ // Flip Armadillo coefficients for this column.
+ data.row(i) *= -1;
+ }
+ }
+
+ for (size_t row = 0; row < 2; row++)
+ for (size_t col = 0; col < 5; col++)
+ BOOST_REQUIRE_CLOSE(data(row, col), correct(row, col), 1e-3);
+}
+
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/radical_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/radical_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/radical_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,71 +0,0 @@
-/**
- * @file radical_main.cpp
- * @author Nishant Mehta
- *
- * Executable for RADICAL
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <armadillo>
-#include <mlpack/core.hpp>
-#include <mlpack/methods/radical/radical.hpp>
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-BOOST_AUTO_TEST_SUITE(RadicalTest);
-
-using namespace mlpack;
-using namespace mlpack::radical;
-using namespace std;
-using namespace arma;
-
-BOOST_AUTO_TEST_CASE(Radical_Test_Radical3D)
-{
- mat matX;
- data::Load("data_3d_mixed.txt", matX);
-
- Radical rad(0.175, 5, 100, matX.n_rows - 1);
-
- mat matY;
- mat matW;
- rad.DoRadical(matX, matY, matW);
-
- mat matYT = trans(matY);
- double valEst = 0;
-
- for (uword i = 0; i < matYT.n_cols; i++)
- {
- vec y = vec(matYT.col(i));
- valEst += rad.Vasicek(y);
- }
-
- mat matS;
- data::Load("data_3d_ind.txt", matS);
- rad.DoRadical(matS, matY, matW);
-
- matYT = trans(matY);
- double valBest = 0;
-
- for (uword i = 0; i < matYT.n_cols; i++)
- {
- vec y = vec(matYT.col(i));
- valBest += rad.Vasicek(y);
- }
-
- BOOST_REQUIRE_CLOSE(valBest, valEst, 0.2);
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/radical_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/radical_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/radical_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/radical_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,71 @@
+/**
+ * @file radical_main.cpp
+ * @author Nishant Mehta
+ *
+ * Executable for RADICAL
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <armadillo>
+#include <mlpack/core.hpp>
+#include <mlpack/methods/radical/radical.hpp>
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+BOOST_AUTO_TEST_SUITE(RadicalTest);
+
+using namespace mlpack;
+using namespace mlpack::radical;
+using namespace std;
+using namespace arma;
+
+BOOST_AUTO_TEST_CASE(Radical_Test_Radical3D)
+{
+ mat matX;
+ data::Load("data_3d_mixed.txt", matX);
+
+ Radical rad(0.175, 5, 100, matX.n_rows - 1);
+
+ mat matY;
+ mat matW;
+ rad.DoRadical(matX, matY, matW);
+
+ mat matYT = trans(matY);
+ double valEst = 0;
+
+ for (uword i = 0; i < matYT.n_cols; i++)
+ {
+ vec y = vec(matYT.col(i));
+ valEst += rad.Vasicek(y);
+ }
+
+ mat matS;
+ data::Load("data_3d_ind.txt", matS);
+ rad.DoRadical(matS, matY, matW);
+
+ matYT = trans(matY);
+ double valBest = 0;
+
+ for (uword i = 0; i < matYT.n_cols; i++)
+ {
+ vec y = vec(matYT.col(i));
+ valBest += rad.Vasicek(y);
+ }
+
+ BOOST_REQUIRE_CLOSE(valBest, valEst, 0.2);
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/range_search_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/range_search_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/range_search_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,587 +0,0 @@
-/**
- * @file range_search_test.cpp
- * @author Ryan Curtin
- *
- * Test file for RangeSearch<> class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/methods/range_search/range_search.hpp>
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::range;
-using namespace mlpack::math;
-using namespace std;
-
-BOOST_AUTO_TEST_SUITE(RangeSearchTest);
-
-// Get our results into a sorted format, so we can actually then test for
-// correctness.
-void SortResults(const vector<vector<size_t> >& neighbors,
- const vector<vector<double> >& distances,
- vector<vector<pair<double, size_t> > >& output)
-{
- output.resize(neighbors.size());
- for (size_t i = 0; i < neighbors.size(); i++)
- {
- output[i].resize(neighbors[i].size());
- for (size_t j = 0; j < neighbors[i].size(); j++)
- output[i][j] = make_pair(distances[i][j], neighbors[i][j]);
-
- // Now that it's constructed, sort it.
- sort(output[i].begin(), output[i].end());
- }
-}
-
-/**
- * Simple range-search test with small, synthetic dataset. This is an
- * exhaustive test, which checks that each method for performing the calculation
- * (dual-tree, single-tree, naive) produces the correct results. An
- * eleven-point dataset and the points within three ranges are taken. The
- * dataset is in one dimension for simplicity -- the correct functionality of
- * distance functions is not tested here.
- */
-BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
-{
- // Set up our data.
- arma::mat data(1, 11);
- data[0] = 0.05; // Row addressing is unnecessary (they are all 0).
- data[1] = 0.35;
- data[2] = 0.15;
- data[3] = 1.25;
- data[4] = 5.05;
- data[5] = -0.22;
- data[6] = -2.00;
- data[7] = -1.30;
- data[8] = 0.45;
- data[9] = 0.90;
- data[10] = 1.00;
-
- // We will loop through three times, one for each method of performing the
- // calculation.
- for (int i = 0; i < 3; i++)
- {
- RangeSearch<>* rs;
- arma::mat dataMutable = data;
- switch (i)
- {
- case 0: // Use the dual-tree method.
- rs = new RangeSearch<>(dataMutable, false, false, 1);
- break;
- case 1: // Use the single-tree method.
- rs = new RangeSearch<>(dataMutable, false, true, 1);
- break;
- case 2: // Use the naive method.
- rs = new RangeSearch<>(dataMutable, true);
- break;
- }
-
- // Now perform the first calculation. Points within 0.50.
- vector<vector<size_t> > neighbors;
- vector<vector<double> > distances;
- rs->Search(Range(0.0, 0.50), neighbors, distances);
-
- // Now the exhaustive check for correctness. This will be long. We must
- // also remember that the distances returned are squared distances. As a
- // result, distance comparisons are written out as (distance * distance) for
- // readability.
- vector<vector<pair<double, size_t> > > sortedOutput;
- SortResults(neighbors, distances, sortedOutput);
-
- // Neighbors of point 0.
- BOOST_REQUIRE(sortedOutput[0].size() == 4);
- BOOST_REQUIRE(sortedOutput[0][0].second == 2);
- BOOST_REQUIRE_CLOSE(sortedOutput[0][0].first, (0.10 * 0.10), 1e-5);
- BOOST_REQUIRE(sortedOutput[0][1].second == 5);
- BOOST_REQUIRE_CLOSE(sortedOutput[0][1].first, (0.27 * 0.27), 1e-5);
- BOOST_REQUIRE(sortedOutput[0][2].second == 1);
- BOOST_REQUIRE_CLOSE(sortedOutput[0][2].first, (0.30 * 0.30), 1e-5);
- BOOST_REQUIRE(sortedOutput[0][3].second == 8);
- BOOST_REQUIRE_CLOSE(sortedOutput[0][3].first, (0.40 * 0.40), 1e-5);
-
- // Neighbors of point 1.
- BOOST_REQUIRE(sortedOutput[1].size() == 6);
- BOOST_REQUIRE(sortedOutput[1][0].second == 8);
- BOOST_REQUIRE_CLOSE(sortedOutput[1][0].first, (0.10 * 0.10), 1e-5);
- BOOST_REQUIRE(sortedOutput[1][1].second == 2);
- BOOST_REQUIRE_CLOSE(sortedOutput[1][1].first, (0.20 * 0.20), 1e-5);
- BOOST_REQUIRE(sortedOutput[1][2].second == 0);
- BOOST_REQUIRE_CLOSE(sortedOutput[1][2].first, (0.30 * 0.30), 1e-5);
- BOOST_REQUIRE(sortedOutput[1][3].second == 9);
- BOOST_REQUIRE_CLOSE(sortedOutput[1][3].first, (0.55 * 0.55), 1e-5);
- BOOST_REQUIRE(sortedOutput[1][4].second == 5);
- BOOST_REQUIRE_CLOSE(sortedOutput[1][4].first, (0.57 * 0.57), 1e-5);
- BOOST_REQUIRE(sortedOutput[1][5].second == 10);
- BOOST_REQUIRE_CLOSE(sortedOutput[1][5].first, (0.65 * 0.65), 1e-5);
-
- // Neighbors of point 2.
- BOOST_REQUIRE(sortedOutput[2].size() == 4);
- BOOST_REQUIRE(sortedOutput[2][0].second == 0);
- BOOST_REQUIRE_CLOSE(sortedOutput[2][0].first, (0.10 * 0.10), 1e-5);
- BOOST_REQUIRE(sortedOutput[2][1].second == 1);
- BOOST_REQUIRE_CLOSE(sortedOutput[2][1].first, (0.20 * 0.20), 1e-5);
- BOOST_REQUIRE(sortedOutput[2][2].second == 8);
- BOOST_REQUIRE_CLOSE(sortedOutput[2][2].first, (0.30 * 0.30), 1e-5);
- BOOST_REQUIRE(sortedOutput[2][3].second == 5);
- BOOST_REQUIRE_CLOSE(sortedOutput[2][3].first, (0.37 * 0.37), 1e-5);
-
- // Neighbors of point 3.
- BOOST_REQUIRE(sortedOutput[3].size() == 2);
- BOOST_REQUIRE(sortedOutput[3][0].second == 10);
- BOOST_REQUIRE_CLOSE(sortedOutput[3][0].first, (0.25 * 0.25), 1e-5);
- BOOST_REQUIRE(sortedOutput[3][1].second == 9);
- BOOST_REQUIRE_CLOSE(sortedOutput[3][1].first, (0.35 * 0.35), 1e-5);
-
- // Neighbors of point 4.
- BOOST_REQUIRE(sortedOutput[4].size() == 0);
-
- // Neighbors of point 5.
- BOOST_REQUIRE(sortedOutput[5].size() == 4);
- BOOST_REQUIRE(sortedOutput[5][0].second == 0);
- BOOST_REQUIRE_CLOSE(sortedOutput[5][0].first, (0.27 * 0.27), 1e-5);
- BOOST_REQUIRE(sortedOutput[5][1].second == 2);
- BOOST_REQUIRE_CLOSE(sortedOutput[5][1].first, (0.37 * 0.37), 1e-5);
- BOOST_REQUIRE(sortedOutput[5][2].second == 1);
- BOOST_REQUIRE_CLOSE(sortedOutput[5][2].first, (0.57 * 0.57), 1e-5);
- BOOST_REQUIRE(sortedOutput[5][3].second == 8);
- BOOST_REQUIRE_CLOSE(sortedOutput[5][3].first, (0.67 * 0.67), 1e-5);
-
- // Neighbors of point 6.
- BOOST_REQUIRE(sortedOutput[6].size() == 1);
- BOOST_REQUIRE(sortedOutput[6][0].second == 7);
- BOOST_REQUIRE_CLOSE(sortedOutput[6][0].first, (0.70 * 0.70), 1e-5);
-
- // Neighbors of point 7.
- BOOST_REQUIRE(sortedOutput[7].size() == 1);
- BOOST_REQUIRE(sortedOutput[7][0].second == 6);
- BOOST_REQUIRE_CLOSE(sortedOutput[7][0].first, (0.70 * 0.70), 1e-5);
-
- // Neighbors of point 8.
- BOOST_REQUIRE(sortedOutput[8].size() == 6);
- BOOST_REQUIRE(sortedOutput[8][0].second == 1);
- BOOST_REQUIRE_CLOSE(sortedOutput[8][0].first, (0.10 * 0.10), 1e-5);
- BOOST_REQUIRE(sortedOutput[8][1].second == 2);
- BOOST_REQUIRE_CLOSE(sortedOutput[8][1].first, (0.30 * 0.30), 1e-5);
- BOOST_REQUIRE(sortedOutput[8][2].second == 0);
- BOOST_REQUIRE_CLOSE(sortedOutput[8][2].first, (0.40 * 0.40), 1e-5);
- BOOST_REQUIRE(sortedOutput[8][3].second == 9);
- BOOST_REQUIRE_CLOSE(sortedOutput[8][3].first, (0.45 * 0.45), 1e-5);
- BOOST_REQUIRE(sortedOutput[8][4].second == 10);
- BOOST_REQUIRE_CLOSE(sortedOutput[8][4].first, (0.55 * 0.55), 1e-5);
- BOOST_REQUIRE(sortedOutput[8][5].second == 5);
- BOOST_REQUIRE_CLOSE(sortedOutput[8][5].first, (0.67 * 0.67), 1e-5);
-
- // Neighbors of point 9.
- BOOST_REQUIRE(sortedOutput[9].size() == 4);
- BOOST_REQUIRE(sortedOutput[9][0].second == 10);
- BOOST_REQUIRE_CLOSE(sortedOutput[9][0].first, (0.10 * 0.10), 1e-5);
- BOOST_REQUIRE(sortedOutput[9][1].second == 3);
- BOOST_REQUIRE_CLOSE(sortedOutput[9][1].first, (0.35 * 0.35), 1e-5);
- BOOST_REQUIRE(sortedOutput[9][2].second == 8);
- BOOST_REQUIRE_CLOSE(sortedOutput[9][2].first, (0.45 * 0.45), 1e-5);
- BOOST_REQUIRE(sortedOutput[9][3].second == 1);
- BOOST_REQUIRE_CLOSE(sortedOutput[9][3].first, (0.55 * 0.55), 1e-5);
-
- // Neighbors of point 10.
- BOOST_REQUIRE(sortedOutput[10].size() == 4);
- BOOST_REQUIRE(sortedOutput[10][0].second == 9);
- BOOST_REQUIRE_CLOSE(sortedOutput[10][0].first, (0.10 * 0.10), 1e-5);
- BOOST_REQUIRE(sortedOutput[10][1].second == 3);
- BOOST_REQUIRE_CLOSE(sortedOutput[10][1].first, (0.25 * 0.25), 1e-5);
- BOOST_REQUIRE(sortedOutput[10][2].second == 8);
- BOOST_REQUIRE_CLOSE(sortedOutput[10][2].first, (0.55 * 0.55), 1e-5);
- BOOST_REQUIRE(sortedOutput[10][3].second == 1);
- BOOST_REQUIRE_CLOSE(sortedOutput[10][3].first, (0.65 * 0.65), 1e-5);
-
- // Now do it again with a different range: [0.5 1.0].
- rs->Search(Range(0.5, 1.0), neighbors, distances);
- SortResults(neighbors, distances, sortedOutput);
-
- // Neighbors of point 0.
- BOOST_REQUIRE(sortedOutput[0].size() == 2);
- BOOST_REQUIRE(sortedOutput[0][0].second == 9);
- BOOST_REQUIRE_CLOSE(sortedOutput[0][0].first, (0.85 * 0.85), 1e-5);
- BOOST_REQUIRE(sortedOutput[0][1].second == 10);
- BOOST_REQUIRE_CLOSE(sortedOutput[0][1].first, (0.95 * 0.95), 1e-5);
-
- // Neighbors of point 1.
- BOOST_REQUIRE(sortedOutput[1].size() == 1);
- BOOST_REQUIRE(sortedOutput[1][0].second == 3);
- BOOST_REQUIRE_CLOSE(sortedOutput[1][0].first, (0.90 * 0.90), 1e-5);
-
- // Neighbors of point 2.
- BOOST_REQUIRE(sortedOutput[2].size() == 2);
- BOOST_REQUIRE(sortedOutput[2][0].second == 9);
- BOOST_REQUIRE_CLOSE(sortedOutput[2][0].first, (0.75 * 0.75), 1e-5);
- BOOST_REQUIRE(sortedOutput[2][1].second == 10);
- BOOST_REQUIRE_CLOSE(sortedOutput[2][1].first, (0.85 * 0.85), 1e-5);
-
- // Neighbors of point 3.
- BOOST_REQUIRE(sortedOutput[3].size() == 2);
- BOOST_REQUIRE(sortedOutput[3][0].second == 8);
- BOOST_REQUIRE_CLOSE(sortedOutput[3][0].first, (0.80 * 0.80), 1e-5);
- BOOST_REQUIRE(sortedOutput[3][1].second == 1);
- BOOST_REQUIRE_CLOSE(sortedOutput[3][1].first, (0.90 * 0.90), 1e-5);
-
- // Neighbors of point 4.
- BOOST_REQUIRE(sortedOutput[4].size() == 0);
-
- // Neighbors of point 5.
- BOOST_REQUIRE(sortedOutput[5].size() == 0);
-
- // Neighbors of point 6.
- BOOST_REQUIRE(sortedOutput[6].size() == 0);
-
- // Neighbors of point 7.
- BOOST_REQUIRE(sortedOutput[7].size() == 0);
-
- // Neighbors of point 8.
- BOOST_REQUIRE(sortedOutput[8].size() == 1);
- BOOST_REQUIRE(sortedOutput[8][0].second == 3);
- BOOST_REQUIRE_CLOSE(sortedOutput[8][0].first, (0.80 * 0.80), 1e-5);
-
- // Neighbors of point 9.
- BOOST_REQUIRE(sortedOutput[9].size() == 2);
- BOOST_REQUIRE(sortedOutput[9][0].second == 2);
- BOOST_REQUIRE_CLOSE(sortedOutput[9][0].first, (0.75 * 0.75), 1e-5);
- BOOST_REQUIRE(sortedOutput[9][1].second == 0);
- BOOST_REQUIRE_CLOSE(sortedOutput[9][1].first, (0.85 * 0.85), 1e-5);
-
- // Neighbors of point 10.
- BOOST_REQUIRE(sortedOutput[10].size() == 2);
- BOOST_REQUIRE(sortedOutput[10][0].second == 2);
- BOOST_REQUIRE_CLOSE(sortedOutput[10][0].first, (0.85 * 0.85), 1e-5);
- BOOST_REQUIRE(sortedOutput[10][1].second == 0);
- BOOST_REQUIRE_CLOSE(sortedOutput[10][1].first, (0.95 * 0.95), 1e-5);
-
- // Now do it again with a different range: [1.0 inf].
- rs->Search(Range(1.0, numeric_limits<double>::infinity()), neighbors,
- distances);
- SortResults(neighbors, distances, sortedOutput);
-
- // Neighbors of point 0.
- BOOST_REQUIRE(sortedOutput[0].size() == 4);
- BOOST_REQUIRE(sortedOutput[0][0].second == 3);
- BOOST_REQUIRE_CLOSE(sortedOutput[0][0].first, (1.20 * 1.20), 1e-5);
- BOOST_REQUIRE(sortedOutput[0][1].second == 7);
- BOOST_REQUIRE_CLOSE(sortedOutput[0][1].first, (1.35 * 1.35), 1e-5);
- BOOST_REQUIRE(sortedOutput[0][2].second == 6);
- BOOST_REQUIRE_CLOSE(sortedOutput[0][2].first, (2.05 * 2.05), 1e-5);
- BOOST_REQUIRE(sortedOutput[0][3].second == 4);
- BOOST_REQUIRE_CLOSE(sortedOutput[0][3].first, (5.00 * 5.00), 1e-5);
-
- // Neighbors of point 1.
- BOOST_REQUIRE(sortedOutput[1].size() == 3);
- BOOST_REQUIRE(sortedOutput[1][0].second == 7);
- BOOST_REQUIRE_CLOSE(sortedOutput[1][0].first, (1.65 * 1.65), 1e-5);
- BOOST_REQUIRE(sortedOutput[1][1].second == 6);
- BOOST_REQUIRE_CLOSE(sortedOutput[1][1].first, (2.35 * 2.35), 1e-5);
- BOOST_REQUIRE(sortedOutput[1][2].second == 4);
- BOOST_REQUIRE_CLOSE(sortedOutput[1][2].first, (4.70 * 4.70), 1e-5);
-
- // Neighbors of point 2.
- BOOST_REQUIRE(sortedOutput[2].size() == 4);
- BOOST_REQUIRE(sortedOutput[2][0].second == 3);
- BOOST_REQUIRE_CLOSE(sortedOutput[2][0].first, (1.10 * 1.10), 1e-5);
- BOOST_REQUIRE(sortedOutput[2][1].second == 7);
- BOOST_REQUIRE_CLOSE(sortedOutput[2][1].first, (1.45 * 1.45), 1e-5);
- BOOST_REQUIRE(sortedOutput[2][2].second == 6);
- BOOST_REQUIRE_CLOSE(sortedOutput[2][2].first, (2.15 * 2.15), 1e-5);
- BOOST_REQUIRE(sortedOutput[2][3].second == 4);
- BOOST_REQUIRE_CLOSE(sortedOutput[2][3].first, (4.90 * 4.90), 1e-5);
-
- // Neighbors of point 3.
- BOOST_REQUIRE(sortedOutput[3].size() == 6);
- BOOST_REQUIRE(sortedOutput[3][0].second == 2);
- BOOST_REQUIRE_CLOSE(sortedOutput[3][0].first, (1.10 * 1.10), 1e-5);
- BOOST_REQUIRE(sortedOutput[3][1].second == 0);
- BOOST_REQUIRE_CLOSE(sortedOutput[3][1].first, (1.20 * 1.20), 1e-5);
- BOOST_REQUIRE(sortedOutput[3][2].second == 5);
- BOOST_REQUIRE_CLOSE(sortedOutput[3][2].first, (1.47 * 1.47), 1e-5);
- BOOST_REQUIRE(sortedOutput[3][3].second == 7);
- BOOST_REQUIRE_CLOSE(sortedOutput[3][3].first, (2.55 * 2.55), 1e-5);
- BOOST_REQUIRE(sortedOutput[3][4].second == 6);
- BOOST_REQUIRE_CLOSE(sortedOutput[3][4].first, (3.25 * 3.25), 1e-5);
- BOOST_REQUIRE(sortedOutput[3][5].second == 4);
- BOOST_REQUIRE_CLOSE(sortedOutput[3][5].first, (3.80 * 3.80), 1e-5);
-
- // Neighbors of point 4.
- BOOST_REQUIRE(sortedOutput[4].size() == 10);
- BOOST_REQUIRE(sortedOutput[4][0].second == 3);
- BOOST_REQUIRE_CLOSE(sortedOutput[4][0].first, (3.80 * 3.80), 1e-5);
- BOOST_REQUIRE(sortedOutput[4][1].second == 10);
- BOOST_REQUIRE_CLOSE(sortedOutput[4][1].first, (4.05 * 4.05), 1e-5);
- BOOST_REQUIRE(sortedOutput[4][2].second == 9);
- BOOST_REQUIRE_CLOSE(sortedOutput[4][2].first, (4.15 * 4.15), 1e-5);
- BOOST_REQUIRE(sortedOutput[4][3].second == 8);
- BOOST_REQUIRE_CLOSE(sortedOutput[4][3].first, (4.60 * 4.60), 1e-5);
- BOOST_REQUIRE(sortedOutput[4][4].second == 1);
- BOOST_REQUIRE_CLOSE(sortedOutput[4][4].first, (4.70 * 4.70), 1e-5);
- BOOST_REQUIRE(sortedOutput[4][5].second == 2);
- BOOST_REQUIRE_CLOSE(sortedOutput[4][5].first, (4.90 * 4.90), 1e-5);
- BOOST_REQUIRE(sortedOutput[4][6].second == 0);
- BOOST_REQUIRE_CLOSE(sortedOutput[4][6].first, (5.00 * 5.00), 1e-5);
- BOOST_REQUIRE(sortedOutput[4][7].second == 5);
- BOOST_REQUIRE_CLOSE(sortedOutput[4][7].first, (5.27 * 5.27), 1e-5);
- BOOST_REQUIRE(sortedOutput[4][8].second == 7);
- BOOST_REQUIRE_CLOSE(sortedOutput[4][8].first, (6.35 * 6.35), 1e-5);
- BOOST_REQUIRE(sortedOutput[4][9].second == 6);
- BOOST_REQUIRE_CLOSE(sortedOutput[4][9].first, (7.05 * 7.05), 1e-5);
-
- // Neighbors of point 5.
- BOOST_REQUIRE(sortedOutput[5].size() == 6);
- BOOST_REQUIRE(sortedOutput[5][0].second == 7);
- BOOST_REQUIRE_CLOSE(sortedOutput[5][0].first, (1.08 * 1.08), 1e-5);
- BOOST_REQUIRE(sortedOutput[5][1].second == 9);
- BOOST_REQUIRE_CLOSE(sortedOutput[5][1].first, (1.12 * 1.12), 1e-5);
- BOOST_REQUIRE(sortedOutput[5][2].second == 10);
- BOOST_REQUIRE_CLOSE(sortedOutput[5][2].first, (1.22 * 1.22), 1e-5);
- BOOST_REQUIRE(sortedOutput[5][3].second == 3);
- BOOST_REQUIRE_CLOSE(sortedOutput[5][3].first, (1.47 * 1.47), 1e-5);
- BOOST_REQUIRE(sortedOutput[5][4].second == 6);
- BOOST_REQUIRE_CLOSE(sortedOutput[5][4].first, (1.78 * 1.78), 1e-5);
- BOOST_REQUIRE(sortedOutput[5][5].second == 4);
- BOOST_REQUIRE_CLOSE(sortedOutput[5][5].first, (5.27 * 5.27), 1e-5);
-
- // Neighbors of point 6.
- BOOST_REQUIRE(sortedOutput[6].size() == 9);
- BOOST_REQUIRE(sortedOutput[6][0].second == 5);
- BOOST_REQUIRE_CLOSE(sortedOutput[6][0].first, (1.78 * 1.78), 1e-5);
- BOOST_REQUIRE(sortedOutput[6][1].second == 0);
- BOOST_REQUIRE_CLOSE(sortedOutput[6][1].first, (2.05 * 2.05), 1e-5);
- BOOST_REQUIRE(sortedOutput[6][2].second == 2);
- BOOST_REQUIRE_CLOSE(sortedOutput[6][2].first, (2.15 * 2.15), 1e-5);
- BOOST_REQUIRE(sortedOutput[6][3].second == 1);
- BOOST_REQUIRE_CLOSE(sortedOutput[6][3].first, (2.35 * 2.35), 1e-5);
- BOOST_REQUIRE(sortedOutput[6][4].second == 8);
- BOOST_REQUIRE_CLOSE(sortedOutput[6][4].first, (2.45 * 2.45), 1e-5);
- BOOST_REQUIRE(sortedOutput[6][5].second == 9);
- BOOST_REQUIRE_CLOSE(sortedOutput[6][5].first, (2.90 * 2.90), 1e-5);
- BOOST_REQUIRE(sortedOutput[6][6].second == 10);
- BOOST_REQUIRE_CLOSE(sortedOutput[6][6].first, (3.00 * 3.00), 1e-5);
- BOOST_REQUIRE(sortedOutput[6][7].second == 3);
- BOOST_REQUIRE_CLOSE(sortedOutput[6][7].first, (3.25 * 3.25), 1e-5);
- BOOST_REQUIRE(sortedOutput[6][8].second == 4);
- BOOST_REQUIRE_CLOSE(sortedOutput[6][8].first, (7.05 * 7.05), 1e-5);
-
- // Neighbors of point 7.
- BOOST_REQUIRE(sortedOutput[7].size() == 9);
- BOOST_REQUIRE(sortedOutput[7][0].second == 5);
- BOOST_REQUIRE_CLOSE(sortedOutput[7][0].first, (1.08 * 1.08), 1e-5);
- BOOST_REQUIRE(sortedOutput[7][1].second == 0);
- BOOST_REQUIRE_CLOSE(sortedOutput[7][1].first, (1.35 * 1.35), 1e-5);
- BOOST_REQUIRE(sortedOutput[7][2].second == 2);
- BOOST_REQUIRE_CLOSE(sortedOutput[7][2].first, (1.45 * 1.45), 1e-5);
- BOOST_REQUIRE(sortedOutput[7][3].second == 1);
- BOOST_REQUIRE_CLOSE(sortedOutput[7][3].first, (1.65 * 1.65), 1e-5);
- BOOST_REQUIRE(sortedOutput[7][4].second == 8);
- BOOST_REQUIRE_CLOSE(sortedOutput[7][4].first, (1.75 * 1.75), 1e-5);
- BOOST_REQUIRE(sortedOutput[7][5].second == 9);
- BOOST_REQUIRE_CLOSE(sortedOutput[7][5].first, (2.20 * 2.20), 1e-5);
- BOOST_REQUIRE(sortedOutput[7][6].second == 10);
- BOOST_REQUIRE_CLOSE(sortedOutput[7][6].first, (2.30 * 2.30), 1e-5);
- BOOST_REQUIRE(sortedOutput[7][7].second == 3);
- BOOST_REQUIRE_CLOSE(sortedOutput[7][7].first, (2.55 * 2.55), 1e-5);
- BOOST_REQUIRE(sortedOutput[7][8].second == 4);
- BOOST_REQUIRE_CLOSE(sortedOutput[7][8].first, (6.35 * 6.35), 1e-5);
-
- // Neighbors of point 8.
- BOOST_REQUIRE(sortedOutput[8].size() == 3);
- BOOST_REQUIRE(sortedOutput[8][0].second == 7);
- BOOST_REQUIRE_CLOSE(sortedOutput[8][0].first, (1.75 * 1.75), 1e-5);
- BOOST_REQUIRE(sortedOutput[8][1].second == 6);
- BOOST_REQUIRE_CLOSE(sortedOutput[8][1].first, (2.45 * 2.45), 1e-5);
- BOOST_REQUIRE(sortedOutput[8][2].second == 4);
- BOOST_REQUIRE_CLOSE(sortedOutput[8][2].first, (4.60 * 4.60), 1e-5);
-
- // Neighbors of point 9.
- BOOST_REQUIRE(sortedOutput[9].size() == 4);
- BOOST_REQUIRE(sortedOutput[9][0].second == 5);
- BOOST_REQUIRE_CLOSE(sortedOutput[9][0].first, (1.12 * 1.12), 1e-5);
- BOOST_REQUIRE(sortedOutput[9][1].second == 7);
- BOOST_REQUIRE_CLOSE(sortedOutput[9][1].first, (2.20 * 2.20), 1e-5);
- BOOST_REQUIRE(sortedOutput[9][2].second == 6);
- BOOST_REQUIRE_CLOSE(sortedOutput[9][2].first, (2.90 * 2.90), 1e-5);
- BOOST_REQUIRE(sortedOutput[9][3].second == 4);
- BOOST_REQUIRE_CLOSE(sortedOutput[9][3].first, (4.15 * 4.15), 1e-5);
-
- // Neighbors of point 10.
- BOOST_REQUIRE(sortedOutput[10].size() == 4);
- BOOST_REQUIRE(sortedOutput[10][0].second == 5);
- BOOST_REQUIRE_CLOSE(sortedOutput[10][0].first, (1.22 * 1.22), 1e-5);
- BOOST_REQUIRE(sortedOutput[10][1].second == 7);
- BOOST_REQUIRE_CLOSE(sortedOutput[10][1].first, (2.30 * 2.30), 1e-5);
- BOOST_REQUIRE(sortedOutput[10][2].second == 6);
- BOOST_REQUIRE_CLOSE(sortedOutput[10][2].first, (3.00 * 3.00), 1e-5);
- BOOST_REQUIRE(sortedOutput[10][3].second == 4);
- BOOST_REQUIRE_CLOSE(sortedOutput[10][3].first, (4.05 * 4.05), 1e-5);
-
- // Clean the memory.
- delete rs;
- }
-}
-
-/**
- * Test the dual-tree nearest-neighbors method with the naive method. This
- * uses both a query and reference dataset.
- *
- * Errors are produced if the results are not identical.
- */
-BOOST_AUTO_TEST_CASE(DualTreeVsNaive1)
-{
- arma::mat dataForTree;
-
- // Hard-coded filename: bad!
- if (!data::Load("test_data_3_1000.csv", dataForTree))
- BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
-
- // Set up matrices to work with.
- arma::mat dualQuery(dataForTree);
- arma::mat dualReferences(dataForTree);
- arma::mat naiveQuery(dataForTree);
- arma::mat naiveReferences(dataForTree);
-
- RangeSearch<> rs(dualQuery, dualReferences);
-
- RangeSearch<> naive(naiveQuery, naiveReferences, true);
-
- vector<vector<size_t> > neighborsTree;
- vector<vector<double> > distancesTree;
- rs.Search(Range(0.25, 1.05), neighborsTree, distancesTree);
- vector<vector<pair<double, size_t> > > sortedTree;
- SortResults(neighborsTree, distancesTree, sortedTree);
-
- vector<vector<size_t> > neighborsNaive;
- vector<vector<double> > distancesNaive;
- naive.Search(Range(0.25, 1.05), neighborsNaive, distancesNaive);
- vector<vector<pair<double, size_t> > > sortedNaive;
- SortResults(neighborsNaive, distancesNaive, sortedNaive);
-
- for (size_t i = 0; i < sortedTree.size(); i++)
- {
- BOOST_REQUIRE(sortedTree[i].size() == sortedNaive[i].size());
-
- for (size_t j = 0; j < sortedTree[i].size(); j++)
- {
- BOOST_REQUIRE(sortedTree[i][j].second == sortedNaive[i][j].second);
- BOOST_REQUIRE_CLOSE(sortedTree[i][j].first, sortedNaive[i][j].first,
- 1e-5);
- }
- }
-}
-
-/**
- * Test the dual-tree nearest-neighbors method with the naive method. This uses
- * only a reference dataset.
- *
- * Errors are produced if the results are not identical.
- */
-BOOST_AUTO_TEST_CASE(DualTreeVsNaive2)
-{
- arma::mat dataForTree;
-
- // Hard-coded filename: bad!
- // Code duplication: also bad!
- if (!data::Load("test_data_3_1000.csv", dataForTree))
- BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
-
- // Set up matrices to work with (may not be necessary with no ALIAS_MATRIX?).
- arma::mat dualQuery(dataForTree);
- arma::mat naiveQuery(dataForTree);
-
- RangeSearch<> rs(dualQuery);
-
- // Set naive mode.
- RangeSearch<> naive(naiveQuery, true);
-
- vector<vector<size_t> > neighborsTree;
- vector<vector<double> > distancesTree;
- rs.Search(Range(0.25, 1.05), neighborsTree, distancesTree);
- vector<vector<pair<double, size_t> > > sortedTree;
- SortResults(neighborsTree, distancesTree, sortedTree);
-
- vector<vector<size_t> > neighborsNaive;
- vector<vector<double> > distancesNaive;
- naive.Search(Range(0.25, 1.05), neighborsNaive, distancesNaive);
- vector<vector<pair<double, size_t> > > sortedNaive;
- SortResults(neighborsNaive, distancesNaive, sortedNaive);
-
- for (size_t i = 0; i < sortedTree.size(); i++)
- {
- BOOST_REQUIRE(sortedTree[i].size() == sortedNaive[i].size());
-
- for (size_t j = 0; j < sortedTree[i].size(); j++)
- {
- BOOST_REQUIRE(sortedTree[i][j].second == sortedNaive[i][j].second);
- BOOST_REQUIRE_CLOSE(sortedTree[i][j].first, sortedNaive[i][j].first,
- 1e-5);
- }
- }
-}
-
-/**
- * Test the single-tree nearest-neighbors method with the naive method. This
- * uses only a reference dataset.
- *
- * Errors are produced if the results are not identical.
- */
-BOOST_AUTO_TEST_CASE(SingleTreeVsNaive)
-{
- arma::mat dataForTree;
-
- // Hard-coded filename: bad!
- // Code duplication: also bad!
- if (!data::Load("test_data_3_1000.csv", dataForTree))
- BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
-
- // Set up matrices to work with (may not be necessary with no ALIAS_MATRIX?).
- arma::mat singleQuery(dataForTree);
- arma::mat naiveQuery(dataForTree);
-
- RangeSearch<> single(singleQuery, false, true);
-
- // Set up computation for naive mode.
- RangeSearch<> naive(naiveQuery, true);
-
- vector<vector<size_t> > neighborsSingle;
- vector<vector<double> > distancesSingle;
- single.Search(Range(0.25, 1.05), neighborsSingle, distancesSingle);
- vector<vector<pair<double, size_t> > > sortedTree;
- SortResults(neighborsSingle, distancesSingle, sortedTree);
-
- vector<vector<size_t> > neighborsNaive;
- vector<vector<double> > distancesNaive;
- naive.Search(Range(0.25, 1.05), neighborsNaive, distancesNaive);
- vector<vector<pair<double, size_t> > > sortedNaive;
- SortResults(neighborsNaive, distancesNaive, sortedNaive);
-
- for (size_t i = 0; i < sortedTree.size(); i++)
- {
- BOOST_REQUIRE(sortedTree[i].size() == sortedNaive[i].size());
-
- for (size_t j = 0; j < sortedTree[i].size(); j++)
- {
- BOOST_REQUIRE(sortedTree[i][j].second == sortedNaive[i][j].second);
- BOOST_REQUIRE_CLOSE(sortedTree[i][j].first, sortedNaive[i][j].first,
- 1e-5);
- }
- }
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/range_search_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/range_search_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/range_search_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/range_search_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,587 @@
+/**
+ * @file range_search_test.cpp
+ * @author Ryan Curtin
+ *
+ * Test file for RangeSearch<> class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/methods/range_search/range_search.hpp>
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::range;
+using namespace mlpack::math;
+using namespace std;
+
+BOOST_AUTO_TEST_SUITE(RangeSearchTest);
+
+// Get our results into a sorted format, so we can actually then test for
+// correctness.
+void SortResults(const vector<vector<size_t> >& neighbors,
+ const vector<vector<double> >& distances,
+ vector<vector<pair<double, size_t> > >& output)
+{
+ output.resize(neighbors.size());
+ for (size_t i = 0; i < neighbors.size(); i++)
+ {
+ output[i].resize(neighbors[i].size());
+ for (size_t j = 0; j < neighbors[i].size(); j++)
+ output[i][j] = make_pair(distances[i][j], neighbors[i][j]);
+
+ // Now that it's constructed, sort it.
+ sort(output[i].begin(), output[i].end());
+ }
+}
+
+/**
+ * Simple range-search test with small, synthetic dataset. This is an
+ * exhaustive test, which checks that each method for performing the calculation
+ * (dual-tree, single-tree, naive) produces the correct results. An
+ * eleven-point dataset and the points within three ranges are taken. The
+ * dataset is in one dimension for simplicity -- the correct functionality of
+ * distance functions is not tested here.
+ */
+BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
+{
+ // Set up our data.
+ arma::mat data(1, 11);
+ data[0] = 0.05; // Row addressing is unnecessary (they are all 0).
+ data[1] = 0.35;
+ data[2] = 0.15;
+ data[3] = 1.25;
+ data[4] = 5.05;
+ data[5] = -0.22;
+ data[6] = -2.00;
+ data[7] = -1.30;
+ data[8] = 0.45;
+ data[9] = 0.90;
+ data[10] = 1.00;
+
+ // We will loop through three times, one for each method of performing the
+ // calculation.
+ for (int i = 0; i < 3; i++)
+ {
+ RangeSearch<>* rs;
+ arma::mat dataMutable = data;
+ switch (i)
+ {
+ case 0: // Use the dual-tree method.
+ rs = new RangeSearch<>(dataMutable, false, false, 1);
+ break;
+ case 1: // Use the single-tree method.
+ rs = new RangeSearch<>(dataMutable, false, true, 1);
+ break;
+ case 2: // Use the naive method.
+ rs = new RangeSearch<>(dataMutable, true);
+ break;
+ }
+
+ // Now perform the first calculation. Points within 0.50.
+ vector<vector<size_t> > neighbors;
+ vector<vector<double> > distances;
+ rs->Search(Range(0.0, 0.50), neighbors, distances);
+
+ // Now the exhaustive check for correctness. This will be long. We must
+ // also remember that the distances returned are squared distances. As a
+ // result, distance comparisons are written out as (distance * distance) for
+ // readability.
+ vector<vector<pair<double, size_t> > > sortedOutput;
+ SortResults(neighbors, distances, sortedOutput);
+
+ // Neighbors of point 0.
+ BOOST_REQUIRE(sortedOutput[0].size() == 4);
+ BOOST_REQUIRE(sortedOutput[0][0].second == 2);
+ BOOST_REQUIRE_CLOSE(sortedOutput[0][0].first, (0.10 * 0.10), 1e-5);
+ BOOST_REQUIRE(sortedOutput[0][1].second == 5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[0][1].first, (0.27 * 0.27), 1e-5);
+ BOOST_REQUIRE(sortedOutput[0][2].second == 1);
+ BOOST_REQUIRE_CLOSE(sortedOutput[0][2].first, (0.30 * 0.30), 1e-5);
+ BOOST_REQUIRE(sortedOutput[0][3].second == 8);
+ BOOST_REQUIRE_CLOSE(sortedOutput[0][3].first, (0.40 * 0.40), 1e-5);
+
+ // Neighbors of point 1.
+ BOOST_REQUIRE(sortedOutput[1].size() == 6);
+ BOOST_REQUIRE(sortedOutput[1][0].second == 8);
+ BOOST_REQUIRE_CLOSE(sortedOutput[1][0].first, (0.10 * 0.10), 1e-5);
+ BOOST_REQUIRE(sortedOutput[1][1].second == 2);
+ BOOST_REQUIRE_CLOSE(sortedOutput[1][1].first, (0.20 * 0.20), 1e-5);
+ BOOST_REQUIRE(sortedOutput[1][2].second == 0);
+ BOOST_REQUIRE_CLOSE(sortedOutput[1][2].first, (0.30 * 0.30), 1e-5);
+ BOOST_REQUIRE(sortedOutput[1][3].second == 9);
+ BOOST_REQUIRE_CLOSE(sortedOutput[1][3].first, (0.55 * 0.55), 1e-5);
+ BOOST_REQUIRE(sortedOutput[1][4].second == 5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[1][4].first, (0.57 * 0.57), 1e-5);
+ BOOST_REQUIRE(sortedOutput[1][5].second == 10);
+ BOOST_REQUIRE_CLOSE(sortedOutput[1][5].first, (0.65 * 0.65), 1e-5);
+
+ // Neighbors of point 2.
+ BOOST_REQUIRE(sortedOutput[2].size() == 4);
+ BOOST_REQUIRE(sortedOutput[2][0].second == 0);
+ BOOST_REQUIRE_CLOSE(sortedOutput[2][0].first, (0.10 * 0.10), 1e-5);
+ BOOST_REQUIRE(sortedOutput[2][1].second == 1);
+ BOOST_REQUIRE_CLOSE(sortedOutput[2][1].first, (0.20 * 0.20), 1e-5);
+ BOOST_REQUIRE(sortedOutput[2][2].second == 8);
+ BOOST_REQUIRE_CLOSE(sortedOutput[2][2].first, (0.30 * 0.30), 1e-5);
+ BOOST_REQUIRE(sortedOutput[2][3].second == 5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[2][3].first, (0.37 * 0.37), 1e-5);
+
+ // Neighbors of point 3.
+ BOOST_REQUIRE(sortedOutput[3].size() == 2);
+ BOOST_REQUIRE(sortedOutput[3][0].second == 10);
+ BOOST_REQUIRE_CLOSE(sortedOutput[3][0].first, (0.25 * 0.25), 1e-5);
+ BOOST_REQUIRE(sortedOutput[3][1].second == 9);
+ BOOST_REQUIRE_CLOSE(sortedOutput[3][1].first, (0.35 * 0.35), 1e-5);
+
+ // Neighbors of point 4.
+ BOOST_REQUIRE(sortedOutput[4].size() == 0);
+
+ // Neighbors of point 5.
+ BOOST_REQUIRE(sortedOutput[5].size() == 4);
+ BOOST_REQUIRE(sortedOutput[5][0].second == 0);
+ BOOST_REQUIRE_CLOSE(sortedOutput[5][0].first, (0.27 * 0.27), 1e-5);
+ BOOST_REQUIRE(sortedOutput[5][1].second == 2);
+ BOOST_REQUIRE_CLOSE(sortedOutput[5][1].first, (0.37 * 0.37), 1e-5);
+ BOOST_REQUIRE(sortedOutput[5][2].second == 1);
+ BOOST_REQUIRE_CLOSE(sortedOutput[5][2].first, (0.57 * 0.57), 1e-5);
+ BOOST_REQUIRE(sortedOutput[5][3].second == 8);
+ BOOST_REQUIRE_CLOSE(sortedOutput[5][3].first, (0.67 * 0.67), 1e-5);
+
+ // Neighbors of point 6.
+ BOOST_REQUIRE(sortedOutput[6].size() == 1);
+ BOOST_REQUIRE(sortedOutput[6][0].second == 7);
+ BOOST_REQUIRE_CLOSE(sortedOutput[6][0].first, (0.70 * 0.70), 1e-5);
+
+ // Neighbors of point 7.
+ BOOST_REQUIRE(sortedOutput[7].size() == 1);
+ BOOST_REQUIRE(sortedOutput[7][0].second == 6);
+ BOOST_REQUIRE_CLOSE(sortedOutput[7][0].first, (0.70 * 0.70), 1e-5);
+
+ // Neighbors of point 8.
+ BOOST_REQUIRE(sortedOutput[8].size() == 6);
+ BOOST_REQUIRE(sortedOutput[8][0].second == 1);
+ BOOST_REQUIRE_CLOSE(sortedOutput[8][0].first, (0.10 * 0.10), 1e-5);
+ BOOST_REQUIRE(sortedOutput[8][1].second == 2);
+ BOOST_REQUIRE_CLOSE(sortedOutput[8][1].first, (0.30 * 0.30), 1e-5);
+ BOOST_REQUIRE(sortedOutput[8][2].second == 0);
+ BOOST_REQUIRE_CLOSE(sortedOutput[8][2].first, (0.40 * 0.40), 1e-5);
+ BOOST_REQUIRE(sortedOutput[8][3].second == 9);
+ BOOST_REQUIRE_CLOSE(sortedOutput[8][3].first, (0.45 * 0.45), 1e-5);
+ BOOST_REQUIRE(sortedOutput[8][4].second == 10);
+ BOOST_REQUIRE_CLOSE(sortedOutput[8][4].first, (0.55 * 0.55), 1e-5);
+ BOOST_REQUIRE(sortedOutput[8][5].second == 5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[8][5].first, (0.67 * 0.67), 1e-5);
+
+ // Neighbors of point 9.
+ BOOST_REQUIRE(sortedOutput[9].size() == 4);
+ BOOST_REQUIRE(sortedOutput[9][0].second == 10);
+ BOOST_REQUIRE_CLOSE(sortedOutput[9][0].first, (0.10 * 0.10), 1e-5);
+ BOOST_REQUIRE(sortedOutput[9][1].second == 3);
+ BOOST_REQUIRE_CLOSE(sortedOutput[9][1].first, (0.35 * 0.35), 1e-5);
+ BOOST_REQUIRE(sortedOutput[9][2].second == 8);
+ BOOST_REQUIRE_CLOSE(sortedOutput[9][2].first, (0.45 * 0.45), 1e-5);
+ BOOST_REQUIRE(sortedOutput[9][3].second == 1);
+ BOOST_REQUIRE_CLOSE(sortedOutput[9][3].first, (0.55 * 0.55), 1e-5);
+
+ // Neighbors of point 10.
+ BOOST_REQUIRE(sortedOutput[10].size() == 4);
+ BOOST_REQUIRE(sortedOutput[10][0].second == 9);
+ BOOST_REQUIRE_CLOSE(sortedOutput[10][0].first, (0.10 * 0.10), 1e-5);
+ BOOST_REQUIRE(sortedOutput[10][1].second == 3);
+ BOOST_REQUIRE_CLOSE(sortedOutput[10][1].first, (0.25 * 0.25), 1e-5);
+ BOOST_REQUIRE(sortedOutput[10][2].second == 8);
+ BOOST_REQUIRE_CLOSE(sortedOutput[10][2].first, (0.55 * 0.55), 1e-5);
+ BOOST_REQUIRE(sortedOutput[10][3].second == 1);
+ BOOST_REQUIRE_CLOSE(sortedOutput[10][3].first, (0.65 * 0.65), 1e-5);
+
+ // Now do it again with a different range: [0.5 1.0].
+ rs->Search(Range(0.5, 1.0), neighbors, distances);
+ SortResults(neighbors, distances, sortedOutput);
+
+ // Neighbors of point 0.
+ BOOST_REQUIRE(sortedOutput[0].size() == 2);
+ BOOST_REQUIRE(sortedOutput[0][0].second == 9);
+ BOOST_REQUIRE_CLOSE(sortedOutput[0][0].first, (0.85 * 0.85), 1e-5);
+ BOOST_REQUIRE(sortedOutput[0][1].second == 10);
+ BOOST_REQUIRE_CLOSE(sortedOutput[0][1].first, (0.95 * 0.95), 1e-5);
+
+ // Neighbors of point 1.
+ BOOST_REQUIRE(sortedOutput[1].size() == 1);
+ BOOST_REQUIRE(sortedOutput[1][0].second == 3);
+ BOOST_REQUIRE_CLOSE(sortedOutput[1][0].first, (0.90 * 0.90), 1e-5);
+
+ // Neighbors of point 2.
+ BOOST_REQUIRE(sortedOutput[2].size() == 2);
+ BOOST_REQUIRE(sortedOutput[2][0].second == 9);
+ BOOST_REQUIRE_CLOSE(sortedOutput[2][0].first, (0.75 * 0.75), 1e-5);
+ BOOST_REQUIRE(sortedOutput[2][1].second == 10);
+ BOOST_REQUIRE_CLOSE(sortedOutput[2][1].first, (0.85 * 0.85), 1e-5);
+
+ // Neighbors of point 3.
+ BOOST_REQUIRE(sortedOutput[3].size() == 2);
+ BOOST_REQUIRE(sortedOutput[3][0].second == 8);
+ BOOST_REQUIRE_CLOSE(sortedOutput[3][0].first, (0.80 * 0.80), 1e-5);
+ BOOST_REQUIRE(sortedOutput[3][1].second == 1);
+ BOOST_REQUIRE_CLOSE(sortedOutput[3][1].first, (0.90 * 0.90), 1e-5);
+
+ // Neighbors of point 4.
+ BOOST_REQUIRE(sortedOutput[4].size() == 0);
+
+ // Neighbors of point 5.
+ BOOST_REQUIRE(sortedOutput[5].size() == 0);
+
+ // Neighbors of point 6.
+ BOOST_REQUIRE(sortedOutput[6].size() == 0);
+
+ // Neighbors of point 7.
+ BOOST_REQUIRE(sortedOutput[7].size() == 0);
+
+ // Neighbors of point 8.
+ BOOST_REQUIRE(sortedOutput[8].size() == 1);
+ BOOST_REQUIRE(sortedOutput[8][0].second == 3);
+ BOOST_REQUIRE_CLOSE(sortedOutput[8][0].first, (0.80 * 0.80), 1e-5);
+
+ // Neighbors of point 9.
+ BOOST_REQUIRE(sortedOutput[9].size() == 2);
+ BOOST_REQUIRE(sortedOutput[9][0].second == 2);
+ BOOST_REQUIRE_CLOSE(sortedOutput[9][0].first, (0.75 * 0.75), 1e-5);
+ BOOST_REQUIRE(sortedOutput[9][1].second == 0);
+ BOOST_REQUIRE_CLOSE(sortedOutput[9][1].first, (0.85 * 0.85), 1e-5);
+
+ // Neighbors of point 10.
+ BOOST_REQUIRE(sortedOutput[10].size() == 2);
+ BOOST_REQUIRE(sortedOutput[10][0].second == 2);
+ BOOST_REQUIRE_CLOSE(sortedOutput[10][0].first, (0.85 * 0.85), 1e-5);
+ BOOST_REQUIRE(sortedOutput[10][1].second == 0);
+ BOOST_REQUIRE_CLOSE(sortedOutput[10][1].first, (0.95 * 0.95), 1e-5);
+
+ // Now do it again with a different range: [1.0 inf].
+ rs->Search(Range(1.0, numeric_limits<double>::infinity()), neighbors,
+ distances);
+ SortResults(neighbors, distances, sortedOutput);
+
+ // Neighbors of point 0.
+ BOOST_REQUIRE(sortedOutput[0].size() == 4);
+ BOOST_REQUIRE(sortedOutput[0][0].second == 3);
+ BOOST_REQUIRE_CLOSE(sortedOutput[0][0].first, (1.20 * 1.20), 1e-5);
+ BOOST_REQUIRE(sortedOutput[0][1].second == 7);
+ BOOST_REQUIRE_CLOSE(sortedOutput[0][1].first, (1.35 * 1.35), 1e-5);
+ BOOST_REQUIRE(sortedOutput[0][2].second == 6);
+ BOOST_REQUIRE_CLOSE(sortedOutput[0][2].first, (2.05 * 2.05), 1e-5);
+ BOOST_REQUIRE(sortedOutput[0][3].second == 4);
+ BOOST_REQUIRE_CLOSE(sortedOutput[0][3].first, (5.00 * 5.00), 1e-5);
+
+ // Neighbors of point 1.
+ BOOST_REQUIRE(sortedOutput[1].size() == 3);
+ BOOST_REQUIRE(sortedOutput[1][0].second == 7);
+ BOOST_REQUIRE_CLOSE(sortedOutput[1][0].first, (1.65 * 1.65), 1e-5);
+ BOOST_REQUIRE(sortedOutput[1][1].second == 6);
+ BOOST_REQUIRE_CLOSE(sortedOutput[1][1].first, (2.35 * 2.35), 1e-5);
+ BOOST_REQUIRE(sortedOutput[1][2].second == 4);
+ BOOST_REQUIRE_CLOSE(sortedOutput[1][2].first, (4.70 * 4.70), 1e-5);
+
+ // Neighbors of point 2.
+ BOOST_REQUIRE(sortedOutput[2].size() == 4);
+ BOOST_REQUIRE(sortedOutput[2][0].second == 3);
+ BOOST_REQUIRE_CLOSE(sortedOutput[2][0].first, (1.10 * 1.10), 1e-5);
+ BOOST_REQUIRE(sortedOutput[2][1].second == 7);
+ BOOST_REQUIRE_CLOSE(sortedOutput[2][1].first, (1.45 * 1.45), 1e-5);
+ BOOST_REQUIRE(sortedOutput[2][2].second == 6);
+ BOOST_REQUIRE_CLOSE(sortedOutput[2][2].first, (2.15 * 2.15), 1e-5);
+ BOOST_REQUIRE(sortedOutput[2][3].second == 4);
+ BOOST_REQUIRE_CLOSE(sortedOutput[2][3].first, (4.90 * 4.90), 1e-5);
+
+ // Neighbors of point 3.
+ BOOST_REQUIRE(sortedOutput[3].size() == 6);
+ BOOST_REQUIRE(sortedOutput[3][0].second == 2);
+ BOOST_REQUIRE_CLOSE(sortedOutput[3][0].first, (1.10 * 1.10), 1e-5);
+ BOOST_REQUIRE(sortedOutput[3][1].second == 0);
+ BOOST_REQUIRE_CLOSE(sortedOutput[3][1].first, (1.20 * 1.20), 1e-5);
+ BOOST_REQUIRE(sortedOutput[3][2].second == 5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[3][2].first, (1.47 * 1.47), 1e-5);
+ BOOST_REQUIRE(sortedOutput[3][3].second == 7);
+ BOOST_REQUIRE_CLOSE(sortedOutput[3][3].first, (2.55 * 2.55), 1e-5);
+ BOOST_REQUIRE(sortedOutput[3][4].second == 6);
+ BOOST_REQUIRE_CLOSE(sortedOutput[3][4].first, (3.25 * 3.25), 1e-5);
+ BOOST_REQUIRE(sortedOutput[3][5].second == 4);
+ BOOST_REQUIRE_CLOSE(sortedOutput[3][5].first, (3.80 * 3.80), 1e-5);
+
+ // Neighbors of point 4.
+ BOOST_REQUIRE(sortedOutput[4].size() == 10);
+ BOOST_REQUIRE(sortedOutput[4][0].second == 3);
+ BOOST_REQUIRE_CLOSE(sortedOutput[4][0].first, (3.80 * 3.80), 1e-5);
+ BOOST_REQUIRE(sortedOutput[4][1].second == 10);
+ BOOST_REQUIRE_CLOSE(sortedOutput[4][1].first, (4.05 * 4.05), 1e-5);
+ BOOST_REQUIRE(sortedOutput[4][2].second == 9);
+ BOOST_REQUIRE_CLOSE(sortedOutput[4][2].first, (4.15 * 4.15), 1e-5);
+ BOOST_REQUIRE(sortedOutput[4][3].second == 8);
+ BOOST_REQUIRE_CLOSE(sortedOutput[4][3].first, (4.60 * 4.60), 1e-5);
+ BOOST_REQUIRE(sortedOutput[4][4].second == 1);
+ BOOST_REQUIRE_CLOSE(sortedOutput[4][4].first, (4.70 * 4.70), 1e-5);
+ BOOST_REQUIRE(sortedOutput[4][5].second == 2);
+ BOOST_REQUIRE_CLOSE(sortedOutput[4][5].first, (4.90 * 4.90), 1e-5);
+ BOOST_REQUIRE(sortedOutput[4][6].second == 0);
+ BOOST_REQUIRE_CLOSE(sortedOutput[4][6].first, (5.00 * 5.00), 1e-5);
+ BOOST_REQUIRE(sortedOutput[4][7].second == 5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[4][7].first, (5.27 * 5.27), 1e-5);
+ BOOST_REQUIRE(sortedOutput[4][8].second == 7);
+ BOOST_REQUIRE_CLOSE(sortedOutput[4][8].first, (6.35 * 6.35), 1e-5);
+ BOOST_REQUIRE(sortedOutput[4][9].second == 6);
+ BOOST_REQUIRE_CLOSE(sortedOutput[4][9].first, (7.05 * 7.05), 1e-5);
+
+ // Neighbors of point 5.
+ BOOST_REQUIRE(sortedOutput[5].size() == 6);
+ BOOST_REQUIRE(sortedOutput[5][0].second == 7);
+ BOOST_REQUIRE_CLOSE(sortedOutput[5][0].first, (1.08 * 1.08), 1e-5);
+ BOOST_REQUIRE(sortedOutput[5][1].second == 9);
+ BOOST_REQUIRE_CLOSE(sortedOutput[5][1].first, (1.12 * 1.12), 1e-5);
+ BOOST_REQUIRE(sortedOutput[5][2].second == 10);
+ BOOST_REQUIRE_CLOSE(sortedOutput[5][2].first, (1.22 * 1.22), 1e-5);
+ BOOST_REQUIRE(sortedOutput[5][3].second == 3);
+ BOOST_REQUIRE_CLOSE(sortedOutput[5][3].first, (1.47 * 1.47), 1e-5);
+ BOOST_REQUIRE(sortedOutput[5][4].second == 6);
+ BOOST_REQUIRE_CLOSE(sortedOutput[5][4].first, (1.78 * 1.78), 1e-5);
+ BOOST_REQUIRE(sortedOutput[5][5].second == 4);
+ BOOST_REQUIRE_CLOSE(sortedOutput[5][5].first, (5.27 * 5.27), 1e-5);
+
+ // Neighbors of point 6.
+ BOOST_REQUIRE(sortedOutput[6].size() == 9);
+ BOOST_REQUIRE(sortedOutput[6][0].second == 5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[6][0].first, (1.78 * 1.78), 1e-5);
+ BOOST_REQUIRE(sortedOutput[6][1].second == 0);
+ BOOST_REQUIRE_CLOSE(sortedOutput[6][1].first, (2.05 * 2.05), 1e-5);
+ BOOST_REQUIRE(sortedOutput[6][2].second == 2);
+ BOOST_REQUIRE_CLOSE(sortedOutput[6][2].first, (2.15 * 2.15), 1e-5);
+ BOOST_REQUIRE(sortedOutput[6][3].second == 1);
+ BOOST_REQUIRE_CLOSE(sortedOutput[6][3].first, (2.35 * 2.35), 1e-5);
+ BOOST_REQUIRE(sortedOutput[6][4].second == 8);
+ BOOST_REQUIRE_CLOSE(sortedOutput[6][4].first, (2.45 * 2.45), 1e-5);
+ BOOST_REQUIRE(sortedOutput[6][5].second == 9);
+ BOOST_REQUIRE_CLOSE(sortedOutput[6][5].first, (2.90 * 2.90), 1e-5);
+ BOOST_REQUIRE(sortedOutput[6][6].second == 10);
+ BOOST_REQUIRE_CLOSE(sortedOutput[6][6].first, (3.00 * 3.00), 1e-5);
+ BOOST_REQUIRE(sortedOutput[6][7].second == 3);
+ BOOST_REQUIRE_CLOSE(sortedOutput[6][7].first, (3.25 * 3.25), 1e-5);
+ BOOST_REQUIRE(sortedOutput[6][8].second == 4);
+ BOOST_REQUIRE_CLOSE(sortedOutput[6][8].first, (7.05 * 7.05), 1e-5);
+
+ // Neighbors of point 7.
+ BOOST_REQUIRE(sortedOutput[7].size() == 9);
+ BOOST_REQUIRE(sortedOutput[7][0].second == 5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[7][0].first, (1.08 * 1.08), 1e-5);
+ BOOST_REQUIRE(sortedOutput[7][1].second == 0);
+ BOOST_REQUIRE_CLOSE(sortedOutput[7][1].first, (1.35 * 1.35), 1e-5);
+ BOOST_REQUIRE(sortedOutput[7][2].second == 2);
+ BOOST_REQUIRE_CLOSE(sortedOutput[7][2].first, (1.45 * 1.45), 1e-5);
+ BOOST_REQUIRE(sortedOutput[7][3].second == 1);
+ BOOST_REQUIRE_CLOSE(sortedOutput[7][3].first, (1.65 * 1.65), 1e-5);
+ BOOST_REQUIRE(sortedOutput[7][4].second == 8);
+ BOOST_REQUIRE_CLOSE(sortedOutput[7][4].first, (1.75 * 1.75), 1e-5);
+ BOOST_REQUIRE(sortedOutput[7][5].second == 9);
+ BOOST_REQUIRE_CLOSE(sortedOutput[7][5].first, (2.20 * 2.20), 1e-5);
+ BOOST_REQUIRE(sortedOutput[7][6].second == 10);
+ BOOST_REQUIRE_CLOSE(sortedOutput[7][6].first, (2.30 * 2.30), 1e-5);
+ BOOST_REQUIRE(sortedOutput[7][7].second == 3);
+ BOOST_REQUIRE_CLOSE(sortedOutput[7][7].first, (2.55 * 2.55), 1e-5);
+ BOOST_REQUIRE(sortedOutput[7][8].second == 4);
+ BOOST_REQUIRE_CLOSE(sortedOutput[7][8].first, (6.35 * 6.35), 1e-5);
+
+ // Neighbors of point 8.
+ BOOST_REQUIRE(sortedOutput[8].size() == 3);
+ BOOST_REQUIRE(sortedOutput[8][0].second == 7);
+ BOOST_REQUIRE_CLOSE(sortedOutput[8][0].first, (1.75 * 1.75), 1e-5);
+ BOOST_REQUIRE(sortedOutput[8][1].second == 6);
+ BOOST_REQUIRE_CLOSE(sortedOutput[8][1].first, (2.45 * 2.45), 1e-5);
+ BOOST_REQUIRE(sortedOutput[8][2].second == 4);
+ BOOST_REQUIRE_CLOSE(sortedOutput[8][2].first, (4.60 * 4.60), 1e-5);
+
+ // Neighbors of point 9.
+ BOOST_REQUIRE(sortedOutput[9].size() == 4);
+ BOOST_REQUIRE(sortedOutput[9][0].second == 5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[9][0].first, (1.12 * 1.12), 1e-5);
+ BOOST_REQUIRE(sortedOutput[9][1].second == 7);
+ BOOST_REQUIRE_CLOSE(sortedOutput[9][1].first, (2.20 * 2.20), 1e-5);
+ BOOST_REQUIRE(sortedOutput[9][2].second == 6);
+ BOOST_REQUIRE_CLOSE(sortedOutput[9][2].first, (2.90 * 2.90), 1e-5);
+ BOOST_REQUIRE(sortedOutput[9][3].second == 4);
+ BOOST_REQUIRE_CLOSE(sortedOutput[9][3].first, (4.15 * 4.15), 1e-5);
+
+ // Neighbors of point 10.
+ BOOST_REQUIRE(sortedOutput[10].size() == 4);
+ BOOST_REQUIRE(sortedOutput[10][0].second == 5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[10][0].first, (1.22 * 1.22), 1e-5);
+ BOOST_REQUIRE(sortedOutput[10][1].second == 7);
+ BOOST_REQUIRE_CLOSE(sortedOutput[10][1].first, (2.30 * 2.30), 1e-5);
+ BOOST_REQUIRE(sortedOutput[10][2].second == 6);
+ BOOST_REQUIRE_CLOSE(sortedOutput[10][2].first, (3.00 * 3.00), 1e-5);
+ BOOST_REQUIRE(sortedOutput[10][3].second == 4);
+ BOOST_REQUIRE_CLOSE(sortedOutput[10][3].first, (4.05 * 4.05), 1e-5);
+
+ // Clean the memory.
+ delete rs;
+ }
+}
+
+/**
+ * Test the dual-tree nearest-neighbors method with the naive method. This
+ * uses both a query and reference dataset.
+ *
+ * Errors are produced if the results are not identical.
+ */
+BOOST_AUTO_TEST_CASE(DualTreeVsNaive1)
+{
+ arma::mat dataForTree;
+
+ // Hard-coded filename: bad!
+ if (!data::Load("test_data_3_1000.csv", dataForTree))
+ BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
+
+ // Set up matrices to work with.
+ arma::mat dualQuery(dataForTree);
+ arma::mat dualReferences(dataForTree);
+ arma::mat naiveQuery(dataForTree);
+ arma::mat naiveReferences(dataForTree);
+
+ RangeSearch<> rs(dualQuery, dualReferences);
+
+ RangeSearch<> naive(naiveQuery, naiveReferences, true);
+
+ vector<vector<size_t> > neighborsTree;
+ vector<vector<double> > distancesTree;
+ rs.Search(Range(0.25, 1.05), neighborsTree, distancesTree);
+ vector<vector<pair<double, size_t> > > sortedTree;
+ SortResults(neighborsTree, distancesTree, sortedTree);
+
+ vector<vector<size_t> > neighborsNaive;
+ vector<vector<double> > distancesNaive;
+ naive.Search(Range(0.25, 1.05), neighborsNaive, distancesNaive);
+ vector<vector<pair<double, size_t> > > sortedNaive;
+ SortResults(neighborsNaive, distancesNaive, sortedNaive);
+
+ for (size_t i = 0; i < sortedTree.size(); i++)
+ {
+ BOOST_REQUIRE(sortedTree[i].size() == sortedNaive[i].size());
+
+ for (size_t j = 0; j < sortedTree[i].size(); j++)
+ {
+ BOOST_REQUIRE(sortedTree[i][j].second == sortedNaive[i][j].second);
+ BOOST_REQUIRE_CLOSE(sortedTree[i][j].first, sortedNaive[i][j].first,
+ 1e-5);
+ }
+ }
+}
+
+/**
+ * Test the dual-tree nearest-neighbors method with the naive method. This uses
+ * only a reference dataset.
+ *
+ * Errors are produced if the results are not identical.
+ */
+BOOST_AUTO_TEST_CASE(DualTreeVsNaive2)
+{
+ arma::mat dataForTree;
+
+ // Hard-coded filename: bad!
+ // Code duplication: also bad!
+ if (!data::Load("test_data_3_1000.csv", dataForTree))
+ BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
+
+ // Set up matrices to work with (may not be necessary with no ALIAS_MATRIX?).
+ arma::mat dualQuery(dataForTree);
+ arma::mat naiveQuery(dataForTree);
+
+ RangeSearch<> rs(dualQuery);
+
+ // Set naive mode.
+ RangeSearch<> naive(naiveQuery, true);
+
+ vector<vector<size_t> > neighborsTree;
+ vector<vector<double> > distancesTree;
+ rs.Search(Range(0.25, 1.05), neighborsTree, distancesTree);
+ vector<vector<pair<double, size_t> > > sortedTree;
+ SortResults(neighborsTree, distancesTree, sortedTree);
+
+ vector<vector<size_t> > neighborsNaive;
+ vector<vector<double> > distancesNaive;
+ naive.Search(Range(0.25, 1.05), neighborsNaive, distancesNaive);
+ vector<vector<pair<double, size_t> > > sortedNaive;
+ SortResults(neighborsNaive, distancesNaive, sortedNaive);
+
+ for (size_t i = 0; i < sortedTree.size(); i++)
+ {
+ BOOST_REQUIRE(sortedTree[i].size() == sortedNaive[i].size());
+
+ for (size_t j = 0; j < sortedTree[i].size(); j++)
+ {
+ BOOST_REQUIRE(sortedTree[i][j].second == sortedNaive[i][j].second);
+ BOOST_REQUIRE_CLOSE(sortedTree[i][j].first, sortedNaive[i][j].first,
+ 1e-5);
+ }
+ }
+}
+
+/**
+ * Test the single-tree nearest-neighbors method with the naive method. This
+ * uses only a reference dataset.
+ *
+ * Errors are produced if the results are not identical.
+ */
+BOOST_AUTO_TEST_CASE(SingleTreeVsNaive)
+{
+ arma::mat dataForTree;
+
+ // Hard-coded filename: bad!
+ // Code duplication: also bad!
+ if (!data::Load("test_data_3_1000.csv", dataForTree))
+ BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
+
+ // Set up matrices to work with (may not be necessary with no ALIAS_MATRIX?).
+ arma::mat singleQuery(dataForTree);
+ arma::mat naiveQuery(dataForTree);
+
+ RangeSearch<> single(singleQuery, false, true);
+
+ // Set up computation for naive mode.
+ RangeSearch<> naive(naiveQuery, true);
+
+ vector<vector<size_t> > neighborsSingle;
+ vector<vector<double> > distancesSingle;
+ single.Search(Range(0.25, 1.05), neighborsSingle, distancesSingle);
+ vector<vector<pair<double, size_t> > > sortedTree;
+ SortResults(neighborsSingle, distancesSingle, sortedTree);
+
+ vector<vector<size_t> > neighborsNaive;
+ vector<vector<double> > distancesNaive;
+ naive.Search(Range(0.25, 1.05), neighborsNaive, distancesNaive);
+ vector<vector<pair<double, size_t> > > sortedNaive;
+ SortResults(neighborsNaive, distancesNaive, sortedNaive);
+
+ for (size_t i = 0; i < sortedTree.size(); i++)
+ {
+ BOOST_REQUIRE(sortedTree[i].size() == sortedNaive[i].size());
+
+ for (size_t j = 0; j < sortedTree[i].size(); j++)
+ {
+ BOOST_REQUIRE(sortedTree[i][j].second == sortedNaive[i][j].second);
+ BOOST_REQUIRE_CLOSE(sortedTree[i][j].first, sortedNaive[i][j].first,
+ 1e-5);
+ }
+ }
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/save_restore_utility_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/save_restore_utility_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/save_restore_utility_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,189 +0,0 @@
-/**
- * @file save_restore_model_test.cpp
- * @author Neil Slagle
- *
- * Here we have tests for the SaveRestoreModel class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core/util/save_restore_utility.hpp>
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-#define ARGSTR(a) a,#a
-
-using namespace mlpack;
-using namespace mlpack::util;
-
-BOOST_AUTO_TEST_SUITE(SaveRestoreUtilityTests);
-
-/*
- * Exhibit proper save restore utility usage of child class proper usage.
- */
-class SaveRestoreTest
-{
- private:
- size_t anInt;
- SaveRestoreUtility saveRestore;
-
- public:
- SaveRestoreTest()
- {
- saveRestore = SaveRestoreUtility();
- anInt = 0;
- }
-
- bool SaveModel(std::string filename)
- {
- saveRestore.SaveParameter(anInt, "anInt");
- return saveRestore.WriteFile(filename);
- }
-
- bool LoadModel(std::string filename)
- {
- bool success = saveRestore.ReadFile(filename);
- if (success)
- anInt = saveRestore.LoadParameter(anInt, "anInt");
-
- return success;
- }
-
- size_t AnInt() { return anInt; }
- void AnInt(size_t s) { this->anInt = s; }
-};
-
-/**
- * Perform a save and restore on basic types.
- */
-BOOST_AUTO_TEST_CASE(SaveBasicTypes)
-{
- bool b = false;
- char c = 67;
- unsigned u = 34;
- size_t s = 12;
- short sh = 100;
- int i = -23;
- float f = -2.34f;
- double d = 3.14159;
- std::string cc = "Hello world!";
-
- SaveRestoreUtility* sRM = new SaveRestoreUtility();
-
- sRM->SaveParameter(ARGSTR(b));
- sRM->SaveParameter(ARGSTR(c));
- sRM->SaveParameter(ARGSTR(u));
- sRM->SaveParameter(ARGSTR(s));
- sRM->SaveParameter(ARGSTR(sh));
- sRM->SaveParameter(ARGSTR(i));
- sRM->SaveParameter(ARGSTR(f));
- sRM->SaveParameter(ARGSTR(d));
- sRM->SaveParameter(ARGSTR(cc));
- sRM->WriteFile("test_basic_types.xml");
-
- sRM->ReadFile("test_basic_types.xml");
-
- bool b2 = sRM->LoadParameter(ARGSTR(b));
- char c2 = sRM->LoadParameter(ARGSTR(c));
- unsigned u2 = sRM->LoadParameter(ARGSTR(u));
- size_t s2 = sRM->LoadParameter(ARGSTR(s));
- short sh2 = sRM->LoadParameter(ARGSTR(sh));
- int i2 = sRM->LoadParameter(ARGSTR(i));
- float f2 = sRM->LoadParameter(ARGSTR(f));
- double d2 = sRM->LoadParameter(ARGSTR(d));
- std::string cc2 = sRM->LoadParameter(ARGSTR(cc));
-
- BOOST_REQUIRE(b == b2);
- BOOST_REQUIRE(c == c2);
- BOOST_REQUIRE(u == u2);
- BOOST_REQUIRE(s == s2);
- BOOST_REQUIRE(sh == sh2);
- BOOST_REQUIRE(i == i2);
- BOOST_REQUIRE(cc == cc2);
- BOOST_REQUIRE_CLOSE(f, f2, 1e-5);
- BOOST_REQUIRE_CLOSE(d, d2, 1e-5);
-
- delete sRM;
-}
-
-BOOST_AUTO_TEST_CASE(SaveRestoreStdVector)
-{
- size_t numbers[] = {0,3,6,2,6};
- std::vector<size_t> vec (numbers,
- numbers + sizeof (numbers) / sizeof (size_t));
- SaveRestoreUtility* sRM = new SaveRestoreUtility();
-
- sRM->SaveParameter(ARGSTR(vec));
-
- sRM->WriteFile("test_std_vector_type.xml");
-
- sRM->ReadFile("test_std_vector_type.xml");
-
- std::vector<size_t> loadee = sRM->LoadParameter(ARGSTR(vec));
-
- for (size_t index = 0; index < loadee.size(); ++index)
- BOOST_REQUIRE_EQUAL(numbers[index], loadee[index]);
-}
-
-/**
- * Test the arma::mat functionality.
- */
-BOOST_AUTO_TEST_CASE(SaveArmaMat)
-{
- arma::mat matrix;
- matrix << 1.2 << 2.3 << -0.1 << arma::endr
- << 3.5 << 2.4 << -1.2 << arma::endr
- << -0.1 << 3.4 << -7.8 << arma::endr;
-
- SaveRestoreUtility* sRM = new SaveRestoreUtility();
-
- sRM->SaveParameter(ARGSTR(matrix));
-
- sRM->WriteFile("test_arma_mat_type.xml");
-
- sRM->ReadFile("test_arma_mat_type.xml");
-
- arma::mat matrix2 = sRM->LoadParameter(ARGSTR(matrix));
-
- for (size_t row = 0; row < matrix.n_rows; ++row)
- for (size_t column = 0; column < matrix.n_cols; ++column)
- BOOST_REQUIRE_CLOSE(matrix(row,column), matrix2(row,column), 1e-5);
-
- delete sRM;
-}
-
-/**
- * Test SaveRestoreModel proper usage in child classes and loading from
- * separately defined objects
- */
-BOOST_AUTO_TEST_CASE(SaveRestoreModelChildClassUsage)
-{
- SaveRestoreTest* saver = new SaveRestoreTest();
- SaveRestoreTest* loader = new SaveRestoreTest();
- size_t s = 1200;
- const char* filename = "anInt.xml";
-
- saver->AnInt(s);
- saver->SaveModel(filename);
- delete saver;
-
- loader->LoadModel(filename);
-
- BOOST_REQUIRE(loader->AnInt() == s);
-
- delete loader;
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/save_restore_utility_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/save_restore_utility_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/save_restore_utility_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/save_restore_utility_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,189 @@
+/**
+ * @file save_restore_model_test.cpp
+ * @author Neil Slagle
+ *
+ * Here we have tests for the SaveRestoreModel class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core/util/save_restore_utility.hpp>
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+#define ARGSTR(a) a,#a
+
+using namespace mlpack;
+using namespace mlpack::util;
+
+BOOST_AUTO_TEST_SUITE(SaveRestoreUtilityTests);
+
+/*
+ * Exhibit proper save restore utility usage of child class proper usage.
+ */
+class SaveRestoreTest
+{
+ private:
+ size_t anInt;
+ SaveRestoreUtility saveRestore;
+
+ public:
+ SaveRestoreTest()
+ {
+ saveRestore = SaveRestoreUtility();
+ anInt = 0;
+ }
+
+ bool SaveModel(std::string filename)
+ {
+ saveRestore.SaveParameter(anInt, "anInt");
+ return saveRestore.WriteFile(filename);
+ }
+
+ bool LoadModel(std::string filename)
+ {
+ bool success = saveRestore.ReadFile(filename);
+ if (success)
+ anInt = saveRestore.LoadParameter(anInt, "anInt");
+
+ return success;
+ }
+
+ size_t AnInt() { return anInt; }
+ void AnInt(size_t s) { this->anInt = s; }
+};
+
+/**
+ * Perform a save and restore on basic types.
+ */
+BOOST_AUTO_TEST_CASE(SaveBasicTypes)
+{
+ bool b = false;
+ char c = 67;
+ unsigned u = 34;
+ size_t s = 12;
+ short sh = 100;
+ int i = -23;
+ float f = -2.34f;
+ double d = 3.14159;
+ std::string cc = "Hello world!";
+
+ SaveRestoreUtility* sRM = new SaveRestoreUtility();
+
+ sRM->SaveParameter(ARGSTR(b));
+ sRM->SaveParameter(ARGSTR(c));
+ sRM->SaveParameter(ARGSTR(u));
+ sRM->SaveParameter(ARGSTR(s));
+ sRM->SaveParameter(ARGSTR(sh));
+ sRM->SaveParameter(ARGSTR(i));
+ sRM->SaveParameter(ARGSTR(f));
+ sRM->SaveParameter(ARGSTR(d));
+ sRM->SaveParameter(ARGSTR(cc));
+ sRM->WriteFile("test_basic_types.xml");
+
+ sRM->ReadFile("test_basic_types.xml");
+
+ bool b2 = sRM->LoadParameter(ARGSTR(b));
+ char c2 = sRM->LoadParameter(ARGSTR(c));
+ unsigned u2 = sRM->LoadParameter(ARGSTR(u));
+ size_t s2 = sRM->LoadParameter(ARGSTR(s));
+ short sh2 = sRM->LoadParameter(ARGSTR(sh));
+ int i2 = sRM->LoadParameter(ARGSTR(i));
+ float f2 = sRM->LoadParameter(ARGSTR(f));
+ double d2 = sRM->LoadParameter(ARGSTR(d));
+ std::string cc2 = sRM->LoadParameter(ARGSTR(cc));
+
+ BOOST_REQUIRE(b == b2);
+ BOOST_REQUIRE(c == c2);
+ BOOST_REQUIRE(u == u2);
+ BOOST_REQUIRE(s == s2);
+ BOOST_REQUIRE(sh == sh2);
+ BOOST_REQUIRE(i == i2);
+ BOOST_REQUIRE(cc == cc2);
+ BOOST_REQUIRE_CLOSE(f, f2, 1e-5);
+ BOOST_REQUIRE_CLOSE(d, d2, 1e-5);
+
+ delete sRM;
+}
+
+BOOST_AUTO_TEST_CASE(SaveRestoreStdVector)
+{
+ size_t numbers[] = {0,3,6,2,6};
+ std::vector<size_t> vec (numbers,
+ numbers + sizeof (numbers) / sizeof (size_t));
+ SaveRestoreUtility* sRM = new SaveRestoreUtility();
+
+ sRM->SaveParameter(ARGSTR(vec));
+
+ sRM->WriteFile("test_std_vector_type.xml");
+
+ sRM->ReadFile("test_std_vector_type.xml");
+
+ std::vector<size_t> loadee = sRM->LoadParameter(ARGSTR(vec));
+
+ for (size_t index = 0; index < loadee.size(); ++index)
+ BOOST_REQUIRE_EQUAL(numbers[index], loadee[index]);
+}
+
+/**
+ * Test the arma::mat functionality.
+ */
+BOOST_AUTO_TEST_CASE(SaveArmaMat)
+{
+ arma::mat matrix;
+ matrix << 1.2 << 2.3 << -0.1 << arma::endr
+ << 3.5 << 2.4 << -1.2 << arma::endr
+ << -0.1 << 3.4 << -7.8 << arma::endr;
+
+ SaveRestoreUtility* sRM = new SaveRestoreUtility();
+
+ sRM->SaveParameter(ARGSTR(matrix));
+
+ sRM->WriteFile("test_arma_mat_type.xml");
+
+ sRM->ReadFile("test_arma_mat_type.xml");
+
+ arma::mat matrix2 = sRM->LoadParameter(ARGSTR(matrix));
+
+ for (size_t row = 0; row < matrix.n_rows; ++row)
+ for (size_t column = 0; column < matrix.n_cols; ++column)
+ BOOST_REQUIRE_CLOSE(matrix(row,column), matrix2(row,column), 1e-5);
+
+ delete sRM;
+}
+
+/**
+ * Test SaveRestoreModel proper usage in child classes and loading from
+ * separately defined objects
+ */
+BOOST_AUTO_TEST_CASE(SaveRestoreModelChildClassUsage)
+{
+ SaveRestoreTest* saver = new SaveRestoreTest();
+ SaveRestoreTest* loader = new SaveRestoreTest();
+ size_t s = 1200;
+ const char* filename = "anInt.xml";
+
+ saver->AnInt(s);
+ saver->SaveModel(filename);
+ delete saver;
+
+ loader->LoadModel(filename);
+
+ BOOST_REQUIRE(loader->AnInt() == s);
+
+ delete loader;
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/sgd_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/sgd_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/sgd_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,71 +0,0 @@
-/**
- * @file sgd_test.cpp
- * @author Ryan Curtin
- *
- * Test file for SGD (stochastic gradient descent).
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/core/optimizers/sgd/sgd.hpp>
-#include <mlpack/core/optimizers/lbfgs/test_functions.hpp>
-#include <mlpack/core/optimizers/sgd/test_function.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace std;
-using namespace arma;
-using namespace mlpack;
-using namespace mlpack::optimization;
-using namespace mlpack::optimization::test;
-
-BOOST_AUTO_TEST_SUITE(SGDTest);
-
-BOOST_AUTO_TEST_CASE(SimpleSGDTestFunction)
-{
- SGDTestFunction f;
- SGD<SGDTestFunction> s(f, 0.0003, 5000000, 1e-9, true);
-
- arma::mat coordinates = f.GetInitialPoint();
- double result = s.Optimize(coordinates);
-
- BOOST_REQUIRE_CLOSE(result, -1.0, 0.05);
- BOOST_REQUIRE_SMALL(coordinates[0], 1e-3);
- BOOST_REQUIRE_SMALL(coordinates[1], 1e-7);
- BOOST_REQUIRE_SMALL(coordinates[2], 1e-7);
-}
-
-BOOST_AUTO_TEST_CASE(GeneralizedRosenbrockTest)
-{
- // Loop over several variants.
- for (size_t i = 10; i < 50; i += 5)
- {
- // Create the generalized Rosenbrock function.
- GeneralizedRosenbrockFunction f(i);
-
- SGD<GeneralizedRosenbrockFunction> s(f, 0.001, 0, 1e-15, true);
-
- arma::mat coordinates = f.GetInitialPoint();
- double result = s.Optimize(coordinates);
-
- BOOST_REQUIRE_SMALL(result, 1e-10);
- for (size_t j = 0; j < i; ++j)
- BOOST_REQUIRE_CLOSE(coordinates[j], (double) 1.0, 1e-3);
- }
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/sgd_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/sgd_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/sgd_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/sgd_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,71 @@
+/**
+ * @file sgd_test.cpp
+ * @author Ryan Curtin
+ *
+ * Test file for SGD (stochastic gradient descent).
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/optimizers/sgd/sgd.hpp>
+#include <mlpack/core/optimizers/lbfgs/test_functions.hpp>
+#include <mlpack/core/optimizers/sgd/test_function.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace std;
+using namespace arma;
+using namespace mlpack;
+using namespace mlpack::optimization;
+using namespace mlpack::optimization::test;
+
+BOOST_AUTO_TEST_SUITE(SGDTest);
+
+BOOST_AUTO_TEST_CASE(SimpleSGDTestFunction)
+{
+ SGDTestFunction f;
+ SGD<SGDTestFunction> s(f, 0.0003, 5000000, 1e-9, true);
+
+ arma::mat coordinates = f.GetInitialPoint();
+ double result = s.Optimize(coordinates);
+
+ BOOST_REQUIRE_CLOSE(result, -1.0, 0.05);
+ BOOST_REQUIRE_SMALL(coordinates[0], 1e-3);
+ BOOST_REQUIRE_SMALL(coordinates[1], 1e-7);
+ BOOST_REQUIRE_SMALL(coordinates[2], 1e-7);
+}
+
+BOOST_AUTO_TEST_CASE(GeneralizedRosenbrockTest)
+{
+ // Loop over several variants.
+ for (size_t i = 10; i < 50; i += 5)
+ {
+ // Create the generalized Rosenbrock function.
+ GeneralizedRosenbrockFunction f(i);
+
+ SGD<GeneralizedRosenbrockFunction> s(f, 0.001, 0, 1e-15, true);
+
+ arma::mat coordinates = f.GetInitialPoint();
+ double result = s.Optimize(coordinates);
+
+ BOOST_REQUIRE_SMALL(result, 1e-10);
+ for (size_t j = 0; j < i; ++j)
+ BOOST_REQUIRE_CLOSE(coordinates[j], (double) 1.0, 1e-3);
+ }
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/sort_policy_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/sort_policy_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/sort_policy_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,351 +0,0 @@
-/**
- * @file sort_policy_test.cpp
- * @author Ryan Curtin
- *
- * Tests for each of the implementations of the SortPolicy class.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/core/tree/binary_space_tree.hpp>
-
-// Classes to test.
-#include <mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp>
-#include <mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::neighbor;
-using namespace mlpack::bound;
-
-BOOST_AUTO_TEST_SUITE(SortPolicyTest);
-
-// Tests for NearestNeighborSort
-
-/**
- * Ensure the best distance for nearest neighbors is 0.
- */
-BOOST_AUTO_TEST_CASE(NnsBestDistance)
-{
- BOOST_REQUIRE(NearestNeighborSort::BestDistance() == 0);
-}
-
-/**
- * Ensure the worst distance for nearest neighbors is DBL_MAX.
- */
-BOOST_AUTO_TEST_CASE(NnsWorstDistance)
-{
- BOOST_REQUIRE(NearestNeighborSort::WorstDistance() == DBL_MAX);
-}
-
-/**
- * Make sure the comparison works for values strictly less than the reference.
- */
-BOOST_AUTO_TEST_CASE(NnsIsBetterStrict)
-{
- BOOST_REQUIRE(NearestNeighborSort::IsBetter(5.0, 6.0) == true);
-}
-
-/**
- * Warn in case the comparison is not strict.
- */
-BOOST_AUTO_TEST_CASE(NnsIsBetterNotStrict)
-{
- BOOST_WARN(NearestNeighborSort::IsBetter(6.0, 6.0) == true);
-}
-
-/**
- * A simple test case of where to insert when all the values in the list are
- * DBL_MAX.
- */
-BOOST_AUTO_TEST_CASE(NnsSortDistanceAllDblMax)
-{
- arma::vec list(5);
- list.fill(DBL_MAX);
-
- // Should be inserted at the head of the list.
- BOOST_REQUIRE(NearestNeighborSort::SortDistance(list, 5.0) == 0);
-}
-
-/**
- * Another test case, where we are just putting the new value in the middle of
- * the list.
- */
-BOOST_AUTO_TEST_CASE(NnsSortDistance2)
-{
- arma::vec list(3);
- list[0] = 0.66;
- list[1] = 0.89;
- list[2] = 1.14;
-
- // Run a couple possibilities through.
- BOOST_REQUIRE(NearestNeighborSort::SortDistance(list, 0.61) == 0);
- BOOST_REQUIRE(NearestNeighborSort::SortDistance(list, 0.76) == 1);
- BOOST_REQUIRE(NearestNeighborSort::SortDistance(list, 0.99) == 2);
- BOOST_REQUIRE(NearestNeighborSort::SortDistance(list, 1.22) ==
- (size_t() - 1));
-}
-
-/**
- * Very simple sanity check to ensure that bounds are working alright. We will
- * use a one-dimensional bound for simplicity.
- */
-BOOST_AUTO_TEST_CASE(NnsNodeToNodeDistance)
-{
- // Well, there's no easy way to make HRectBounds the way we want, so we have
- // to make them and then expand the region to include new points.
- arma::mat dataset("1");
- tree::BinarySpaceTree<HRectBound<2>, tree::EmptyStatistic, arma::mat>
- nodeOne(dataset);
- arma::vec utility(1);
- utility[0] = 0;
-
- nodeOne.Bound() = HRectBound<2>(1);
- nodeOne.Bound() |= utility;
- utility[0] = 1;
- nodeOne.Bound() |= utility;
-
- tree::BinarySpaceTree<HRectBound<2>, tree::EmptyStatistic, arma::mat>
- nodeTwo(dataset);
- nodeTwo.Bound() = HRectBound<2>(1);
-
- utility[0] = 5;
- nodeTwo.Bound() |= utility;
- utility[0] = 6;
- nodeTwo.Bound() |= utility;
-
- // This should use the L2 distance.
- BOOST_REQUIRE_CLOSE(NearestNeighborSort::BestNodeToNodeDistance(&nodeOne,
- &nodeTwo), 4.0, 1e-5);
-
- // And another just to be sure, from the other side.
- nodeTwo.Bound().Clear();
- utility[0] = -2;
- nodeTwo.Bound() |= utility;
- utility[0] = -1;
- nodeTwo.Bound() |= utility;
-
- // Again, the distance is the L2 distance.
- BOOST_REQUIRE_CLOSE(NearestNeighborSort::BestNodeToNodeDistance(&nodeOne,
- &nodeTwo), 1.0, 1e-5);
-
- // Now, when the bounds overlap.
- nodeTwo.Bound().Clear();
- utility[0] = -0.5;
- nodeTwo.Bound() |= utility;
- utility[0] = 0.5;
- nodeTwo.Bound() |= utility;
-
- BOOST_REQUIRE_SMALL(NearestNeighborSort::BestNodeToNodeDistance(&nodeOne,
- &nodeTwo), 1e-5);
-}
-
-/**
- * Another very simple sanity check for the point-to-node case, again in one
- * dimension.
- */
-BOOST_AUTO_TEST_CASE(NnsPointToNodeDistance)
-{
- // Well, there's no easy way to make HRectBounds the way we want, so we have
- // to make them and then expand the region to include new points.
- arma::vec utility(1);
- utility[0] = 0;
-
- arma::mat dataset("1");
- tree::BinarySpaceTree<HRectBound<2> > node(dataset);
- node.Bound() = HRectBound<2>(1);
- node.Bound() |= utility;
- utility[0] = 1;
- node.Bound() |= utility;
-
- arma::vec point(1);
- point[0] = -0.5;
-
- // The distance is the L2 distance.
- BOOST_REQUIRE_CLOSE(NearestNeighborSort::BestPointToNodeDistance(point,
- &node), 0.5, 1e-5);
-
- // Now from the other side of the bound.
- point[0] = 1.5;
-
- BOOST_REQUIRE_CLOSE(NearestNeighborSort::BestPointToNodeDistance(point,
- &node), 0.5, 1e-5);
-
- // And now when the point is inside the bound.
- point[0] = 0.5;
-
- BOOST_REQUIRE_SMALL(NearestNeighborSort::BestPointToNodeDistance(point,
- &node), 1e-5);
-}
-
-// Tests for FurthestNeighborSort
-
-/**
- * Ensure the best distance for furthest neighbors is DBL_MAX.
- */
-BOOST_AUTO_TEST_CASE(FnsBestDistance)
-{
- BOOST_REQUIRE(FurthestNeighborSort::BestDistance() == DBL_MAX);
-}
-
-/**
- * Ensure the worst distance for furthest neighbors is 0.
- */
-BOOST_AUTO_TEST_CASE(FnsWorstDistance)
-{
- BOOST_REQUIRE(FurthestNeighborSort::WorstDistance() == 0);
-}
-
-/**
- * Make sure the comparison works for values strictly less than the reference.
- */
-BOOST_AUTO_TEST_CASE(FnsIsBetterStrict)
-{
- BOOST_REQUIRE(FurthestNeighborSort::IsBetter(5.0, 4.0) == true);
-}
-
-/**
- * Warn in case the comparison is not strict.
- */
-BOOST_AUTO_TEST_CASE(FnsIsBetterNotStrict)
-{
- BOOST_WARN(FurthestNeighborSort::IsBetter(6.0, 6.0) == true);
-}
-
-/**
- * A simple test case of where to insert when all the values in the list are
- * 0.
- */
-BOOST_AUTO_TEST_CASE(FnsSortDistanceAllZero)
-{
- arma::vec list(5);
- list.fill(0);
-
- // Should be inserted at the head of the list.
- BOOST_REQUIRE(FurthestNeighborSort::SortDistance(list, 5.0) == 0);
-}
-
-/**
- * Another test case, where we are just putting the new value in the middle of
- * the list.
- */
-BOOST_AUTO_TEST_CASE(FnsSortDistance2)
-{
- arma::vec list(3);
- list[0] = 1.14;
- list[1] = 0.89;
- list[2] = 0.66;
-
- // Run a couple possibilities through.
- BOOST_REQUIRE(FurthestNeighborSort::SortDistance(list, 1.22) == 0);
- BOOST_REQUIRE(FurthestNeighborSort::SortDistance(list, 0.93) == 1);
- BOOST_REQUIRE(FurthestNeighborSort::SortDistance(list, 0.68) == 2);
- BOOST_REQUIRE(FurthestNeighborSort::SortDistance(list, 0.62) ==
- (size_t() - 1));
-}
-
-/**
- * Very simple sanity check to ensure that bounds are working alright. We will
- * use a one-dimensional bound for simplicity.
- */
-BOOST_AUTO_TEST_CASE(FnsNodeToNodeDistance)
-{
- // Well, there's no easy way to make HRectBounds the way we want, so we have
- // to make them and then expand the region to include new points.
- arma::vec utility(1);
- utility[0] = 0;
-
- arma::mat dataset("1");
- tree::BinarySpaceTree<HRectBound<2> > nodeOne(dataset);
- nodeOne.Bound() = HRectBound<2>(1);
- nodeOne.Bound() |= utility;
- utility[0] = 1;
- nodeOne.Bound() |= utility;
-
- tree::BinarySpaceTree<HRectBound<2> > nodeTwo(dataset);
- nodeTwo.Bound() = HRectBound<2>(1);
- utility[0] = 5;
- nodeTwo.Bound() |= utility;
- utility[0] = 6;
- nodeTwo.Bound() |= utility;
-
- // This should use the L2 distance.
- BOOST_REQUIRE_CLOSE(FurthestNeighborSort::BestNodeToNodeDistance(&nodeOne,
- &nodeTwo), 6.0, 1e-5);
-
- // And another just to be sure, from the other side.
- nodeTwo.Bound().Clear();
- utility[0] = -2;
- nodeTwo.Bound() |= utility;
- utility[0] = -1;
- nodeTwo.Bound() |= utility;
-
- // Again, the distance is the L2 distance.
- BOOST_REQUIRE_CLOSE(FurthestNeighborSort::BestNodeToNodeDistance(&nodeOne,
- &nodeTwo), 3.0, 1e-5);
-
- // Now, when the bounds overlap.
- nodeTwo.Bound().Clear();
- utility[0] = -0.5;
- nodeTwo.Bound() |= utility;
- utility[0] = 0.5;
- nodeTwo.Bound() |= utility;
-
- BOOST_REQUIRE_CLOSE(FurthestNeighborSort::BestNodeToNodeDistance(&nodeOne,
- &nodeTwo), 1.5, 1e-5);
-}
-
-/**
- * Another very simple sanity check for the point-to-node case, again in one
- * dimension.
- */
-BOOST_AUTO_TEST_CASE(FnsPointToNodeDistance)
-{
- // Well, there's no easy way to make HRectBounds the way we want, so we have
- // to make them and then expand the region to include new points.
- arma::vec utility(1);
- utility[0] = 0;
-
- arma::mat dataset("1");
- tree::BinarySpaceTree<HRectBound<2> > node(dataset);
- node.Bound() = HRectBound<2>(1);
- node.Bound() |= utility;
- utility[0] = 1;
- node.Bound() |= utility;
-
- arma::vec point(1);
- point[0] = -0.5;
-
- // The distance is the L2 distance.
- BOOST_REQUIRE_CLOSE(FurthestNeighborSort::BestPointToNodeDistance(point,
- &node), 1.5, 1e-5);
-
- // Now from the other side of the bound.
- point[0] = 1.5;
-
- BOOST_REQUIRE_CLOSE(FurthestNeighborSort::BestPointToNodeDistance(point,
- &node), 1.5, 1e-5);
-
- // And now when the point is inside the bound.
- point[0] = 0.5;
-
- BOOST_REQUIRE_CLOSE(FurthestNeighborSort::BestPointToNodeDistance(point,
- &node), 0.5, 1e-5);
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/sort_policy_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/sort_policy_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/sort_policy_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/sort_policy_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,351 @@
+/**
+ * @file sort_policy_test.cpp
+ * @author Ryan Curtin
+ *
+ * Tests for each of the implementations of the SortPolicy class.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/tree/binary_space_tree.hpp>
+
+// Classes to test.
+#include <mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp>
+#include <mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::neighbor;
+using namespace mlpack::bound;
+
+BOOST_AUTO_TEST_SUITE(SortPolicyTest);
+
+// Tests for NearestNeighborSort
+
+/**
+ * Ensure the best distance for nearest neighbors is 0.
+ */
+BOOST_AUTO_TEST_CASE(NnsBestDistance)
+{
+ BOOST_REQUIRE(NearestNeighborSort::BestDistance() == 0);
+}
+
+/**
+ * Ensure the worst distance for nearest neighbors is DBL_MAX.
+ */
+BOOST_AUTO_TEST_CASE(NnsWorstDistance)
+{
+ BOOST_REQUIRE(NearestNeighborSort::WorstDistance() == DBL_MAX);
+}
+
+/**
+ * Make sure the comparison works for values strictly less than the reference.
+ */
+BOOST_AUTO_TEST_CASE(NnsIsBetterStrict)
+{
+ BOOST_REQUIRE(NearestNeighborSort::IsBetter(5.0, 6.0) == true);
+}
+
+/**
+ * Warn in case the comparison is not strict.
+ */
+BOOST_AUTO_TEST_CASE(NnsIsBetterNotStrict)
+{
+ BOOST_WARN(NearestNeighborSort::IsBetter(6.0, 6.0) == true);
+}
+
+/**
+ * A simple test case of where to insert when all the values in the list are
+ * DBL_MAX.
+ */
+BOOST_AUTO_TEST_CASE(NnsSortDistanceAllDblMax)
+{
+ arma::vec list(5);
+ list.fill(DBL_MAX);
+
+ // Should be inserted at the head of the list.
+ BOOST_REQUIRE(NearestNeighborSort::SortDistance(list, 5.0) == 0);
+}
+
+/**
+ * Another test case, where we are just putting the new value in the middle of
+ * the list.
+ */
+BOOST_AUTO_TEST_CASE(NnsSortDistance2)
+{
+ arma::vec list(3);
+ list[0] = 0.66;
+ list[1] = 0.89;
+ list[2] = 1.14;
+
+ // Run a couple possibilities through.
+ BOOST_REQUIRE(NearestNeighborSort::SortDistance(list, 0.61) == 0);
+ BOOST_REQUIRE(NearestNeighborSort::SortDistance(list, 0.76) == 1);
+ BOOST_REQUIRE(NearestNeighborSort::SortDistance(list, 0.99) == 2);
+ BOOST_REQUIRE(NearestNeighborSort::SortDistance(list, 1.22) ==
+ (size_t() - 1));
+}
+
+/**
+ * Very simple sanity check to ensure that bounds are working alright. We will
+ * use a one-dimensional bound for simplicity.
+ */
+BOOST_AUTO_TEST_CASE(NnsNodeToNodeDistance)
+{
+ // Well, there's no easy way to make HRectBounds the way we want, so we have
+ // to make them and then expand the region to include new points.
+ arma::mat dataset("1");
+ tree::BinarySpaceTree<HRectBound<2>, tree::EmptyStatistic, arma::mat>
+ nodeOne(dataset);
+ arma::vec utility(1);
+ utility[0] = 0;
+
+ nodeOne.Bound() = HRectBound<2>(1);
+ nodeOne.Bound() |= utility;
+ utility[0] = 1;
+ nodeOne.Bound() |= utility;
+
+ tree::BinarySpaceTree<HRectBound<2>, tree::EmptyStatistic, arma::mat>
+ nodeTwo(dataset);
+ nodeTwo.Bound() = HRectBound<2>(1);
+
+ utility[0] = 5;
+ nodeTwo.Bound() |= utility;
+ utility[0] = 6;
+ nodeTwo.Bound() |= utility;
+
+ // This should use the L2 distance.
+ BOOST_REQUIRE_CLOSE(NearestNeighborSort::BestNodeToNodeDistance(&nodeOne,
+ &nodeTwo), 4.0, 1e-5);
+
+ // And another just to be sure, from the other side.
+ nodeTwo.Bound().Clear();
+ utility[0] = -2;
+ nodeTwo.Bound() |= utility;
+ utility[0] = -1;
+ nodeTwo.Bound() |= utility;
+
+ // Again, the distance is the L2 distance.
+ BOOST_REQUIRE_CLOSE(NearestNeighborSort::BestNodeToNodeDistance(&nodeOne,
+ &nodeTwo), 1.0, 1e-5);
+
+ // Now, when the bounds overlap.
+ nodeTwo.Bound().Clear();
+ utility[0] = -0.5;
+ nodeTwo.Bound() |= utility;
+ utility[0] = 0.5;
+ nodeTwo.Bound() |= utility;
+
+ BOOST_REQUIRE_SMALL(NearestNeighborSort::BestNodeToNodeDistance(&nodeOne,
+ &nodeTwo), 1e-5);
+}
+
+/**
+ * Another very simple sanity check for the point-to-node case, again in one
+ * dimension.
+ */
+BOOST_AUTO_TEST_CASE(NnsPointToNodeDistance)
+{
+ // Well, there's no easy way to make HRectBounds the way we want, so we have
+ // to make them and then expand the region to include new points.
+ arma::vec utility(1);
+ utility[0] = 0;
+
+ arma::mat dataset("1");
+ tree::BinarySpaceTree<HRectBound<2> > node(dataset);
+ node.Bound() = HRectBound<2>(1);
+ node.Bound() |= utility;
+ utility[0] = 1;
+ node.Bound() |= utility;
+
+ arma::vec point(1);
+ point[0] = -0.5;
+
+ // The distance is the L2 distance.
+ BOOST_REQUIRE_CLOSE(NearestNeighborSort::BestPointToNodeDistance(point,
+ &node), 0.5, 1e-5);
+
+ // Now from the other side of the bound.
+ point[0] = 1.5;
+
+ BOOST_REQUIRE_CLOSE(NearestNeighborSort::BestPointToNodeDistance(point,
+ &node), 0.5, 1e-5);
+
+ // And now when the point is inside the bound.
+ point[0] = 0.5;
+
+ BOOST_REQUIRE_SMALL(NearestNeighborSort::BestPointToNodeDistance(point,
+ &node), 1e-5);
+}
+
+// Tests for FurthestNeighborSort
+
+/**
+ * Ensure the best distance for furthest neighbors is DBL_MAX.
+ */
+BOOST_AUTO_TEST_CASE(FnsBestDistance)
+{
+ BOOST_REQUIRE(FurthestNeighborSort::BestDistance() == DBL_MAX);
+}
+
+/**
+ * Ensure the worst distance for furthest neighbors is 0.
+ */
+BOOST_AUTO_TEST_CASE(FnsWorstDistance)
+{
+ BOOST_REQUIRE(FurthestNeighborSort::WorstDistance() == 0);
+}
+
+/**
+ * Make sure the comparison works for values strictly less than the reference.
+ */
+BOOST_AUTO_TEST_CASE(FnsIsBetterStrict)
+{
+ BOOST_REQUIRE(FurthestNeighborSort::IsBetter(5.0, 4.0) == true);
+}
+
+/**
+ * Warn in case the comparison is not strict.
+ */
+BOOST_AUTO_TEST_CASE(FnsIsBetterNotStrict)
+{
+ BOOST_WARN(FurthestNeighborSort::IsBetter(6.0, 6.0) == true);
+}
+
+/**
+ * A simple test case of where to insert when all the values in the list are
+ * 0.
+ */
+BOOST_AUTO_TEST_CASE(FnsSortDistanceAllZero)
+{
+ arma::vec list(5);
+ list.fill(0);
+
+ // Should be inserted at the head of the list.
+ BOOST_REQUIRE(FurthestNeighborSort::SortDistance(list, 5.0) == 0);
+}
+
+/**
+ * Another test case, where we are just putting the new value in the middle of
+ * the list.
+ */
+BOOST_AUTO_TEST_CASE(FnsSortDistance2)
+{
+ arma::vec list(3);
+ list[0] = 1.14;
+ list[1] = 0.89;
+ list[2] = 0.66;
+
+ // Run a couple possibilities through.
+ BOOST_REQUIRE(FurthestNeighborSort::SortDistance(list, 1.22) == 0);
+ BOOST_REQUIRE(FurthestNeighborSort::SortDistance(list, 0.93) == 1);
+ BOOST_REQUIRE(FurthestNeighborSort::SortDistance(list, 0.68) == 2);
+ BOOST_REQUIRE(FurthestNeighborSort::SortDistance(list, 0.62) ==
+ (size_t() - 1));
+}
+
+/**
+ * Very simple sanity check to ensure that bounds are working alright. We will
+ * use a one-dimensional bound for simplicity.
+ */
+BOOST_AUTO_TEST_CASE(FnsNodeToNodeDistance)
+{
+ // Well, there's no easy way to make HRectBounds the way we want, so we have
+ // to make them and then expand the region to include new points.
+ arma::vec utility(1);
+ utility[0] = 0;
+
+ arma::mat dataset("1");
+ tree::BinarySpaceTree<HRectBound<2> > nodeOne(dataset);
+ nodeOne.Bound() = HRectBound<2>(1);
+ nodeOne.Bound() |= utility;
+ utility[0] = 1;
+ nodeOne.Bound() |= utility;
+
+ tree::BinarySpaceTree<HRectBound<2> > nodeTwo(dataset);
+ nodeTwo.Bound() = HRectBound<2>(1);
+ utility[0] = 5;
+ nodeTwo.Bound() |= utility;
+ utility[0] = 6;
+ nodeTwo.Bound() |= utility;
+
+ // This should use the L2 distance.
+ BOOST_REQUIRE_CLOSE(FurthestNeighborSort::BestNodeToNodeDistance(&nodeOne,
+ &nodeTwo), 6.0, 1e-5);
+
+ // And another just to be sure, from the other side.
+ nodeTwo.Bound().Clear();
+ utility[0] = -2;
+ nodeTwo.Bound() |= utility;
+ utility[0] = -1;
+ nodeTwo.Bound() |= utility;
+
+ // Again, the distance is the L2 distance.
+ BOOST_REQUIRE_CLOSE(FurthestNeighborSort::BestNodeToNodeDistance(&nodeOne,
+ &nodeTwo), 3.0, 1e-5);
+
+ // Now, when the bounds overlap.
+ nodeTwo.Bound().Clear();
+ utility[0] = -0.5;
+ nodeTwo.Bound() |= utility;
+ utility[0] = 0.5;
+ nodeTwo.Bound() |= utility;
+
+ BOOST_REQUIRE_CLOSE(FurthestNeighborSort::BestNodeToNodeDistance(&nodeOne,
+ &nodeTwo), 1.5, 1e-5);
+}
+
+/**
+ * Another very simple sanity check for the point-to-node case, again in one
+ * dimension.
+ */
+BOOST_AUTO_TEST_CASE(FnsPointToNodeDistance)
+{
+ // Well, there's no easy way to make HRectBounds the way we want, so we have
+ // to make them and then expand the region to include new points.
+ arma::vec utility(1);
+ utility[0] = 0;
+
+ arma::mat dataset("1");
+ tree::BinarySpaceTree<HRectBound<2> > node(dataset);
+ node.Bound() = HRectBound<2>(1);
+ node.Bound() |= utility;
+ utility[0] = 1;
+ node.Bound() |= utility;
+
+ arma::vec point(1);
+ point[0] = -0.5;
+
+ // The distance is the L2 distance.
+ BOOST_REQUIRE_CLOSE(FurthestNeighborSort::BestPointToNodeDistance(point,
+ &node), 1.5, 1e-5);
+
+ // Now from the other side of the bound.
+ point[0] = 1.5;
+
+ BOOST_REQUIRE_CLOSE(FurthestNeighborSort::BestPointToNodeDistance(point,
+ &node), 1.5, 1e-5);
+
+ // And now when the point is inside the bound.
+ point[0] = 0.5;
+
+ BOOST_REQUIRE_CLOSE(FurthestNeighborSort::BestPointToNodeDistance(point,
+ &node), 0.5, 1e-5);
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/sparse_coding_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/sparse_coding_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/sparse_coding_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,154 +0,0 @@
-/**
- * @file sparse_coding_test.cpp
- *
- * Test for Sparse Coding
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-
-// Note: We don't use BOOST_REQUIRE_CLOSE in the code below because we need
-// to use FPC_WEAK, and it's not at all intuitive how to do that.
-
-#include <mlpack/core.hpp>
-#include <mlpack/methods/sparse_coding/sparse_coding.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace arma;
-using namespace mlpack;
-using namespace mlpack::regression;
-using namespace mlpack::sparse_coding;
-
-BOOST_AUTO_TEST_SUITE(SparseCodingTest);
-
-void SCVerifyCorrectness(vec beta, vec errCorr, double lambda)
-{
- const double tol = 1e-12;
- size_t nDims = beta.n_elem;
- for(size_t j = 0; j < nDims; j++)
- {
- if (beta(j) == 0)
- {
- // Make sure that errCorr(j) <= lambda.
- BOOST_REQUIRE_SMALL(std::max(fabs(errCorr(j)) - lambda, 0.0), tol);
- }
- else if (beta(j) < 0)
- {
- // Make sure that errCorr(j) == lambda.
- BOOST_REQUIRE_SMALL(errCorr(j) - lambda, tol);
- }
- else // beta(j) > 0.
- {
- // Make sure that errCorr(j) == -lambda.
- BOOST_REQUIRE_SMALL(errCorr(j) + lambda, tol);
- }
- }
-}
-
-BOOST_AUTO_TEST_CASE(SparseCodingTestCodingStepLasso)
-{
- double lambda1 = 0.1;
- uword nAtoms = 25;
-
- mat X;
- X.load("mnist_first250_training_4s_and_9s.arm");
- uword nPoints = X.n_cols;
-
- // Normalize each point since these are images.
- for (uword i = 0; i < nPoints; ++i) {
- X.col(i) /= norm(X.col(i), 2);
- }
-
- SparseCoding<> sc(X, nAtoms, lambda1);
- sc.OptimizeCode();
-
- mat D = sc.Dictionary();
- mat Z = sc.Codes();
-
- for (uword i = 0; i < nPoints; ++i)
- {
- vec errCorr = trans(D) * (D * Z.unsafe_col(i) - X.unsafe_col(i));
- SCVerifyCorrectness(Z.unsafe_col(i), errCorr, lambda1);
- }
-}
-
-BOOST_AUTO_TEST_CASE(SparseCodingTestCodingStepElasticNet)
-{
- double lambda1 = 0.1;
- double lambda2 = 0.2;
- uword nAtoms = 25;
-
- mat X;
- X.load("mnist_first250_training_4s_and_9s.arm");
- uword nPoints = X.n_cols;
-
- // Normalize each point since these are images.
- for (uword i = 0; i < nPoints; ++i)
- X.col(i) /= norm(X.col(i), 2);
-
- SparseCoding<> sc(X, nAtoms, lambda1, lambda2);
- sc.OptimizeCode();
-
- mat D = sc.Dictionary();
- mat Z = sc.Codes();
-
- for(uword i = 0; i < nPoints; ++i)
- {
- vec errCorr =
- (trans(D) * D + lambda2 * eye(nAtoms, nAtoms)) * Z.unsafe_col(i)
- - trans(D) * X.unsafe_col(i);
-
- SCVerifyCorrectness(Z.unsafe_col(i), errCorr, lambda1);
- }
-}
-
-BOOST_AUTO_TEST_CASE(SparseCodingTestDictionaryStep)
-{
- const double tol = 2e-7;
-
- double lambda1 = 0.1;
- uword nAtoms = 25;
-
- mat X;
- X.load("mnist_first250_training_4s_and_9s.arm");
- uword nPoints = X.n_cols;
-
- // Normalize each point since these are images.
- for (uword i = 0; i < nPoints; ++i)
- X.col(i) /= norm(X.col(i), 2);
-
- SparseCoding<> sc(X, nAtoms, lambda1);
- sc.OptimizeCode();
-
- mat D = sc.Dictionary();
- mat Z = sc.Codes();
-
- uvec adjacencies = find(Z);
- double normGradient = sc.OptimizeDictionary(adjacencies, 1e-12);
-
- BOOST_REQUIRE_SMALL(normGradient, tol);
-}
-
-/*
-BOOST_AUTO_TEST_CASE(SparseCodingTestWhole)
-{
-
-}
-*/
-
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/sparse_coding_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/sparse_coding_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/sparse_coding_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/sparse_coding_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,154 @@
+/**
+ * @file sparse_coding_test.cpp
+ *
+ * Test for Sparse Coding
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+// Note: We don't use BOOST_REQUIRE_CLOSE in the code below because we need
+// to use FPC_WEAK, and it's not at all intuitive how to do that.
+
+#include <mlpack/core.hpp>
+#include <mlpack/methods/sparse_coding/sparse_coding.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace arma;
+using namespace mlpack;
+using namespace mlpack::regression;
+using namespace mlpack::sparse_coding;
+
+BOOST_AUTO_TEST_SUITE(SparseCodingTest);
+
+void SCVerifyCorrectness(vec beta, vec errCorr, double lambda)
+{
+ const double tol = 1e-12;
+ size_t nDims = beta.n_elem;
+ for(size_t j = 0; j < nDims; j++)
+ {
+ if (beta(j) == 0)
+ {
+ // Make sure that errCorr(j) <= lambda.
+ BOOST_REQUIRE_SMALL(std::max(fabs(errCorr(j)) - lambda, 0.0), tol);
+ }
+ else if (beta(j) < 0)
+ {
+ // Make sure that errCorr(j) == lambda.
+ BOOST_REQUIRE_SMALL(errCorr(j) - lambda, tol);
+ }
+ else // beta(j) > 0.
+ {
+ // Make sure that errCorr(j) == -lambda.
+ BOOST_REQUIRE_SMALL(errCorr(j) + lambda, tol);
+ }
+ }
+}
+
+BOOST_AUTO_TEST_CASE(SparseCodingTestCodingStepLasso)
+{
+ double lambda1 = 0.1;
+ uword nAtoms = 25;
+
+ mat X;
+ X.load("mnist_first250_training_4s_and_9s.arm");
+ uword nPoints = X.n_cols;
+
+ // Normalize each point since these are images.
+ for (uword i = 0; i < nPoints; ++i) {
+ X.col(i) /= norm(X.col(i), 2);
+ }
+
+ SparseCoding<> sc(X, nAtoms, lambda1);
+ sc.OptimizeCode();
+
+ mat D = sc.Dictionary();
+ mat Z = sc.Codes();
+
+ for (uword i = 0; i < nPoints; ++i)
+ {
+ vec errCorr = trans(D) * (D * Z.unsafe_col(i) - X.unsafe_col(i));
+ SCVerifyCorrectness(Z.unsafe_col(i), errCorr, lambda1);
+ }
+}
+
+BOOST_AUTO_TEST_CASE(SparseCodingTestCodingStepElasticNet)
+{
+ double lambda1 = 0.1;
+ double lambda2 = 0.2;
+ uword nAtoms = 25;
+
+ mat X;
+ X.load("mnist_first250_training_4s_and_9s.arm");
+ uword nPoints = X.n_cols;
+
+ // Normalize each point since these are images.
+ for (uword i = 0; i < nPoints; ++i)
+ X.col(i) /= norm(X.col(i), 2);
+
+ SparseCoding<> sc(X, nAtoms, lambda1, lambda2);
+ sc.OptimizeCode();
+
+ mat D = sc.Dictionary();
+ mat Z = sc.Codes();
+
+ for(uword i = 0; i < nPoints; ++i)
+ {
+ vec errCorr =
+ (trans(D) * D + lambda2 * eye(nAtoms, nAtoms)) * Z.unsafe_col(i)
+ - trans(D) * X.unsafe_col(i);
+
+ SCVerifyCorrectness(Z.unsafe_col(i), errCorr, lambda1);
+ }
+}
+
+BOOST_AUTO_TEST_CASE(SparseCodingTestDictionaryStep)
+{
+ const double tol = 2e-7;
+
+ double lambda1 = 0.1;
+ uword nAtoms = 25;
+
+ mat X;
+ X.load("mnist_first250_training_4s_and_9s.arm");
+ uword nPoints = X.n_cols;
+
+ // Normalize each point since these are images.
+ for (uword i = 0; i < nPoints; ++i)
+ X.col(i) /= norm(X.col(i), 2);
+
+ SparseCoding<> sc(X, nAtoms, lambda1);
+ sc.OptimizeCode();
+
+ mat D = sc.Dictionary();
+ mat Z = sc.Codes();
+
+ uvec adjacencies = find(Z);
+ double normGradient = sc.OptimizeDictionary(adjacencies, 1e-12);
+
+ BOOST_REQUIRE_SMALL(normGradient, tol);
+}
+
+/*
+BOOST_AUTO_TEST_CASE(SparseCodingTestWhole)
+{
+
+}
+*/
+
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/tree_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/tree_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/tree_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,2209 +0,0 @@
-/**
- * @file tree_test.cpp
- *
- * Tests for tree-building methods.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/core/tree/bounds.hpp>
-#include <mlpack/core/tree/binary_space_tree/binary_space_tree.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-#include <mlpack/core/tree/cover_tree/cover_tree.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::math;
-using namespace mlpack::tree;
-using namespace mlpack::metric;
-using namespace mlpack::bound;
-
-BOOST_AUTO_TEST_SUITE(TreeTest);
-
-/**
- * Ensure that a bound, by default, is empty and has no dimensionality.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundEmptyConstructor)
-{
- HRectBound<2> b;
-
- BOOST_REQUIRE_EQUAL((int) b.Dim(), 0);
-}
-
-/**
- * Ensure that when we specify the dimensionality in the constructor, it is
- * correct, and the bounds are all the empty set.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundDimConstructor)
-{
- HRectBound<2> b(2); // We'll do this with 2 and 5 dimensions.
-
- BOOST_REQUIRE_EQUAL(b.Dim(), 2);
- BOOST_REQUIRE_SMALL(b[0].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(b[1].Width(), 1e-5);
-
- b = HRectBound<2>(5);
-
- BOOST_REQUIRE_EQUAL(b.Dim(), 5);
- BOOST_REQUIRE_SMALL(b[0].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(b[1].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(b[2].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(b[3].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(b[4].Width(), 1e-5);
-}
-
-/**
- * Test the copy constructor.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundCopyConstructor)
-{
- HRectBound<2> b(2);
- b[0] = Range(0.0, 2.0);
- b[1] = Range(2.0, 3.0);
-
- HRectBound<2> c(b);
-
- BOOST_REQUIRE_EQUAL(c.Dim(), 2);
- BOOST_REQUIRE_SMALL(c[0].Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(c[0].Hi(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c[1].Lo(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c[1].Hi(), 3.0, 1e-5);
-}
-
-/**
- * Test the assignment operator.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundAssignmentOperator)
-{
- HRectBound<2> b(2);
- b[0] = Range(0.0, 2.0);
- b[1] = Range(2.0, 3.0);
-
- HRectBound<2> c(4);
-
- c = b;
-
- BOOST_REQUIRE_EQUAL(c.Dim(), 2);
- BOOST_REQUIRE_SMALL(c[0].Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(c[0].Hi(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c[1].Lo(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c[1].Hi(), 3.0, 1e-5);
-}
-
-/**
- * Test that clearing the dimensions resets the bound to empty.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundClear)
-{
- HRectBound<2> b(2); // We'll do this with two dimensions only.
-
- b[0] = Range(0.0, 2.0);
- b[1] = Range(2.0, 4.0);
-
- // Now we just need to make sure that we clear the range.
- b.Clear();
-
- BOOST_REQUIRE_SMALL(b[0].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(b[1].Width(), 1e-5);
-}
-
-/**
- * Ensure that we get the correct centroid for our bound.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundCentroid)
-{
- // Create a simple 3-dimensional bound.
- HRectBound<2> b(3);
-
- b[0] = Range(0.0, 5.0);
- b[1] = Range(-2.0, -1.0);
- b[2] = Range(-10.0, 50.0);
-
- arma::vec centroid;
-
- b.Centroid(centroid);
-
- BOOST_REQUIRE_EQUAL(centroid.n_elem, 3);
- BOOST_REQUIRE_CLOSE(centroid[0], 2.5, 1e-5);
- BOOST_REQUIRE_CLOSE(centroid[1], -1.5, 1e-5);
- BOOST_REQUIRE_CLOSE(centroid[2], 20.0, 1e-5);
-}
-
-/**
- * Ensure that we calculate the correct minimum distance between a point and a
- * bound.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundMinDistancePoint)
-{
- // We'll do the calculation in five dimensions, and we'll use three cases for
- // the point: point is outside the bound; point is on the edge of the bound;
- // point is inside the bound. In the latter two cases, the distance should be
- // zero.
- HRectBound<2> b(5);
-
- b[0] = Range(0.0, 2.0);
- b[1] = Range(1.0, 5.0);
- b[2] = Range(-2.0, 2.0);
- b[3] = Range(-5.0, -2.0);
- b[4] = Range(1.0, 2.0);
-
- arma::vec point = "-2.0 0.0 10.0 3.0 3.0";
-
- // This will be the Euclidean distance.
- BOOST_REQUIRE_CLOSE(b.MinDistance(point), sqrt(95.0), 1e-5);
-
- point = "2.0 5.0 2.0 -5.0 1.0";
-
- BOOST_REQUIRE_SMALL(b.MinDistance(point), 1e-5);
-
- point = "1.0 2.0 0.0 -2.0 1.5";
-
- BOOST_REQUIRE_SMALL(b.MinDistance(point), 1e-5);
-}
-
-/**
- * Ensure that we calculate the correct minimum distance between a bound and
- * another bound.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundMinDistanceBound)
-{
- // We'll do the calculation in five dimensions, and we can use six cases.
- // The other bound is completely outside the bound; the other bound is on the
- // edge of the bound; the other bound partially overlaps the bound; the other
- // bound fully overlaps the bound; the other bound is entirely inside the
- // bound; the other bound entirely envelops the bound.
- HRectBound<2> b(5);
-
- b[0] = Range(0.0, 2.0);
- b[1] = Range(1.0, 5.0);
- b[2] = Range(-2.0, 2.0);
- b[3] = Range(-5.0, -2.0);
- b[4] = Range(1.0, 2.0);
-
- HRectBound<2> c(5);
-
- // The other bound is completely outside the bound.
- c[0] = Range(-5.0, -2.0);
- c[1] = Range(6.0, 7.0);
- c[2] = Range(-2.0, 2.0);
- c[3] = Range(2.0, 5.0);
- c[4] = Range(3.0, 4.0);
-
- BOOST_REQUIRE_CLOSE(b.MinDistance(c), sqrt(22.0), 1e-5);
- BOOST_REQUIRE_CLOSE(c.MinDistance(b), sqrt(22.0), 1e-5);
-
- // The other bound is on the edge of the bound.
- c[0] = Range(-2.0, 0.0);
- c[1] = Range(0.0, 1.0);
- c[2] = Range(-3.0, -2.0);
- c[3] = Range(-10.0, -5.0);
- c[4] = Range(2.0, 3.0);
-
- BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
- BOOST_REQUIRE_SMALL(c.MinDistance(b), 1e-5);
-
- // The other bound partially overlaps the bound.
- c[0] = Range(-2.0, 1.0);
- c[1] = Range(0.0, 2.0);
- c[2] = Range(-2.0, 2.0);
- c[3] = Range(-8.0, -4.0);
- c[4] = Range(0.0, 4.0);
-
- BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
- BOOST_REQUIRE_SMALL(c.MinDistance(b), 1e-5);
-
- // The other bound fully overlaps the bound.
- BOOST_REQUIRE_SMALL(b.MinDistance(b), 1e-5);
- BOOST_REQUIRE_SMALL(c.MinDistance(c), 1e-5);
-
- // The other bound is entirely inside the bound / the other bound entirely
- // envelops the bound.
- c[0] = Range(-1.0, 3.0);
- c[1] = Range(0.0, 6.0);
- c[2] = Range(-3.0, 3.0);
- c[3] = Range(-7.0, 0.0);
- c[4] = Range(0.0, 5.0);
-
- BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
- BOOST_REQUIRE_SMALL(c.MinDistance(b), 1e-5);
-
- // Now we must be sure that the minimum distance to itself is 0.
- BOOST_REQUIRE_SMALL(b.MinDistance(b), 1e-5);
- BOOST_REQUIRE_SMALL(c.MinDistance(c), 1e-5);
-}
-
-/**
- * Ensure that we calculate the correct maximum distance between a bound and a
- * point. This uses the same test cases as the MinDistance test.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundMaxDistancePoint)
-{
- // We'll do the calculation in five dimensions, and we'll use three cases for
- // the point: point is outside the bound; point is on the edge of the bound;
- // point is inside the bound. In the latter two cases, the distance should be
- // zero.
- HRectBound<2> b(5);
-
- b[0] = Range(0.0, 2.0);
- b[1] = Range(1.0, 5.0);
- b[2] = Range(-2.0, 2.0);
- b[3] = Range(-5.0, -2.0);
- b[4] = Range(1.0, 2.0);
-
- arma::vec point = "-2.0 0.0 10.0 3.0 3.0";
-
- // This will be the Euclidean distance.
- BOOST_REQUIRE_CLOSE(b.MaxDistance(point), sqrt(253.0), 1e-5);
-
- point = "2.0 5.0 2.0 -5.0 1.0";
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(point), sqrt(46.0), 1e-5);
-
- point = "1.0 2.0 0.0 -2.0 1.5";
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(point), sqrt(23.25), 1e-5);
-}
-
-/**
- * Ensure that we calculate the correct maximum distance between a bound and
- * another bound. This uses the same test cases as the MinDistance test.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundMaxDistanceBound)
-{
- // We'll do the calculation in five dimensions, and we can use six cases.
- // The other bound is completely outside the bound; the other bound is on the
- // edge of the bound; the other bound partially overlaps the bound; the other
- // bound fully overlaps the bound; the other bound is entirely inside the
- // bound; the other bound entirely envelops the bound.
- HRectBound<2> b(5);
-
- b[0] = Range(0.0, 2.0);
- b[1] = Range(1.0, 5.0);
- b[2] = Range(-2.0, 2.0);
- b[3] = Range(-5.0, -2.0);
- b[4] = Range(1.0, 2.0);
-
- HRectBound<2> c(5);
-
- // The other bound is completely outside the bound.
- c[0] = Range(-5.0, -2.0);
- c[1] = Range(6.0, 7.0);
- c[2] = Range(-2.0, 2.0);
- c[3] = Range(2.0, 5.0);
- c[4] = Range(3.0, 4.0);
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), sqrt(210.0), 1e-5);
- BOOST_REQUIRE_CLOSE(c.MaxDistance(b), sqrt(210.0), 1e-5);
-
- // The other bound is on the edge of the bound.
- c[0] = Range(-2.0, 0.0);
- c[1] = Range(0.0, 1.0);
- c[2] = Range(-3.0, -2.0);
- c[3] = Range(-10.0, -5.0);
- c[4] = Range(2.0, 3.0);
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), sqrt(134.0), 1e-5);
- BOOST_REQUIRE_CLOSE(c.MaxDistance(b), sqrt(134.0), 1e-5);
-
- // The other bound partially overlaps the bound.
- c[0] = Range(-2.0, 1.0);
- c[1] = Range(0.0, 2.0);
- c[2] = Range(-2.0, 2.0);
- c[3] = Range(-8.0, -4.0);
- c[4] = Range(0.0, 4.0);
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), sqrt(102.0), 1e-5);
- BOOST_REQUIRE_CLOSE(c.MaxDistance(b), sqrt(102.0), 1e-5);
-
- // The other bound fully overlaps the bound.
- BOOST_REQUIRE_CLOSE(b.MaxDistance(b), sqrt(46.0), 1e-5);
- BOOST_REQUIRE_CLOSE(c.MaxDistance(c), sqrt(61.0), 1e-5);
-
- // The other bound is entirely inside the bound / the other bound entirely
- // envelops the bound.
- c[0] = Range(-1.0, 3.0);
- c[1] = Range(0.0, 6.0);
- c[2] = Range(-3.0, 3.0);
- c[3] = Range(-7.0, 0.0);
- c[4] = Range(0.0, 5.0);
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), sqrt(100.0), 1e-5);
- BOOST_REQUIRE_CLOSE(c.MaxDistance(b), sqrt(100.0), 1e-5);
-
- // Identical bounds. This will be the sum of the squared widths in each
- // dimension.
- BOOST_REQUIRE_CLOSE(b.MaxDistance(b), sqrt(46.0), 1e-5);
- BOOST_REQUIRE_CLOSE(c.MaxDistance(c), sqrt(162.0), 1e-5);
-
- // One last additional case. If the bound encloses only one point, the
- // maximum distance between it and itself is 0.
- HRectBound<2> d(2);
-
- d[0] = Range(2.0, 2.0);
- d[1] = Range(3.0, 3.0);
-
- BOOST_REQUIRE_SMALL(d.MaxDistance(d), 1e-5);
-}
-
-/**
- * Ensure that the ranges returned by RangeDistance() are equal to the minimum
- * and maximum distance. We will perform this test by creating random bounds
- * and comparing the behavior to MinDistance() and MaxDistance() -- so this test
- * is assuming that those passed and operate correctly.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundRangeDistanceBound)
-{
- for (int i = 0; i < 50; i++)
- {
- size_t dim = math::RandInt(20);
-
- HRectBound<2> a(dim);
- HRectBound<2> b(dim);
-
- // We will set the low randomly and the width randomly for each dimension of
- // each bound.
- arma::vec loA(dim);
- arma::vec widthA(dim);
-
- loA.randu();
- widthA.randu();
-
- arma::vec lo_b(dim);
- arma::vec width_b(dim);
-
- lo_b.randu();
- width_b.randu();
-
- for (size_t j = 0; j < dim; j++)
- {
- a[j] = Range(loA[j], loA[j] + widthA[j]);
- b[j] = Range(lo_b[j], lo_b[j] + width_b[j]);
- }
-
- // Now ensure that MinDistance and MaxDistance report the same.
- Range r = a.RangeDistance(b);
- Range s = b.RangeDistance(a);
-
- BOOST_REQUIRE_CLOSE(r.Lo(), s.Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(r.Hi(), s.Hi(), 1e-5);
-
- BOOST_REQUIRE_CLOSE(r.Lo(), a.MinDistance(b), 1e-5);
- BOOST_REQUIRE_CLOSE(r.Hi(), a.MaxDistance(b), 1e-5);
-
- BOOST_REQUIRE_CLOSE(s.Lo(), b.MinDistance(a), 1e-5);
- BOOST_REQUIRE_CLOSE(s.Hi(), b.MaxDistance(a), 1e-5);
- }
-}
-
-/**
- * Ensure that the ranges returned by RangeDistance() are equal to the minimum
- * and maximum distance. We will perform this test by creating random bounds
- * and comparing the bheavior to MinDistance() and MaxDistance() -- so this test
- * is assuming that those passed and operate correctly. This is for the
- * bound-to-point case.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundRangeDistancePoint)
-{
- for (int i = 0; i < 20; i++)
- {
- size_t dim = math::RandInt(20);
-
- HRectBound<2> a(dim);
-
- // We will set the low randomly and the width randomly for each dimension of
- // each bound.
- arma::vec loA(dim);
- arma::vec widthA(dim);
-
- loA.randu();
- widthA.randu();
-
- for (size_t j = 0; j < dim; j++)
- a[j] = Range(loA[j], loA[j] + widthA[j]);
-
- // Now run the test on a few points.
- for (int j = 0; j < 10; j++)
- {
- arma::vec point(dim);
-
- point.randu();
-
- Range r = a.RangeDistance(point);
-
- BOOST_REQUIRE_CLOSE(r.Lo(), a.MinDistance(point), 1e-5);
- BOOST_REQUIRE_CLOSE(r.Hi(), a.MaxDistance(point), 1e-5);
- }
- }
-}
-
-/**
- * Test that we can expand the bound to include a new point.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundOrOperatorPoint)
-{
- // Because this should be independent in each dimension, we can essentially
- // run five test cases at once.
- HRectBound<2> b(5);
-
- b[0] = Range(1.0, 3.0);
- b[1] = Range(2.0, 4.0);
- b[2] = Range(-2.0, -1.0);
- b[3] = Range(0.0, 0.0);
- b[4] = Range(); // Empty range.
-
- arma::vec point = "2.0 4.0 2.0 -1.0 6.0";
-
- b |= point;
-
- BOOST_REQUIRE_CLOSE(b[0].Lo(), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[0].Hi(), 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[1].Lo(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[1].Hi(), 4.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[2].Lo(), -2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[2].Hi(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[3].Lo(), -1.0, 1e-5);
- BOOST_REQUIRE_SMALL(b[3].Hi(), 1e-5);
- BOOST_REQUIRE_CLOSE(b[4].Lo(), 6.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[4].Hi(), 6.0, 1e-5);
-}
-
-/**
- * Test that we can expand the bound to include another bound.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundOrOperatorBound)
-{
- // Because this should be independent in each dimension, we can run many tests
- // at once.
- HRectBound<2> b(8);
-
- b[0] = Range(1.0, 3.0);
- b[1] = Range(2.0, 4.0);
- b[2] = Range(-2.0, -1.0);
- b[3] = Range(4.0, 5.0);
- b[4] = Range(2.0, 4.0);
- b[5] = Range(0.0, 0.0);
- b[6] = Range();
- b[7] = Range(1.0, 3.0);
-
- HRectBound<2> c(8);
-
- c[0] = Range(-3.0, -1.0); // Entirely less than the other bound.
- c[1] = Range(0.0, 2.0); // Touching edges.
- c[2] = Range(-3.0, -1.5); // Partially overlapping.
- c[3] = Range(4.0, 5.0); // Identical.
- c[4] = Range(1.0, 5.0); // Entirely enclosing.
- c[5] = Range(2.0, 2.0); // A single point.
- c[6] = Range(1.0, 3.0);
- c[7] = Range(); // Empty set.
-
- HRectBound<2> d = c;
-
- b |= c;
- d |= b;
-
- BOOST_REQUIRE_CLOSE(b[0].Lo(), -3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[0].Hi(), 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[0].Lo(), -3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[0].Hi(), 3.0, 1e-5);
-
- BOOST_REQUIRE_CLOSE(b[1].Lo(), 0.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[1].Hi(), 4.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[1].Lo(), 0.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[1].Hi(), 4.0, 1e-5);
-
- BOOST_REQUIRE_CLOSE(b[2].Lo(), -3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[2].Hi(), -1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[2].Lo(), -3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[2].Hi(), -1.0, 1e-5);
-
- BOOST_REQUIRE_CLOSE(b[3].Lo(), 4.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[3].Hi(), 5.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[3].Lo(), 4.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[3].Hi(), 5.0, 1e-5);
-
- BOOST_REQUIRE_CLOSE(b[4].Lo(), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[4].Hi(), 5.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[4].Lo(), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[4].Hi(), 5.0, 1e-5);
-
- BOOST_REQUIRE_SMALL(b[5].Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(b[5].Hi(), 2.0, 1e-5);
- BOOST_REQUIRE_SMALL(d[5].Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(d[5].Hi(), 2.0, 1e-5);
-
- BOOST_REQUIRE_CLOSE(b[6].Lo(), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[6].Hi(), 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[6].Lo(), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[6].Hi(), 3.0, 1e-5);
-
- BOOST_REQUIRE_CLOSE(b[7].Lo(), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b[7].Hi(), 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[7].Lo(), 1.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d[7].Hi(), 3.0, 1e-5);
-}
-
-/**
- * Test that the Contains() function correctly figures out whether or not a
- * point is in a bound.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundContains)
-{
- // We can test a couple different points: completely outside the bound,
- // adjacent in one dimension to the bound, adjacent in all dimensions to the
- // bound, and inside the bound.
- HRectBound<2> b(3);
-
- b[0] = Range(0.0, 2.0);
- b[1] = Range(0.0, 2.0);
- b[2] = Range(0.0, 2.0);
-
- // Completely outside the range.
- arma::vec point = "-1.0 4.0 4.0";
- BOOST_REQUIRE(!b.Contains(point));
-
- // Completely outside, but one dimension is in the range.
- point = "-1.0 4.0 1.0";
- BOOST_REQUIRE(!b.Contains(point));
-
- // Outside, but one dimension is on the edge.
- point = "-1.0 0.0 3.0";
- BOOST_REQUIRE(!b.Contains(point));
-
- // Two dimensions are on the edge, but one is outside.
- point = "0.0 0.0 3.0";
- BOOST_REQUIRE(!b.Contains(point));
-
- // Completely on the edge (should be contained).
- point = "0.0 0.0 0.0";
- BOOST_REQUIRE(b.Contains(point));
-
- // Inside the range.
- point = "0.3 1.0 0.4";
- BOOST_REQUIRE(b.Contains(point));
-}
-
-BOOST_AUTO_TEST_CASE(TestBallBound)
-{
- BallBound<> b1;
- BallBound<> b2;
-
- // Create two balls with a center distance of 1 from each other.
- // Give the first one a radius of 0.3 and the second a radius of 0.4.
- b1.Center().set_size(3);
- b1.Center()[0] = 1;
- b1.Center()[1] = 2;
- b1.Center()[2] = 3;
- b1.Radius() = 0.3;
-
- b2.Center().set_size(3);
- b2.Center()[0] = 1;
- b2.Center()[1] = 2;
- b2.Center()[2] = 4;
- b2.Radius() = 0.4;
-
- BOOST_REQUIRE_CLOSE(b1.MinDistance(b2), 1-0.3-0.4, 1e-5);
- BOOST_REQUIRE_CLOSE(b1.RangeDistance(b2).Hi(), 1+0.3+0.4, 1e-5);
- BOOST_REQUIRE_CLOSE(b1.RangeDistance(b2).Lo(), 1-0.3-0.4, 1e-5);
- BOOST_REQUIRE_CLOSE(b1.RangeDistance(b2).Hi(), 1+0.3+0.4, 1e-5);
- BOOST_REQUIRE_CLOSE(b1.RangeDistance(b2).Lo(), 1-0.3-0.4, 1e-5);
-
- BOOST_REQUIRE_CLOSE(b2.MinDistance(b1), 1-0.3-0.4, 1e-5);
- BOOST_REQUIRE_CLOSE(b2.MaxDistance(b1), 1+0.3+0.4, 1e-5);
- BOOST_REQUIRE_CLOSE(b2.RangeDistance(b1).Hi(), 1+0.3+0.4, 1e-5);
- BOOST_REQUIRE_CLOSE(b2.RangeDistance(b1).Lo(), 1-0.3-0.4, 1e-5);
-
- BOOST_REQUIRE(b1.Contains(b1.Center()));
- BOOST_REQUIRE(!b1.Contains(b2.Center()));
-
- BOOST_REQUIRE(!b2.Contains(b1.Center()));
- BOOST_REQUIRE(b2.Contains(b2.Center()));
- arma::vec b2point(3); // A point that's within the radius but not the center.
- b2point[0] = 1.1;
- b2point[1] = 2.1;
- b2point[2] = 4.1;
-
- BOOST_REQUIRE(b2.Contains(b2point));
-
- BOOST_REQUIRE_SMALL(b1.MinDistance(b1.Center()), 1e-5);
- BOOST_REQUIRE_CLOSE(b1.MinDistance(b2.Center()), 1 - 0.3, 1e-5);
- BOOST_REQUIRE_CLOSE(b2.MinDistance(b1.Center()), 1 - 0.4, 1e-5);
- BOOST_REQUIRE_CLOSE(b2.MaxDistance(b1.Center()), 1 + 0.4, 1e-5);
- BOOST_REQUIRE_CLOSE(b1.MaxDistance(b2.Center()), 1 + 0.3, 1e-5);
-}
-
-/**
- * Ensure that we calculate the correct minimum distance between a point and a
- * bound.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundRootMinDistancePoint)
-{
- // We'll do the calculation in five dimensions, and we'll use three cases for
- // the point: point is outside the bound; point is on the edge of the bound;
- // point is inside the bound. In the latter two cases, the distance should be
- // zero.
- HRectBound<2, true> b(5);
-
- b[0] = Range(0.0, 2.0);
- b[1] = Range(1.0, 5.0);
- b[2] = Range(-2.0, 2.0);
- b[3] = Range(-5.0, -2.0);
- b[4] = Range(1.0, 2.0);
-
- arma::vec point = "-2.0 0.0 10.0 3.0 3.0";
-
- // This will be the Euclidean distance.
- BOOST_REQUIRE_CLOSE(b.MinDistance(point), sqrt(95.0), 1e-5);
-
- point = "2.0 5.0 2.0 -5.0 1.0";
-
- BOOST_REQUIRE_SMALL(b.MinDistance(point), 1e-5);
-
- point = "1.0 2.0 0.0 -2.0 1.5";
-
- BOOST_REQUIRE_SMALL(b.MinDistance(point), 1e-5);
-}
-
-/**
- * Ensure that we calculate the correct minimum distance between a bound and
- * another bound.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundRootMinDistanceBound)
-{
- // We'll do the calculation in five dimensions, and we can use six cases.
- // The other bound is completely outside the bound; the other bound is on the
- // edge of the bound; the other bound partially overlaps the bound; the other
- // bound fully overlaps the bound; the other bound is entirely inside the
- // bound; the other bound entirely envelops the bound.
- HRectBound<2, true> b(5);
-
- b[0] = Range(0.0, 2.0);
- b[1] = Range(1.0, 5.0);
- b[2] = Range(-2.0, 2.0);
- b[3] = Range(-5.0, -2.0);
- b[4] = Range(1.0, 2.0);
-
- HRectBound<2, true> c(5);
-
- // The other bound is completely outside the bound.
- c[0] = Range(-5.0, -2.0);
- c[1] = Range(6.0, 7.0);
- c[2] = Range(-2.0, 2.0);
- c[3] = Range(2.0, 5.0);
- c[4] = Range(3.0, 4.0);
-
- BOOST_REQUIRE_CLOSE(b.MinDistance(c), sqrt(22.0), 1e-5);
- BOOST_REQUIRE_CLOSE(c.MinDistance(b), sqrt(22.0), 1e-5);
-
- // The other bound is on the edge of the bound.
- c[0] = Range(-2.0, 0.0);
- c[1] = Range(0.0, 1.0);
- c[2] = Range(-3.0, -2.0);
- c[3] = Range(-10.0, -5.0);
- c[4] = Range(2.0, 3.0);
-
- BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
- BOOST_REQUIRE_SMALL(c.MinDistance(b), 1e-5);
-
- // The other bound partially overlaps the bound.
- c[0] = Range(-2.0, 1.0);
- c[1] = Range(0.0, 2.0);
- c[2] = Range(-2.0, 2.0);
- c[3] = Range(-8.0, -4.0);
- c[4] = Range(0.0, 4.0);
-
- BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
- BOOST_REQUIRE_SMALL(c.MinDistance(b), 1e-5);
-
- // The other bound fully overlaps the bound.
- BOOST_REQUIRE_SMALL(b.MinDistance(b), 1e-5);
- BOOST_REQUIRE_SMALL(c.MinDistance(c), 1e-5);
-
- // The other bound is entirely inside the bound / the other bound entirely
- // envelops the bound.
- c[0] = Range(-1.0, 3.0);
- c[1] = Range(0.0, 6.0);
- c[2] = Range(-3.0, 3.0);
- c[3] = Range(-7.0, 0.0);
- c[4] = Range(0.0, 5.0);
-
- BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
- BOOST_REQUIRE_SMALL(c.MinDistance(b), 1e-5);
-
- // Now we must be sure that the minimum distance to itself is 0.
- BOOST_REQUIRE_SMALL(b.MinDistance(b), 1e-5);
- BOOST_REQUIRE_SMALL(c.MinDistance(c), 1e-5);
-}
-
-/**
- * Ensure that we calculate the correct maximum distance between a bound and a
- * point. This uses the same test cases as the MinDistance test.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundRootMaxDistancePoint)
-{
- // We'll do the calculation in five dimensions, and we'll use three cases for
- // the point: point is outside the bound; point is on the edge of the bound;
- // point is inside the bound. In the latter two cases, the distance should be
- // zero.
- HRectBound<2, true> b(5);
-
- b[0] = Range(0.0, 2.0);
- b[1] = Range(1.0, 5.0);
- b[2] = Range(-2.0, 2.0);
- b[3] = Range(-5.0, -2.0);
- b[4] = Range(1.0, 2.0);
-
- arma::vec point = "-2.0 0.0 10.0 3.0 3.0";
-
- // This will be the Euclidean distance.
- BOOST_REQUIRE_CLOSE(b.MaxDistance(point), sqrt(253.0), 1e-5);
-
- point = "2.0 5.0 2.0 -5.0 1.0";
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(point), sqrt(46.0), 1e-5);
-
- point = "1.0 2.0 0.0 -2.0 1.5";
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(point), sqrt(23.25), 1e-5);
-}
-
-/**
- * Ensure that we calculate the correct maximum distance between a bound and
- * another bound. This uses the same test cases as the MinDistance test.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundRootMaxDistanceBound)
-{
- // We'll do the calculation in five dimensions, and we can use six cases.
- // The other bound is completely outside the bound; the other bound is on the
- // edge of the bound; the other bound partially overlaps the bound; the other
- // bound fully overlaps the bound; the other bound is entirely inside the
- // bound; the other bound entirely envelops the bound.
- HRectBound<2, true> b(5);
-
- b[0] = Range(0.0, 2.0);
- b[1] = Range(1.0, 5.0);
- b[2] = Range(-2.0, 2.0);
- b[3] = Range(-5.0, -2.0);
- b[4] = Range(1.0, 2.0);
-
- HRectBound<2, true> c(5);
-
- // The other bound is completely outside the bound.
- c[0] = Range(-5.0, -2.0);
- c[1] = Range(6.0, 7.0);
- c[2] = Range(-2.0, 2.0);
- c[3] = Range(2.0, 5.0);
- c[4] = Range(3.0, 4.0);
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), sqrt(210.0), 1e-5);
- BOOST_REQUIRE_CLOSE(c.MaxDistance(b), sqrt(210.0), 1e-5);
-
- // The other bound is on the edge of the bound.
- c[0] = Range(-2.0, 0.0);
- c[1] = Range(0.0, 1.0);
- c[2] = Range(-3.0, -2.0);
- c[3] = Range(-10.0, -5.0);
- c[4] = Range(2.0, 3.0);
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), sqrt(134.0), 1e-5);
- BOOST_REQUIRE_CLOSE(c.MaxDistance(b), sqrt(134.0), 1e-5);
-
- // The other bound partially overlaps the bound.
- c[0] = Range(-2.0, 1.0);
- c[1] = Range(0.0, 2.0);
- c[2] = Range(-2.0, 2.0);
- c[3] = Range(-8.0, -4.0);
- c[4] = Range(0.0, 4.0);
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), sqrt(102.0), 1e-5);
- BOOST_REQUIRE_CLOSE(c.MaxDistance(b), sqrt(102.0), 1e-5);
-
- // The other bound fully overlaps the bound.
- BOOST_REQUIRE_CLOSE(b.MaxDistance(b), sqrt(46.0), 1e-5);
- BOOST_REQUIRE_CLOSE(c.MaxDistance(c), sqrt(61.0), 1e-5);
-
- // The other bound is entirely inside the bound / the other bound entirely
- // envelops the bound.
- c[0] = Range(-1.0, 3.0);
- c[1] = Range(0.0, 6.0);
- c[2] = Range(-3.0, 3.0);
- c[3] = Range(-7.0, 0.0);
- c[4] = Range(0.0, 5.0);
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), sqrt(100.0), 1e-5);
- BOOST_REQUIRE_CLOSE(c.MaxDistance(b), sqrt(100.0), 1e-5);
-
- // Identical bounds. This will be the sum of the squared widths in each
- // dimension.
- BOOST_REQUIRE_CLOSE(b.MaxDistance(b), sqrt(46.0), 1e-5);
- BOOST_REQUIRE_CLOSE(c.MaxDistance(c), sqrt(162.0), 1e-5);
-
- // One last additional case. If the bound encloses only one point, the
- // maximum distance between it and itself is 0.
- HRectBound<2, true> d(2);
-
- d[0] = Range(2.0, 2.0);
- d[1] = Range(3.0, 3.0);
-
- BOOST_REQUIRE_SMALL(d.MaxDistance(d), 1e-5);
-}
-
-/**
- * Ensure that the ranges returned by RangeDistance() are equal to the minimum
- * and maximum distance. We will perform this test by creating random bounds
- * and comparing the behavior to MinDistance() and MaxDistance() -- so this test
- * is assuming that those passed and operate correctly.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundRootRangeDistanceBound)
-{
- for (int i = 0; i < 50; i++)
- {
- size_t dim = math::RandInt(20);
-
- HRectBound<2, true> a(dim);
- HRectBound<2, true> b(dim);
-
- // We will set the low randomly and the width randomly for each dimension of
- // each bound.
- arma::vec loA(dim);
- arma::vec widthA(dim);
-
- loA.randu();
- widthA.randu();
-
- arma::vec lo_b(dim);
- arma::vec width_b(dim);
-
- lo_b.randu();
- width_b.randu();
-
- for (size_t j = 0; j < dim; j++)
- {
- a[j] = Range(loA[j], loA[j] + widthA[j]);
- b[j] = Range(lo_b[j], lo_b[j] + width_b[j]);
- }
-
- // Now ensure that MinDistance and MaxDistance report the same.
- Range r = a.RangeDistance(b);
- Range s = b.RangeDistance(a);
-
- BOOST_REQUIRE_CLOSE(r.Lo(), s.Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(r.Hi(), s.Hi(), 1e-5);
-
- BOOST_REQUIRE_CLOSE(r.Lo(), a.MinDistance(b), 1e-5);
- BOOST_REQUIRE_CLOSE(r.Hi(), a.MaxDistance(b), 1e-5);
-
- BOOST_REQUIRE_CLOSE(s.Lo(), b.MinDistance(a), 1e-5);
- BOOST_REQUIRE_CLOSE(s.Hi(), b.MaxDistance(a), 1e-5);
- }
-}
-
-/**
- * Ensure that the ranges returned by RangeDistance() are equal to the minimum
- * and maximum distance. We will perform this test by creating random bounds
- * and comparing the bheavior to MinDistance() and MaxDistance() -- so this test
- * is assuming that those passed and operate correctly. This is for the
- * bound-to-point case.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundRootRangeDistancePoint)
-{
- for (int i = 0; i < 20; i++)
- {
- size_t dim = math::RandInt(20);
-
- HRectBound<2, true> a(dim);
-
- // We will set the low randomly and the width randomly for each dimension of
- // each bound.
- arma::vec loA(dim);
- arma::vec widthA(dim);
-
- loA.randu();
- widthA.randu();
-
- for (size_t j = 0; j < dim; j++)
- a[j] = Range(loA[j], loA[j] + widthA[j]);
-
- // Now run the test on a few points.
- for (int j = 0; j < 10; j++)
- {
- arma::vec point(dim);
-
- point.randu();
-
- Range r = a.RangeDistance(point);
-
- BOOST_REQUIRE_CLOSE(r.Lo(), a.MinDistance(point), 1e-5);
- BOOST_REQUIRE_CLOSE(r.Hi(), a.MaxDistance(point), 1e-5);
- }
- }
-}
-
-/**
- * Ensure that HRectBound::Diameter() works properly.
- */
-BOOST_AUTO_TEST_CASE(HRectBoundDiameter)
-{
- HRectBound<3> b(4);
- b[0] = math::Range(0.0, 1.0);
- b[1] = math::Range(-1.0, 0.0);
- b[2] = math::Range(2.0, 3.0);
- b[3] = math::Range(7.0, 7.0);
-
- BOOST_REQUIRE_CLOSE(b.Diameter(), std::pow(3.0, 1.0 / 3.0), 1e-5);
-
- HRectBound<2, false> c(4);
- c[0] = math::Range(0.0, 1.0);
- c[1] = math::Range(-1.0, 0.0);
- c[2] = math::Range(2.0, 3.0);
- c[3] = math::Range(0.0, 0.0);
-
- BOOST_REQUIRE_CLOSE(c.Diameter(), 3.0, 1e-5);
-
- HRectBound<5> d(2);
- d[0] = math::Range(2.2, 2.2);
- d[1] = math::Range(1.0, 1.0);
-
- BOOST_REQUIRE_SMALL(d.Diameter(), 1e-5);
-}
-
-/**
- * Ensure that a bound, by default, is empty and has no dimensionality, and the
- * box size vector is empty.
- */
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundEmptyConstructor)
-{
- PeriodicHRectBound<2> b;
-
- BOOST_REQUIRE_EQUAL(b.Dim(), 0);
- BOOST_REQUIRE_EQUAL(b.Box().n_elem, 0);
-}
-
-/**
- * Ensure that when we specify the dimensionality in the constructor, it is
- * correct, and the bounds are all the empty set.
- */
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundBoxConstructor)
-{
- PeriodicHRectBound<2> b(arma::vec("5 6"));
-
- BOOST_REQUIRE_EQUAL(b.Dim(), 2);
- BOOST_REQUIRE_SMALL(b[0].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(b[1].Width(), 1e-5);
- BOOST_REQUIRE_EQUAL(b.Box().n_elem, 2);
- BOOST_REQUIRE_CLOSE(b.Box()[0], 5.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b.Box()[1], 6.0, 1e-5);
-
- PeriodicHRectBound<2> d(arma::vec("2 3 4 5 6"));
-
- BOOST_REQUIRE_EQUAL(d.Dim(), 5);
- BOOST_REQUIRE_SMALL(d[0].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(d[1].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(d[2].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(d[3].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(d[4].Width(), 1e-5);
- BOOST_REQUIRE_EQUAL(d.Box().n_elem, 5);
- BOOST_REQUIRE_CLOSE(d.Box()[0], 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Box()[1], 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Box()[2], 4.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Box()[3], 5.0, 1e-5);
- BOOST_REQUIRE_CLOSE(d.Box()[4], 6.0, 1e-5);
-}
-
-/**
- * Test the copy constructor.
- */
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundCopyConstructor)
-{
- PeriodicHRectBound<2> b(arma::vec("3 4"));
- b[0] = Range(0.0, 2.0);
- b[1] = Range(2.0, 3.0);
-
- PeriodicHRectBound<2> c(b);
-
- BOOST_REQUIRE_EQUAL(c.Dim(), 2);
- BOOST_REQUIRE_SMALL(c[0].Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(c[0].Hi(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c[1].Lo(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c[1].Hi(), 3.0, 1e-5);
- BOOST_REQUIRE_EQUAL(c.Box().n_elem, 2);
- BOOST_REQUIRE_CLOSE(c.Box()[0], 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c.Box()[1], 4.0, 1e-5);
-}
-
-/**
- * Test the assignment operator.
- *
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundAssignmentOperator)
-{
- PeriodicHRectBound<2> b(arma::vec("3 4"));
- b[0] = Range(0.0, 2.0);
- b[1] = Range(2.0, 3.0);
-
- PeriodicHRectBound<2> c(arma::vec("3 4 5 6"));
-
- c = b;
-
- BOOST_REQUIRE_EQUAL(c.Dim(), 2);
- BOOST_REQUIRE_SMALL(c[0].Lo(), 1e-5);
- BOOST_REQUIRE_CLOSE(c[0].Hi(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c[1].Lo(), 2.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c[1].Hi(), 3.0, 1e-5);
- BOOST_REQUIRE_EQUAL(c.Box().n_elem, 2);
- BOOST_REQUIRE_CLOSE(c.Box()[0], 3.0, 1e-5);
- BOOST_REQUIRE_CLOSE(c.Box()[1], 4.0, 1e-5);
-}*/
-
-/**
- * Ensure that we can set the box size correctly.
- *
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundSetBoxSize)
-{
- PeriodicHRectBound<2> b(arma::vec("1 2"));
-
- b.SetBoxSize(arma::vec("10 12"));
-
- BOOST_REQUIRE_EQUAL(b.Box().n_elem, 2);
- BOOST_REQUIRE_CLOSE(b.Box()[0], 10.0, 1e-5);
- BOOST_REQUIRE_CLOSE(b.Box()[1], 12.0, 1e-5);
-}*/
-
-/**
- * Ensure that we can clear the dimensions correctly. This does not involve the
- * box size at all, so the test can be identical to the HRectBound test.
- *
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundClear)
-{
- // We'll do this with two dimensions only.
- PeriodicHRectBound<2> b(arma::vec("5 5"));
-
- b[0] = Range(0.0, 2.0);
- b[1] = Range(2.0, 4.0);
-
- // Now we just need to make sure that we clear the range.
- b.Clear();
-
- BOOST_REQUIRE_SMALL(b[0].Width(), 1e-5);
- BOOST_REQUIRE_SMALL(b[1].Width(), 1e-5);
-}*/
-
-/**
- * Ensure that we get the correct centroid for our bound.
- *
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundCentroid) {
- // Create a simple 3-dimensional bound. The centroid is not affected by the
- // periodic coordinates.
- PeriodicHRectBound<2> b(arma::vec("100 100 100"));
-
- b[0] = Range(0.0, 5.0);
- b[1] = Range(-2.0, -1.0);
- b[2] = Range(-10.0, 50.0);
-
- arma::vec centroid;
-
- b.Centroid(centroid);
-
- BOOST_REQUIRE_EQUAL(centroid.n_elem, 3);
- BOOST_REQUIRE_CLOSE(centroid[0], 2.5, 1e-5);
- BOOST_REQUIRE_CLOSE(centroid[1], -1.5, 1e-5);
- BOOST_REQUIRE_CLOSE(centroid[2], 20.0, 1e-5);
-}*/
-
-/**
- * Correctly calculate the minimum distance between the bound and a point in
- * periodic coordinates. We have to account for the shifts necessary in
- * periodic coordinates too, so that makes testing this a little more difficult.
- *
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundMinDistancePoint)
-{
- // First, we'll start with a simple 2-dimensional case where the point is
- // inside the bound, then on the edge of the bound, then barely outside the
- // bound. The box size will be large enough that this is basically the
- // HRectBound case.
- PeriodicHRectBound<2> b(arma::vec("100 100"));
-
- b[0] = Range(0.0, 5.0);
- b[1] = Range(2.0, 4.0);
-
- // Inside the bound.
- arma::vec point = "2.5 3.0";
-
- BOOST_REQUIRE_SMALL(b.MinDistance(point), 1e-5);
-
- // On the edge.
- point = "5.0 4.0";
-
- BOOST_REQUIRE_SMALL(b.MinDistance(point), 1e-5);
-
- // And just a little outside the bound.
- point = "6.0 5.0";
-
- BOOST_REQUIRE_CLOSE(b.MinDistance(point), 2.0, 1e-5);
-
- // Now we start to invoke the periodicity. This point will "alias" to (-1,
- // -1).
- point = "99.0 99.0";
-
- BOOST_REQUIRE_CLOSE(b.MinDistance(point), 10.0, 1e-5);
-
- // We will perform several tests on a one-dimensional bound.
- PeriodicHRectBound<2> a(arma::vec("5.0"));
- point.set_size(1);
-
- a[0] = Range(2.0, 4.0); // Entirely inside box.
- point[0] = 7.5; // Inside first right image of the box.
-
- BOOST_REQUIRE_SMALL(a.MinDistance(point), 1e-5);
-
- a[0] = Range(0.0, 5.0); // Fills box fully.
- point[1] = 19.3; // Inside the box, which covers everything.
-
- BOOST_REQUIRE_SMALL(a.MinDistance(point), 1e-5);
-
- a[0] = Range(-10.0, 10.0); // Larger than the box.
- point[0] = -500.0; // Inside the box, which covers everything.
-
- BOOST_REQUIRE_SMALL(a.MinDistance(point), 1e-5);
-
- a[0] = Range(-2.0, 1.0); // Crosses over an edge.
- point[0] = 2.9; // The first right image of the bound starts at 3.0.
-
- BOOST_REQUIRE_CLOSE(a.MinDistance(point), 0.01, 1e-5);
-
- a[0] = Range(2.0, 4.0); // Inside box.
- point[0] = 0.0; // Closest to the first left image of the bound.
-
- BOOST_REQUIRE_CLOSE(a.MinDistance(point), 1.0, 1e-5);
-
- a[0] = Range(0.0, 2.0); // On edge of box.
- point[0] = 7.1; // 0.1 away from the first right image of the bound.
-
- BOOST_REQUIRE_CLOSE(a.MinDistance(point), 0.01, 1e-5);
-
- PeriodicHRectBound<2> d(arma::vec("0.0"));
- d[0] = Range(-10.0, 10.0); // Box is of infinite size.
- point[0] = 810.0; // 800 away from the only image of the box.
-
- BOOST_REQUIRE_CLOSE(d.MinDistance(point), 640000.0, 1e-5);
-
- PeriodicHRectBound<2> e(arma::vec("-5.0"));
- e[0] = Range(2.0, 4.0); // Box size of -5 should function the same as 5.
- point[0] = -10.8; // Should alias to 4.2.
-
- BOOST_REQUIRE_CLOSE(e.MinDistance(point), 0.04, 1e-5);
-
- // Switch our bound to a higher dimensionality. This should ensure that the
- // dimensions are independent like they should be.
- PeriodicHRectBound<2> c(arma::vec("5.0 5.0 5.0 5.0 5.0 5.0 0.0 -5.0"));
-
- c[0] = Range(2.0, 4.0); // Entirely inside box.
- c[1] = Range(0.0, 5.0); // Fills box fully.
- c[2] = Range(-10.0, 10.0); // Larger than the box.
- c[3] = Range(-2.0, 1.0); // Crosses over an edge.
- c[4] = Range(2.0, 4.0); // Inside box.
- c[5] = Range(0.0, 2.0); // On edge of box.
- c[6] = Range(-10.0, 10.0); // Box is of infinite size.
- c[7] = Range(2.0, 4.0); // Box size of -5 should function the same as 5.
-
- point.set_size(8);
- point[0] = 7.5; // Inside first right image of the box.
- point[1] = 19.3; // Inside the box, which covers everything.
- point[2] = -500.0; // Inside the box, which covers everything.
- point[3] = 2.9; // The first right image of the bound starts at 3.0.
- point[4] = 0.0; // Closest to the first left image of the bound.
- point[5] = 7.1; // 0.1 away from the first right image of the bound.
- point[6] = 810.0; // 800 away from the only image of the box.
- point[7] = -10.8; // Should alias to 4.2.
-
- BOOST_REQUIRE_CLOSE(c.MinDistance(point), 640001.06, 1e-10);
-}*/
-
-/**
- * Correctly calculate the minimum distance between the bound and another bound in
- * periodic coordinates. We have to account for the shifts necessary in
- * periodic coordinates too, so that makes testing this a little more difficult.
- *
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundMinDistanceBound)
-{
- // First, we'll start with a simple 2-dimensional case where the bounds are nonoverlapping,
- // then one bound is on the edge of the other bound,
- // then overlapping, then one range entirely covering the other. The box size will be large enough that this is basically the
- // HRectBound case.
- PeriodicHRectBound<2> b(arma::vec("100 100"));
- PeriodicHRectBound<2> c(arma::vec("100 100"));
-
- b[0] = Range(0.0, 5.0);
- b[1] = Range(2.0, 4.0);
-
- // Inside the bound.
- c[0] = Range(7.0, 9.0);
- c[1] = Range(10.0,12.0);
-
-
- BOOST_REQUIRE_CLOSE(b.MinDistance(c), 40.0, 1e-5);
-
- // On the edge.
- c[0] = Range(5.0, 8.0);
- c[1] = Range(4.0, 6.0);
-
- BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
-
- // Overlapping the bound.
- c[0] = Range(3.0, 6.0);
- c[1] = Range(1.0, 3.0);
-
- BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
-
- // One range entirely covering the other
-
- c[0] = Range(0.0, 6.0);
- c[1] = Range(1.0, 7.0);
-
- BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
-
- // Now we start to invoke the periodicity. These bounds "alias" to (-3.0,
- // -1.0) and (5,0,6.0).
-
- c[0] = Range(97.0, 99.0);
- c[1] = Range(105.0, 106.0);
-
- BOOST_REQUIRE_CLOSE(b.MinDistance(c), 2.0, 1e-5);
-
- // We will perform several tests on a one-dimensional bound and smaller box size and mostly overlapping.
- PeriodicHRectBound<2> a(arma::vec("5.0"));
- PeriodicHRectBound<2> d(a);
-
- a[0] = Range(2.0, 4.0); // Entirely inside box.
- d[0] = Range(7.5, 10.0); // In the right image of the box, overlapping ranges.
-
- BOOST_REQUIRE_SMALL(a.MinDistance(d), 1e-5);
-
- a[0] = Range(0.0, 5.0); // Fills box fully.
- d[0] = Range(19.3, 21.0); // Two intervals inside the box, same as range of b[0].
-
- BOOST_REQUIRE_SMALL(a.MinDistance(d), 1e-5);
-
- a[0] = Range(-10.0, 10.0); // Larger than the box.
- d[0] = Range(-500.0, -498.0); // Inside the box, which covers everything.
-
- BOOST_REQUIRE_SMALL(a.MinDistance(d), 1e-5);
-
- a[0] = Range(-2.0, 1.0); // Crosses over an edge.
- d[0] = Range(2.9, 5.1); // The first right image of the bound starts at 3.0. Overlapping
-
- BOOST_REQUIRE_SMALL(a.MinDistance(d), 1e-5);
-
- a[0] = Range(-1.0, 1.0); // Crosses over an edge.
- d[0] = Range(11.9, 12.5); // The first right image of the bound starts at 4.0.
- BOOST_REQUIRE_CLOSE(a.MinDistance(d), 0.81, 1e-5);
-
- a[0] = Range(2.0, 3.0);
- d[0] = Range(9.5, 11);
- BOOST_REQUIRE_CLOSE(a.MinDistance(d), 1.0, 1e-5);
-
-}*/
-
-/**
- * Correctly calculate the maximum distance between the bound and a point in
- * periodic coordinates. We have to account for the shifts necessary in
- * periodic coordinates too, so that makes testing this a little more difficult.
- *
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundMaxDistancePoint)
-{
- // First, we'll start with a simple 2-dimensional case where the point is
- // inside the bound, then on the edge of the bound, then barely outside the
- // bound. The box size will be large enough that this is basically the
- // HRectBound case.
- PeriodicHRectBound<2> b(arma::vec("100 100"));
-
- b[0] = Range(0.0, 5.0);
- b[1] = Range(2.0, 4.0);
-
- // Inside the bound.
- arma::vec point = "2.5 3.0";
-
- //BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 7.25, 1e-5);
-
- // On the edge.
- point = "5.0 4.0";
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 29.0, 1e-5);
-
- // And just a little outside the bound.
- point = "6.0 5.0";
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 45.0, 1e-5);
-
- // Now we start to invoke the periodicity. This point will "alias" to (-1,
- // -1).
- point = "99.0 99.0";
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 61.0, 1e-5);
-
- // We will perform several tests on a one-dimensional bound and smaller box size.
- PeriodicHRectBound<2> a(arma::vec("5.0"));
- point.set_size(1);
-
- a[0] = Range(2.0, 4.0); // Entirely inside box.
- point[0] = 7.5; // Inside first right image of the box.
-
- BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 2.25, 1e-5);
-
- a[0] = Range(0.0, 5.0); // Fills box fully.
- point[1] = 19.3; // Inside the box, which covers everything.
-
- BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 18.49, 1e-5);
-
- a[0] = Range(-10.0, 10.0); // Larger than the box.
- point[0] = -500.0; // Inside the box, which covers everything.
-
- BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 25.0, 1e-5);
-
- a[0] = Range(-2.0, 1.0); // Crosses over an edge.
- point[0] = 2.9; // The first right image of the bound starts at 3.0.
-
- BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 8.41, 1e-5);
-
- a[0] = Range(2.0, 4.0); // Inside box.
- point[0] = 0.0; // Farthest from the first right image of the bound.
-
- BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 25.0, 1e-5);
-
- a[0] = Range(0.0, 2.0); // On edge of box.
- point[0] = 7.1; // 2.1 away from the first left image of the bound.
-
- BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 4.41, 1e-5);
-
- PeriodicHRectBound<2> d(arma::vec("0.0"));
- d[0] = Range(-10.0, 10.0); // Box is of infinite size.
- point[0] = 810.0; // 820 away from the only left image of the box.
-
- BOOST_REQUIRE_CLOSE(d.MinDistance(point), 672400.0, 1e-5);
-
- PeriodicHRectBound<2> e(arma::vec("-5.0"));
- e[0] = Range(2.0, 4.0); // Box size of -5 should function the same as 5.
- point[0] = -10.8; // Should alias to 4.2.
-
- BOOST_REQUIRE_CLOSE(e.MaxDistance(point), 4.84, 1e-5);
-
- // Switch our bound to a higher dimensionality. This should ensure that the
- // dimensions are independent like they should be.
- PeriodicHRectBound<2> c(arma::vec("5.0 5.0 5.0 5.0 5.0 5.0 0.0 -5.0"));
-
- c[0] = Range(2.0, 4.0); // Entirely inside box.
- c[1] = Range(0.0, 5.0); // Fills box fully.
- c[2] = Range(-10.0, 10.0); // Larger than the box.
- c[3] = Range(-2.0, 1.0); // Crosses over an edge.
- c[4] = Range(2.0, 4.0); // Inside box.
- c[5] = Range(0.0, 2.0); // On edge of box.
- c[6] = Range(-10.0, 10.0); // Box is of infinite size.
- c[7] = Range(2.0, 4.0); // Box size of -5 should function the same as 5.
-
- point.set_size(8);
- point[0] = 7.5; // Inside first right image of the box.
- point[1] = 19.3; // Inside the box, which covers everything.
- point[2] = -500.0; // Inside the box, which covers everything.
- point[3] = 2.9; // The first right image of the bound starts at 3.0.
- point[4] = 0.0; // Closest to the first left image of the bound.
- point[5] = 7.1; // 0.1 away from the first right image of the bound.
- point[6] = 810.0; // 800 away from the only image of the box.
- point[7] = -10.8; // Should alias to 4.2.
-
- BOOST_REQUIRE_CLOSE(c.MaxDistance(point), 672630.65, 1e-10);
-}*/
-
-/**
- * Correctly calculate the maximum distance between the bound and another bound in
- * periodic coordinates. We have to account for the shifts necessary in
- * periodic coordinates too, so that makes testing this a little more difficult.
- *
-BOOST_AUTO_TEST_CASE(PeriodicHRectBoundMaxDistanceBound)
-{
- // First, we'll start with a simple 2-dimensional case where the bounds are nonoverlapping,
- // then one bound is on the edge of the other bound,
- // then overlapping, then one range entirely covering the other. The box size will be large enough that this is basically the
- // HRectBound case.
- PeriodicHRectBound<2> b(arma::vec("100 100"));
- PeriodicHRectBound<2> c(b);
-
- b[0] = Range(0.0, 5.0);
- b[1] = Range(2.0, 4.0);
-
- // Inside the bound.
-
- c[0] = Range(7.0, 9.0);
- c[1] = Range(10.0,12.0);
-
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 181.0, 1e-5);
-
- // On the edge.
-
- c[0] = Range(5.0, 8.0);
- c[1] = Range(4.0, 6.0);
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 80.0, 1e-5);
-
- // Overlapping the bound.
-
- c[0] = Range(3.0, 6.0);
- c[1] = Range(1.0, 3.0);
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 45.0, 1e-5);
-
- // One range entirely covering the other
-
- c[0] = Range(0.0, 6.0);
- c[1] = Range(1.0, 7.0);
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 61.0, 1e-5);
-
- // Now we start to invoke the periodicity. Thess bounds "alias" to (-3.0,
- // -1.0) and (5,0,6.0).
-
- c[0] = Range(97.0, 99.0);
- c[1] = Range(105.0, 106.0);
-
- BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 80.0, 1e-5);
-
- // We will perform several tests on a one-dimensional bound and smaller box size.
- PeriodicHRectBound<2> a(arma::vec("5.0"));
- PeriodicHRectBound<2> d(a);
-
- a[0] = Range(2.0, 4.0); // Entirely inside box.
- d[0] = Range(7.5, 10); // In the right image of the box, overlapping ranges.
-
- BOOST_REQUIRE_CLOSE(a.MaxDistance(d), 9.0, 1e-5);
-
- a[0] = Range(0.0, 5.0); // Fills box fully.
- d[0] = Range(19.3, 21.0); // Two intervals inside the box, same as range of b[0].
-
- BOOST_REQUIRE_CLOSE(a.MaxDistance(d), 18.49, 1e-5);
-
- a[0] = Range(-10.0, 10.0); // Larger than the box.
- d[0] = Range(-500.0, -498.0); // Inside the box, which covers everything.
-
- BOOST_REQUIRE_CLOSE(a.MaxDistance(d), 9.0, 1e-5);
-
- a[0] = Range(-2.0, 1.0); // Crosses over an edge.
- d[0] = Range(2.9, 5.1); // The first right image of the bound starts at 3.0.
-
- BOOST_REQUIRE_CLOSE(a.MaxDistance(d), 24.01, 1e-5);
-}*/
-
-
-/**
- * It seems as though Bill has stumbled across a bug where
- * BinarySpaceTree<>::count() returns something different than
- * BinarySpaceTree<>::count_. So, let's build a simple tree and make sure they
- * are the same.
- */
-BOOST_AUTO_TEST_CASE(TreeCountMismatch)
-{
- arma::mat dataset = "2.0 5.0 9.0 4.0 8.0 7.0;"
- "3.0 4.0 6.0 7.0 1.0 2.0 ";
-
- // Leaf size of 1.
- BinarySpaceTree<HRectBound<2> > rootNode(dataset, 1);
-
- BOOST_REQUIRE(rootNode.Count() == 6);
- BOOST_REQUIRE(rootNode.Left()->Count() == 3);
- BOOST_REQUIRE(rootNode.Left()->Left()->Count() == 2);
- BOOST_REQUIRE(rootNode.Left()->Left()->Left()->Count() == 1);
- BOOST_REQUIRE(rootNode.Left()->Left()->Right()->Count() == 1);
- BOOST_REQUIRE(rootNode.Left()->Right()->Count() == 1);
- BOOST_REQUIRE(rootNode.Right()->Count() == 3);
- BOOST_REQUIRE(rootNode.Right()->Left()->Count() == 2);
- BOOST_REQUIRE(rootNode.Right()->Left()->Left()->Count() == 1);
- BOOST_REQUIRE(rootNode.Right()->Left()->Right()->Count() == 1);
- BOOST_REQUIRE(rootNode.Right()->Right()->Count() == 1);
-}
-
-BOOST_AUTO_TEST_CASE(CheckParents)
-{
- arma::mat dataset = "2.0 5.0 9.0 4.0 8.0 7.0;"
- "3.0 4.0 6.0 7.0 1.0 2.0 ";
-
- // Leaf size of 1.
- BinarySpaceTree<HRectBound<2> > rootNode(dataset, 1);
-
- BOOST_REQUIRE_EQUAL(rootNode.Parent(),
- (BinarySpaceTree<HRectBound<2> >*) NULL);
- BOOST_REQUIRE_EQUAL(&rootNode, rootNode.Left()->Parent());
- BOOST_REQUIRE_EQUAL(&rootNode, rootNode.Right()->Parent());
- BOOST_REQUIRE_EQUAL(rootNode.Left(), rootNode.Left()->Left()->Parent());
- BOOST_REQUIRE_EQUAL(rootNode.Left(), rootNode.Left()->Right()->Parent());
- BOOST_REQUIRE_EQUAL(rootNode.Left()->Left(),
- rootNode.Left()->Left()->Left()->Parent());
- BOOST_REQUIRE_EQUAL(rootNode.Left()->Left(),
- rootNode.Left()->Left()->Right()->Parent());
- BOOST_REQUIRE_EQUAL(rootNode.Right(), rootNode.Right()->Left()->Parent());
- BOOST_REQUIRE_EQUAL(rootNode.Right(), rootNode.Right()->Right()->Parent());
- BOOST_REQUIRE_EQUAL(rootNode.Right()->Left(),
- rootNode.Right()->Left()->Left()->Parent());
- BOOST_REQUIRE_EQUAL(rootNode.Right()->Left(),
- rootNode.Right()->Left()->Right()->Parent());
-}
-
-BOOST_AUTO_TEST_CASE(CheckDataset)
-{
- arma::mat dataset = "2.0 5.0 9.0 4.0 8.0 7.0;"
- "3.0 4.0 6.0 7.0 1.0 2.0 ";
-
- // Leaf size of 1.
- BinarySpaceTree<HRectBound<2> > rootNode(dataset, 1);
-
- BOOST_REQUIRE_EQUAL(&rootNode.Dataset(), &dataset);
- BOOST_REQUIRE_EQUAL(&rootNode.Left()->Dataset(), &dataset);
- BOOST_REQUIRE_EQUAL(&rootNode.Right()->Dataset(), &dataset);
- BOOST_REQUIRE_EQUAL(&rootNode.Left()->Left()->Dataset(), &dataset);
- BOOST_REQUIRE_EQUAL(&rootNode.Left()->Right()->Dataset(), &dataset);
- BOOST_REQUIRE_EQUAL(&rootNode.Right()->Left()->Dataset(), &dataset);
- BOOST_REQUIRE_EQUAL(&rootNode.Right()->Right()->Dataset(), &dataset);
- BOOST_REQUIRE_EQUAL(&rootNode.Left()->Left()->Left()->Dataset(), &dataset);
- BOOST_REQUIRE_EQUAL(&rootNode.Left()->Left()->Right()->Dataset(), &dataset);
- BOOST_REQUIRE_EQUAL(&rootNode.Right()->Left()->Left()->Dataset(), &dataset);
- BOOST_REQUIRE_EQUAL(&rootNode.Right()->Left()->Right()->Dataset(), &dataset);
-}
-
-// Ensure FurthestDescendantDistance() works.
-BOOST_AUTO_TEST_CASE(FurthestDescendantDistanceTest)
-{
- arma::mat dataset = "1; 3"; // One point.
- BinarySpaceTree<HRectBound<2> > rootNode(dataset, 1);
-
- BOOST_REQUIRE_SMALL(rootNode.FurthestDescendantDistance(), 1e-5);
-
- dataset = "1 -1; 1 -1"; // Square of size [2, 2].
-
- // Both points are contained in the one node.
- BinarySpaceTree<HRectBound<2> > twoPoint(dataset);
- BOOST_REQUIRE_CLOSE(twoPoint.FurthestDescendantDistance(), sqrt(2.0), 1e-5);
-}
-
-// Forward declaration of methods we need for the next test.
-template<typename TreeType, typename MatType>
-bool CheckPointBounds(TreeType* node, const MatType& data);
-
-template<typename TreeType>
-void GenerateVectorOfTree(TreeType* node,
- size_t depth,
- std::vector<TreeType*>& v);
-
-template<int t_pow>
-bool DoBoundsIntersect(HRectBound<t_pow>& a,
- HRectBound<t_pow>& b,
- size_t ia,
- size_t ib);
-
-/**
- * Exhaustive kd-tree test based on #125.
- *
- * - Generate a random dataset of a random size.
- * - Build a tree on that dataset.
- * - Ensure all the permutation indices map back to the correct points.
- * - Verify that each point is contained inside all of the bounds of its parent
- * nodes.
- * - Verify that each bound at a particular level of the tree does not overlap
- * with any other bounds at that level.
- *
- * Then, we do that whole process a handful of times.
- */
-BOOST_AUTO_TEST_CASE(KdTreeTest)
-{
- typedef BinarySpaceTree<HRectBound<2> > TreeType;
-
- size_t maxRuns = 10; // Ten total tests.
- size_t pointIncrements = 1000; // Range is from 2000 points to 11000.
-
- // We use the default leaf size of 20.
- for(size_t run = 0; run < maxRuns; run++)
- {
- size_t dimensions = run + 2;
- size_t maxPoints = (run + 1) * pointIncrements;
-
- size_t size = maxPoints;
- arma::mat dataset = arma::mat(dimensions, size);
- arma::mat datacopy; // Used to test mappings.
-
- // Mappings for post-sort verification of data.
- std::vector<size_t> newToOld;
- std::vector<size_t> oldToNew;
-
- // Generate data.
- dataset.randu();
- datacopy = dataset; // Save a copy.
-
- // Build the tree itself.
- TreeType root(dataset, newToOld, oldToNew);
-
- // Ensure the size of the tree is correct.
- BOOST_REQUIRE_EQUAL(root.Count(), size);
-
- // Check the forward and backward mappings for correctness.
- for(size_t i = 0; i < size; i++)
- {
- for(size_t j = 0; j < dimensions; j++)
- {
- BOOST_REQUIRE_EQUAL(dataset(j, i), datacopy(j, newToOld[i]));
- BOOST_REQUIRE_EQUAL(dataset(j, oldToNew[i]), datacopy(j, i));
- }
- }
-
- // Now check that each point is contained inside of all bounds above it.
- CheckPointBounds(&root, dataset);
-
- // Now check that no peers overlap.
- std::vector<TreeType*> v;
- GenerateVectorOfTree(&root, 1, v);
-
- // Start with the first pair.
- size_t depth = 2;
- // Compare each peer against every other peer.
- while (depth < v.size())
- {
- for (size_t i = depth; i < 2 * depth && i < v.size(); i++)
- for (size_t j = i + 1; j < 2 * depth && j < v.size(); j++)
- if (v[i] != NULL && v[j] != NULL)
- BOOST_REQUIRE(!DoBoundsIntersect(v[i]->Bound(), v[j]->Bound(),
- i, j));
-
- depth *= 2;
- }
- }
-
- arma::mat dataset = arma::mat(25, 1000);
- for (size_t col = 0; col < dataset.n_cols; ++col)
- for (size_t row = 0; row < dataset.n_rows; ++row)
- dataset(row, col) = row + col;
-
- TreeType root(dataset);
- // Check the tree size.
- BOOST_REQUIRE_EQUAL(root.TreeSize(), 127);
- // Check the tree depth.
- BOOST_REQUIRE_EQUAL(root.TreeDepth(), 7);
-}
-
-// Recursively checks that each node contains all points that it claims to have.
-template<typename TreeType, typename MatType>
-bool CheckPointBounds(TreeType* node, const MatType& data)
-{
- if (node == NULL) // We have passed a leaf node.
- return true;
-
- TreeType* left = node->Left();
- TreeType* right = node->Right();
-
- size_t begin = node->Begin();
- size_t count = node->Count();
-
- // Check that each point which this tree claims is actually inside the tree.
- for (size_t index = begin; index < begin + count; index++)
- if (!node->Bound().Contains(data.col(index)))
- return false;
-
- return CheckPointBounds(left, data) && CheckPointBounds(right, data);
-}
-
-template<int t_pow>
-bool DoBoundsIntersect(HRectBound<t_pow>& a,
- HRectBound<t_pow>& b,
- size_t /* ia */,
- size_t /* ib */)
-{
- size_t dimensionality = a.Dim();
-
- Range r_a;
- Range r_b;
-
- for (size_t i = 0; i < dimensionality; i++)
- {
- r_a = a[i];
- r_b = b[i];
- if (r_a < r_b || r_a > r_b) // If a does not overlap b at all.
- return false;
- }
-
- return true;
-}
-
-template<typename TreeType>
-void GenerateVectorOfTree(TreeType* node,
- size_t depth,
- std::vector<TreeType*>& v)
-{
- if (node == NULL)
- return;
-
- if (depth >= v.size())
- v.resize(2 * depth + 1, NULL); // Resize to right size; fill with NULL.
-
- v[depth] = node;
-
- // Recurse to the left and right children.
- GenerateVectorOfTree(node->Left(), depth * 2, v);
- GenerateVectorOfTree(node->Right(), depth * 2 + 1, v);
-
- return;
-}
-
-#ifdef ARMA_HAS_SPMAT
-// Only run sparse tree tests if we are using Armadillo 3.6. Armadillo 3.4 has
-// some bugs that cause the kd-tree splitting procedure to fail terribly. Soon,
-// that version will be obsolete, though.
-#if !((ARMA_VERSION_MAJOR == 3) && (ARMA_VERSION_MINOR == 4))
-
-/**
- * Exhaustive sparse kd-tree test based on #125.
- *
- * - Generate a random dataset of a random size.
- * - Build a tree on that dataset.
- * - Ensure all the permutation indices map back to the correct points.
- * - Verify that each point is contained inside all of the bounds of its parent
- * nodes.
- * - Verify that each bound at a particular level of the tree does not overlap
- * with any other bounds at that level.
- *
- * Then, we do that whole process a handful of times.
- */
-BOOST_AUTO_TEST_CASE(ExhaustiveSparseKDTreeTest)
-{
- typedef BinarySpaceTree<HRectBound<2>, EmptyStatistic, arma::SpMat<double> >
- TreeType;
-
- size_t maxRuns = 2; // Two total tests.
- size_t pointIncrements = 200; // Range is from 200 points to 400.
-
- // We use the default leaf size of 20.
- for(size_t run = 0; run < maxRuns; run++)
- {
- size_t dimensions = run + 2;
- size_t maxPoints = (run + 1) * pointIncrements;
-
- size_t size = maxPoints;
- arma::SpMat<double> dataset = arma::SpMat<double>(dimensions, size);
- arma::SpMat<double> datacopy; // Used to test mappings.
-
- // Mappings for post-sort verification of data.
- std::vector<size_t> newToOld;
- std::vector<size_t> oldToNew;
-
- // Generate data.
- dataset.sprandu(dimensions, size, 0.1);
- datacopy = dataset; // Save a copy.
-
- // Build the tree itself.
- TreeType root(dataset, newToOld, oldToNew);
-
- // Ensure the size of the tree is correct.
- BOOST_REQUIRE_EQUAL(root.Count(), size);
-
- // Check the forward and backward mappings for correctness.
- for(size_t i = 0; i < size; i++)
- {
- for(size_t j = 0; j < dimensions; j++)
- {
- BOOST_REQUIRE_EQUAL(dataset(j, i), datacopy(j, newToOld[i]));
- BOOST_REQUIRE_EQUAL(dataset(j, oldToNew[i]), datacopy(j, i));
- }
- }
-
- // Now check that each point is contained inside of all bounds above it.
- CheckPointBounds(&root, dataset);
-
- // Now check that no peers overlap.
- std::vector<TreeType*> v;
- GenerateVectorOfTree(&root, 1, v);
-
- // Start with the first pair.
- size_t depth = 2;
- // Compare each peer against every other peer.
- while (depth < v.size())
- {
- for (size_t i = depth; i < 2 * depth && i < v.size(); i++)
- for (size_t j = i + 1; j < 2 * depth && j < v.size(); j++)
- if (v[i] != NULL && v[j] != NULL)
- BOOST_REQUIRE(!DoBoundsIntersect(v[i]->Bound(), v[j]->Bound(),
- i, j));
-
- depth *= 2;
- }
- }
-
- arma::SpMat<double> dataset(25, 1000);
- for (size_t col = 0; col < dataset.n_cols; ++col)
- for (size_t row = 0; row < dataset.n_rows; ++row)
- dataset(row, col) = row + col;
-
- TreeType root(dataset);
- // Check the tree size.
- BOOST_REQUIRE_EQUAL(root.TreeSize(), 127);
- // Check the tree depth.
- BOOST_REQUIRE_EQUAL(root.TreeDepth(), 7);
-}
-
-#endif // Using Armadillo 3.4.
-#endif // ARMA_HAS_SPMAT
-
-template<typename TreeType>
-void RecurseTreeCountLeaves(const TreeType& node, arma::vec& counts)
-{
- for (size_t i = 0; i < node.NumChildren(); ++i)
- {
- if (node.Child(i).NumChildren() == 0)
- counts[node.Child(i).Point()]++;
- else
- RecurseTreeCountLeaves<TreeType>(node.Child(i), counts);
- }
-}
-
-template<typename TreeType>
-void CheckSelfChild(const TreeType& node)
-{
- if (node.NumChildren() == 0)
- return; // No self-child applicable here.
-
- bool found = false;
- for (size_t i = 0; i < node.NumChildren(); ++i)
- {
- if (node.Child(i).Point() == node.Point())
- found = true;
-
- // Recursively check the children.
- CheckSelfChild(node.Child(i));
- }
-
- // Ensure this has its own self-child.
- BOOST_REQUIRE_EQUAL(found, true);
-}
-
-template<typename TreeType, typename MetricType>
-void CheckCovering(const TreeType& node)
-{
- // Return if a leaf. No checking necessary.
- if (node.NumChildren() == 0)
- return;
-
- const arma::mat& dataset = node.Dataset();
- const size_t nodePoint = node.Point();
-
- // To ensure that this node satisfies the covering principle, we must ensure
- // that the distance to each child is less than pow(base, scale).
- double maxDistance = pow(node.Base(), node.Scale());
- for (size_t i = 0; i < node.NumChildren(); ++i)
- {
- const size_t childPoint = node.Child(i).Point();
-
- double distance = MetricType::Evaluate(dataset.col(nodePoint),
- dataset.col(childPoint));
-
- BOOST_REQUIRE_LE(distance, maxDistance);
-
- // Check the child.
- CheckCovering<TreeType, MetricType>(node.Child(i));
- }
-}
-
-template<typename TreeType, typename MetricType>
-void CheckIndividualSeparation(const TreeType& constantNode,
- const TreeType& node)
-{
- // Don't check points at a lower scale.
- if (node.Scale() < constantNode.Scale())
- return;
-
- // If at a higher scale, recurse.
- if (node.Scale() > constantNode.Scale())
- {
- for (size_t i = 0; i < node.NumChildren(); ++i)
- {
- // Don't recurse into leaves.
- if (node.Child(i).NumChildren() > 0)
- CheckIndividualSeparation<TreeType, MetricType>(constantNode,
- node.Child(i));
- }
-
- return;
- }
-
- // Don't compare the same point against itself.
- if (node.Point() == constantNode.Point())
- return;
-
- // Now we know we are at the same scale, so make the comparison.
- const arma::mat& dataset = constantNode.Dataset();
- const size_t constantPoint = constantNode.Point();
- const size_t nodePoint = node.Point();
-
- // Make sure the distance is at least the following value (in accordance with
- // the separation principle of cover trees).
- double minDistance = pow(constantNode.Base(),
- constantNode.Scale());
-
- double distance = MetricType::Evaluate(dataset.col(constantPoint),
- dataset.col(nodePoint));
-
- BOOST_REQUIRE_GE(distance, minDistance);
-}
-
-template<typename TreeType, typename MetricType>
-void CheckSeparation(const TreeType& node, const TreeType& root)
-{
- // Check the separation between this point and all other points on this scale.
- CheckIndividualSeparation<TreeType, MetricType>(node, root);
-
- // Check the children, but only if they are not leaves. Leaves don't need to
- // be checked.
- for (size_t i = 0; i < node.NumChildren(); ++i)
- if (node.Child(i).NumChildren() > 0)
- CheckSeparation<TreeType, MetricType>(node.Child(i), root);
-}
-
-
-/**
- * Create a simple cover tree and then make sure it is valid.
- */
-BOOST_AUTO_TEST_CASE(SimpleCoverTreeConstructionTest)
-{
- // 20-point dataset.
- arma::mat data = arma::trans(arma::mat("0.0 0.0;"
- "1.0 0.0;"
- "0.5 0.5;"
- "2.0 2.0;"
- "-1.0 2.0;"
- "3.0 0.0;"
- "1.5 5.5;"
- "-2.0 -2.0;"
- "-1.5 1.5;"
- "0.0 4.0;"
- "2.0 1.0;"
- "2.0 1.2;"
- "-3.0 -2.5;"
- "-5.0 -5.0;"
- "3.5 1.5;"
- "2.0 2.5;"
- "-1.0 -1.0;"
- "-3.5 1.5;"
- "3.5 -1.5;"
- "2.0 1.0;"));
-
- // The root point will be the first point, (0, 0).
- CoverTree<> tree(data); // Expansion constant of 2.0.
-
- // The furthest point from the root will be (-5, -5), with a distance of
- // of sqrt(50). This means the scale of the root node should be 3 (because
- // 2^3 = 8).
- BOOST_REQUIRE_EQUAL(tree.Scale(), 3);
-
- // Now loop through the tree and ensure that each leaf is only created once.
- arma::vec counts;
- counts.zeros(20);
- RecurseTreeCountLeaves(tree, counts);
-
- // Each point should only have one leaf node representing it.
- for (size_t i = 0; i < 20; ++i)
- BOOST_REQUIRE_EQUAL(counts[i], 1);
-
- // Each non-leaf should have a self-child.
- CheckSelfChild<CoverTree<> >(tree);
-
- // Each node must satisfy the covering principle (its children must be less
- // than or equal to a certain distance apart).
- CheckCovering<CoverTree<>, LMetric<2, true> >(tree);
-
- // Each node's children must be separated by at least a certain value.
- CheckSeparation<CoverTree<>, LMetric<2, true> >(tree, tree);
-}
-
-/**
- * Create a large cover tree and make sure it's accurate.
- */
-BOOST_AUTO_TEST_CASE(CoverTreeConstructionTest)
-{
- arma::mat dataset;
- // 50-dimensional, 1000 point.
- dataset.randu(50, 1000);
-
- CoverTree<> tree(dataset);
-
- // Ensure each leaf is only created once.
- arma::vec counts;
- counts.zeros(1000);
- RecurseTreeCountLeaves(tree, counts);
-
- for (size_t i = 0; i < 1000; ++i)
- BOOST_REQUIRE_EQUAL(counts[i], 1);
-
- // Each non-leaf should have a self-child.
- CheckSelfChild<CoverTree<> >(tree);
-
- // Each node must satisfy the covering principle (its children must be less
- // than or equal to a certain distance apart).
- CheckCovering<CoverTree<>, LMetric<2, true> >(tree);
-
- // Each node's children must be separated by at least a certain value.
- CheckSeparation<CoverTree<>, LMetric<2, true> >(tree, tree);
-}
-
-/**
- * Test the manual constructor.
- */
-BOOST_AUTO_TEST_CASE(CoverTreeManualConstructorTest)
-{
- arma::mat dataset;
- dataset.zeros(10, 10);
-
- CoverTree<> node(dataset, 1.3, 3, 2, NULL, 1.5, 2.75);
-
- BOOST_REQUIRE_EQUAL(&node.Dataset(), &dataset);
- BOOST_REQUIRE_EQUAL(node.Base(), 1.3);
- BOOST_REQUIRE_EQUAL(node.Point(), 3);
- BOOST_REQUIRE_EQUAL(node.Scale(), 2);
- BOOST_REQUIRE_EQUAL(node.Parent(), (CoverTree<>*) NULL);
- BOOST_REQUIRE_EQUAL(node.ParentDistance(), 1.5);
- BOOST_REQUIRE_EQUAL(node.FurthestDescendantDistance(), 2.75);
-}
-
-/**
- * Make sure cover trees work in different metric spaces.
- */
-BOOST_AUTO_TEST_CASE(CoverTreeAlternateMetricTest)
-{
- arma::mat dataset;
- // 5-dimensional, 300-point dataset.
- dataset.randu(5, 300);
-
- CoverTree<LMetric<1, true> > tree(dataset);
-
- // Ensure each leaf is only created once.
- arma::vec counts;
- counts.zeros(300);
- RecurseTreeCountLeaves<CoverTree<LMetric<1, true> > >(tree, counts);
-
- for (size_t i = 0; i < 300; ++i)
- BOOST_REQUIRE_EQUAL(counts[i], 1);
-
- // Each non-leaf should have a self-child.
- CheckSelfChild<CoverTree<LMetric<1, true> > >(tree);
-
- // Each node must satisfy the covering principle (its children must be less
- // than or equal to a certain distance apart).
- CheckCovering<CoverTree<LMetric<1, true> >, LMetric<1, true> >(tree);
-
- // Each node's children must be separated by at least a certain value.
- CheckSeparation<CoverTree<LMetric<1, true> >, LMetric<1, true> >(tree, tree);
-}
-
-/**
- * Make sure copy constructor works for the cover tree.
- */
-BOOST_AUTO_TEST_CASE(CoverTreeCopyConstructor)
-{
- arma::mat dataset;
- dataset.randu(10, 10); // dataset is irrelevant.
- CoverTree<> c(dataset, 1.3, 0, 5, NULL, 1.45, 5.2); // Random parameters.
- c.Children().push_back(new CoverTree<>(dataset, 1.3, 1, 4, &c, 1.3, 2.45));
- c.Children().push_back(new CoverTree<>(dataset, 1.5, 2, 3, &c, 1.2, 5.67));
-
- CoverTree<> d = c;
-
- // Check that everything is the same.
- BOOST_REQUIRE_EQUAL(c.Dataset().memptr(), d.Dataset().memptr());
- BOOST_REQUIRE_CLOSE(c.Base(), d.Base(), 1e-50);
- BOOST_REQUIRE_EQUAL(c.Point(), d.Point());
- BOOST_REQUIRE_EQUAL(c.Scale(), d.Scale());
- BOOST_REQUIRE_EQUAL(c.Parent(), d.Parent());
- BOOST_REQUIRE_EQUAL(c.ParentDistance(), d.ParentDistance());
- BOOST_REQUIRE_EQUAL(c.FurthestDescendantDistance(),
- d.FurthestDescendantDistance());
- BOOST_REQUIRE_EQUAL(c.NumChildren(), d.NumChildren());
- BOOST_REQUIRE_NE(&c.Child(0), &d.Child(0));
- BOOST_REQUIRE_NE(&c.Child(1), &d.Child(1));
-
- BOOST_REQUIRE_EQUAL(c.Child(0).Parent(), &c);
- BOOST_REQUIRE_EQUAL(c.Child(1).Parent(), &c);
- BOOST_REQUIRE_EQUAL(d.Child(0).Parent(), &d);
- BOOST_REQUIRE_EQUAL(d.Child(1).Parent(), &d);
-
- // Check that the children are okay.
- BOOST_REQUIRE_EQUAL(c.Child(0).Dataset().memptr(),
- d.Child(0).Dataset().memptr());
- BOOST_REQUIRE_CLOSE(c.Child(0).Base(), d.Child(0).Base(), 1e-50);
- BOOST_REQUIRE_EQUAL(c.Child(0).Point(), d.Child(0).Point());
- BOOST_REQUIRE_EQUAL(c.Child(0).Scale(), d.Child(0).Scale());
- BOOST_REQUIRE_EQUAL(c.Child(0).ParentDistance(), d.Child(0).ParentDistance());
- BOOST_REQUIRE_EQUAL(c.Child(0).FurthestDescendantDistance(),
- d.Child(0).FurthestDescendantDistance());
- BOOST_REQUIRE_EQUAL(c.Child(0).NumChildren(), d.Child(0).NumChildren());
-
- BOOST_REQUIRE_EQUAL(c.Child(1).Dataset().memptr(),
- d.Child(1).Dataset().memptr());
- BOOST_REQUIRE_CLOSE(c.Child(1).Base(), d.Child(1).Base(), 1e-50);
- BOOST_REQUIRE_EQUAL(c.Child(1).Point(), d.Child(1).Point());
- BOOST_REQUIRE_EQUAL(c.Child(1).Scale(), d.Child(1).Scale());
- BOOST_REQUIRE_EQUAL(c.Child(1).ParentDistance(), d.Child(1).ParentDistance());
- BOOST_REQUIRE_EQUAL(c.Child(1).FurthestDescendantDistance(),
- d.Child(1).FurthestDescendantDistance());
- BOOST_REQUIRE_EQUAL(c.Child(1).NumChildren(), d.Child(1).NumChildren());
-}
-
-/**
- * Make sure copy constructor works right for the binary space tree.
- */
-BOOST_AUTO_TEST_CASE(BinarySpaceTreeCopyConstructor)
-{
- arma::mat data("1");
- BinarySpaceTree<HRectBound<2> > b(data);
- b.Begin() = 10;
- b.Count() = 50;
- b.Left() = new BinarySpaceTree<HRectBound<2> >(data);
- b.Left()->Begin() = 10;
- b.Left()->Count() = 30;
- b.Right() = new BinarySpaceTree<HRectBound<2> >(data);
- b.Right()->Begin() = 40;
- b.Right()->Count() = 20;
-
- // Copy the tree.
- BinarySpaceTree<HRectBound<2> > c(b);
-
- // Ensure everything copied correctly.
- BOOST_REQUIRE_EQUAL(b.Begin(), c.Begin());
- BOOST_REQUIRE_EQUAL(b.Count(), c.Count());
- BOOST_REQUIRE_NE(b.Left(), c.Left());
- BOOST_REQUIRE_NE(b.Right(), c.Right());
-
- // Check the children.
- BOOST_REQUIRE_EQUAL(b.Left()->Begin(), c.Left()->Begin());
- BOOST_REQUIRE_EQUAL(b.Left()->Count(), c.Left()->Count());
- BOOST_REQUIRE_EQUAL(b.Left()->Left(),
- (BinarySpaceTree<HRectBound<2> >*) NULL);
- BOOST_REQUIRE_EQUAL(b.Left()->Left(), c.Left()->Left());
- BOOST_REQUIRE_EQUAL(b.Left()->Right(),
- (BinarySpaceTree<HRectBound<2> >*) NULL);
- BOOST_REQUIRE_EQUAL(b.Left()->Right(), c.Left()->Right());
-
- BOOST_REQUIRE_EQUAL(b.Right()->Begin(), c.Right()->Begin());
- BOOST_REQUIRE_EQUAL(b.Right()->Count(), c.Right()->Count());
- BOOST_REQUIRE_EQUAL(b.Right()->Left(),
- (BinarySpaceTree<HRectBound<2> >*) NULL);
- BOOST_REQUIRE_EQUAL(b.Right()->Left(), c.Right()->Left());
- BOOST_REQUIRE_EQUAL(b.Right()->Right(),
- (BinarySpaceTree<HRectBound<2> >*) NULL);
- BOOST_REQUIRE_EQUAL(b.Right()->Right(), c.Right()->Right());
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/tree_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/tree_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/tree_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/tree_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,2209 @@
+/**
+ * @file tree_test.cpp
+ *
+ * Tests for tree-building methods.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/tree/bounds.hpp>
+#include <mlpack/core/tree/binary_space_tree/binary_space_tree.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/core/tree/cover_tree/cover_tree.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::math;
+using namespace mlpack::tree;
+using namespace mlpack::metric;
+using namespace mlpack::bound;
+
+BOOST_AUTO_TEST_SUITE(TreeTest);
+
+/**
+ * Ensure that a bound, by default, is empty and has no dimensionality.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundEmptyConstructor)
+{
+ HRectBound<2> b;
+
+ BOOST_REQUIRE_EQUAL((int) b.Dim(), 0);
+}
+
+/**
+ * Ensure that when we specify the dimensionality in the constructor, it is
+ * correct, and the bounds are all the empty set.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundDimConstructor)
+{
+ HRectBound<2> b(2); // We'll do this with 2 and 5 dimensions.
+
+ BOOST_REQUIRE_EQUAL(b.Dim(), 2);
+ BOOST_REQUIRE_SMALL(b[0].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(b[1].Width(), 1e-5);
+
+ b = HRectBound<2>(5);
+
+ BOOST_REQUIRE_EQUAL(b.Dim(), 5);
+ BOOST_REQUIRE_SMALL(b[0].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(b[1].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(b[2].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(b[3].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(b[4].Width(), 1e-5);
+}
+
+/**
+ * Test the copy constructor.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundCopyConstructor)
+{
+ HRectBound<2> b(2);
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(2.0, 3.0);
+
+ HRectBound<2> c(b);
+
+ BOOST_REQUIRE_EQUAL(c.Dim(), 2);
+ BOOST_REQUIRE_SMALL(c[0].Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(c[0].Hi(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c[1].Lo(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c[1].Hi(), 3.0, 1e-5);
+}
+
+/**
+ * Test the assignment operator.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundAssignmentOperator)
+{
+ HRectBound<2> b(2);
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(2.0, 3.0);
+
+ HRectBound<2> c(4);
+
+ c = b;
+
+ BOOST_REQUIRE_EQUAL(c.Dim(), 2);
+ BOOST_REQUIRE_SMALL(c[0].Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(c[0].Hi(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c[1].Lo(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c[1].Hi(), 3.0, 1e-5);
+}
+
+/**
+ * Test that clearing the dimensions resets the bound to empty.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundClear)
+{
+ HRectBound<2> b(2); // We'll do this with two dimensions only.
+
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(2.0, 4.0);
+
+ // Now we just need to make sure that we clear the range.
+ b.Clear();
+
+ BOOST_REQUIRE_SMALL(b[0].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(b[1].Width(), 1e-5);
+}
+
+/**
+ * Ensure that we get the correct centroid for our bound.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundCentroid)
+{
+ // Create a simple 3-dimensional bound.
+ HRectBound<2> b(3);
+
+ b[0] = Range(0.0, 5.0);
+ b[1] = Range(-2.0, -1.0);
+ b[2] = Range(-10.0, 50.0);
+
+ arma::vec centroid;
+
+ b.Centroid(centroid);
+
+ BOOST_REQUIRE_EQUAL(centroid.n_elem, 3);
+ BOOST_REQUIRE_CLOSE(centroid[0], 2.5, 1e-5);
+ BOOST_REQUIRE_CLOSE(centroid[1], -1.5, 1e-5);
+ BOOST_REQUIRE_CLOSE(centroid[2], 20.0, 1e-5);
+}
+
+/**
+ * Ensure that we calculate the correct minimum distance between a point and a
+ * bound.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundMinDistancePoint)
+{
+ // We'll do the calculation in five dimensions, and we'll use three cases for
+ // the point: point is outside the bound; point is on the edge of the bound;
+ // point is inside the bound. In the latter two cases, the distance should be
+ // zero.
+ HRectBound<2> b(5);
+
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(1.0, 5.0);
+ b[2] = Range(-2.0, 2.0);
+ b[3] = Range(-5.0, -2.0);
+ b[4] = Range(1.0, 2.0);
+
+ arma::vec point = "-2.0 0.0 10.0 3.0 3.0";
+
+ // This will be the Euclidean distance.
+ BOOST_REQUIRE_CLOSE(b.MinDistance(point), sqrt(95.0), 1e-5);
+
+ point = "2.0 5.0 2.0 -5.0 1.0";
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(point), 1e-5);
+
+ point = "1.0 2.0 0.0 -2.0 1.5";
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(point), 1e-5);
+}
+
+/**
+ * Ensure that we calculate the correct minimum distance between a bound and
+ * another bound.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundMinDistanceBound)
+{
+ // We'll do the calculation in five dimensions, and we can use six cases.
+ // The other bound is completely outside the bound; the other bound is on the
+ // edge of the bound; the other bound partially overlaps the bound; the other
+ // bound fully overlaps the bound; the other bound is entirely inside the
+ // bound; the other bound entirely envelops the bound.
+ HRectBound<2> b(5);
+
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(1.0, 5.0);
+ b[2] = Range(-2.0, 2.0);
+ b[3] = Range(-5.0, -2.0);
+ b[4] = Range(1.0, 2.0);
+
+ HRectBound<2> c(5);
+
+ // The other bound is completely outside the bound.
+ c[0] = Range(-5.0, -2.0);
+ c[1] = Range(6.0, 7.0);
+ c[2] = Range(-2.0, 2.0);
+ c[3] = Range(2.0, 5.0);
+ c[4] = Range(3.0, 4.0);
+
+ BOOST_REQUIRE_CLOSE(b.MinDistance(c), sqrt(22.0), 1e-5);
+ BOOST_REQUIRE_CLOSE(c.MinDistance(b), sqrt(22.0), 1e-5);
+
+ // The other bound is on the edge of the bound.
+ c[0] = Range(-2.0, 0.0);
+ c[1] = Range(0.0, 1.0);
+ c[2] = Range(-3.0, -2.0);
+ c[3] = Range(-10.0, -5.0);
+ c[4] = Range(2.0, 3.0);
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
+ BOOST_REQUIRE_SMALL(c.MinDistance(b), 1e-5);
+
+ // The other bound partially overlaps the bound.
+ c[0] = Range(-2.0, 1.0);
+ c[1] = Range(0.0, 2.0);
+ c[2] = Range(-2.0, 2.0);
+ c[3] = Range(-8.0, -4.0);
+ c[4] = Range(0.0, 4.0);
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
+ BOOST_REQUIRE_SMALL(c.MinDistance(b), 1e-5);
+
+ // The other bound fully overlaps the bound.
+ BOOST_REQUIRE_SMALL(b.MinDistance(b), 1e-5);
+ BOOST_REQUIRE_SMALL(c.MinDistance(c), 1e-5);
+
+ // The other bound is entirely inside the bound / the other bound entirely
+ // envelops the bound.
+ c[0] = Range(-1.0, 3.0);
+ c[1] = Range(0.0, 6.0);
+ c[2] = Range(-3.0, 3.0);
+ c[3] = Range(-7.0, 0.0);
+ c[4] = Range(0.0, 5.0);
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
+ BOOST_REQUIRE_SMALL(c.MinDistance(b), 1e-5);
+
+ // Now we must be sure that the minimum distance to itself is 0.
+ BOOST_REQUIRE_SMALL(b.MinDistance(b), 1e-5);
+ BOOST_REQUIRE_SMALL(c.MinDistance(c), 1e-5);
+}
+
+/**
+ * Ensure that we calculate the correct maximum distance between a bound and a
+ * point. This uses the same test cases as the MinDistance test.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundMaxDistancePoint)
+{
+ // We'll do the calculation in five dimensions, and we'll use three cases for
+ // the point: point is outside the bound; point is on the edge of the bound;
+ // point is inside the bound. In the latter two cases, the distance should be
+ // zero.
+ HRectBound<2> b(5);
+
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(1.0, 5.0);
+ b[2] = Range(-2.0, 2.0);
+ b[3] = Range(-5.0, -2.0);
+ b[4] = Range(1.0, 2.0);
+
+ arma::vec point = "-2.0 0.0 10.0 3.0 3.0";
+
+ // This will be the Euclidean distance.
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(point), sqrt(253.0), 1e-5);
+
+ point = "2.0 5.0 2.0 -5.0 1.0";
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(point), sqrt(46.0), 1e-5);
+
+ point = "1.0 2.0 0.0 -2.0 1.5";
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(point), sqrt(23.25), 1e-5);
+}
+
+/**
+ * Ensure that we calculate the correct maximum distance between a bound and
+ * another bound. This uses the same test cases as the MinDistance test.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundMaxDistanceBound)
+{
+ // We'll do the calculation in five dimensions, and we can use six cases.
+ // The other bound is completely outside the bound; the other bound is on the
+ // edge of the bound; the other bound partially overlaps the bound; the other
+ // bound fully overlaps the bound; the other bound is entirely inside the
+ // bound; the other bound entirely envelops the bound.
+ HRectBound<2> b(5);
+
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(1.0, 5.0);
+ b[2] = Range(-2.0, 2.0);
+ b[3] = Range(-5.0, -2.0);
+ b[4] = Range(1.0, 2.0);
+
+ HRectBound<2> c(5);
+
+ // The other bound is completely outside the bound.
+ c[0] = Range(-5.0, -2.0);
+ c[1] = Range(6.0, 7.0);
+ c[2] = Range(-2.0, 2.0);
+ c[3] = Range(2.0, 5.0);
+ c[4] = Range(3.0, 4.0);
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), sqrt(210.0), 1e-5);
+ BOOST_REQUIRE_CLOSE(c.MaxDistance(b), sqrt(210.0), 1e-5);
+
+ // The other bound is on the edge of the bound.
+ c[0] = Range(-2.0, 0.0);
+ c[1] = Range(0.0, 1.0);
+ c[2] = Range(-3.0, -2.0);
+ c[3] = Range(-10.0, -5.0);
+ c[4] = Range(2.0, 3.0);
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), sqrt(134.0), 1e-5);
+ BOOST_REQUIRE_CLOSE(c.MaxDistance(b), sqrt(134.0), 1e-5);
+
+ // The other bound partially overlaps the bound.
+ c[0] = Range(-2.0, 1.0);
+ c[1] = Range(0.0, 2.0);
+ c[2] = Range(-2.0, 2.0);
+ c[3] = Range(-8.0, -4.0);
+ c[4] = Range(0.0, 4.0);
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), sqrt(102.0), 1e-5);
+ BOOST_REQUIRE_CLOSE(c.MaxDistance(b), sqrt(102.0), 1e-5);
+
+ // The other bound fully overlaps the bound.
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(b), sqrt(46.0), 1e-5);
+ BOOST_REQUIRE_CLOSE(c.MaxDistance(c), sqrt(61.0), 1e-5);
+
+ // The other bound is entirely inside the bound / the other bound entirely
+ // envelops the bound.
+ c[0] = Range(-1.0, 3.0);
+ c[1] = Range(0.0, 6.0);
+ c[2] = Range(-3.0, 3.0);
+ c[3] = Range(-7.0, 0.0);
+ c[4] = Range(0.0, 5.0);
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), sqrt(100.0), 1e-5);
+ BOOST_REQUIRE_CLOSE(c.MaxDistance(b), sqrt(100.0), 1e-5);
+
+ // Identical bounds. This will be the sum of the squared widths in each
+ // dimension.
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(b), sqrt(46.0), 1e-5);
+ BOOST_REQUIRE_CLOSE(c.MaxDistance(c), sqrt(162.0), 1e-5);
+
+ // One last additional case. If the bound encloses only one point, the
+ // maximum distance between it and itself is 0.
+ HRectBound<2> d(2);
+
+ d[0] = Range(2.0, 2.0);
+ d[1] = Range(3.0, 3.0);
+
+ BOOST_REQUIRE_SMALL(d.MaxDistance(d), 1e-5);
+}
+
+/**
+ * Ensure that the ranges returned by RangeDistance() are equal to the minimum
+ * and maximum distance. We will perform this test by creating random bounds
+ * and comparing the behavior to MinDistance() and MaxDistance() -- so this test
+ * is assuming that those passed and operate correctly.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundRangeDistanceBound)
+{
+ for (int i = 0; i < 50; i++)
+ {
+ size_t dim = math::RandInt(20);
+
+ HRectBound<2> a(dim);
+ HRectBound<2> b(dim);
+
+ // We will set the low randomly and the width randomly for each dimension of
+ // each bound.
+ arma::vec loA(dim);
+ arma::vec widthA(dim);
+
+ loA.randu();
+ widthA.randu();
+
+ arma::vec lo_b(dim);
+ arma::vec width_b(dim);
+
+ lo_b.randu();
+ width_b.randu();
+
+ for (size_t j = 0; j < dim; j++)
+ {
+ a[j] = Range(loA[j], loA[j] + widthA[j]);
+ b[j] = Range(lo_b[j], lo_b[j] + width_b[j]);
+ }
+
+ // Now ensure that MinDistance and MaxDistance report the same.
+ Range r = a.RangeDistance(b);
+ Range s = b.RangeDistance(a);
+
+ BOOST_REQUIRE_CLOSE(r.Lo(), s.Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(r.Hi(), s.Hi(), 1e-5);
+
+ BOOST_REQUIRE_CLOSE(r.Lo(), a.MinDistance(b), 1e-5);
+ BOOST_REQUIRE_CLOSE(r.Hi(), a.MaxDistance(b), 1e-5);
+
+ BOOST_REQUIRE_CLOSE(s.Lo(), b.MinDistance(a), 1e-5);
+ BOOST_REQUIRE_CLOSE(s.Hi(), b.MaxDistance(a), 1e-5);
+ }
+}
+
+/**
+ * Ensure that the ranges returned by RangeDistance() are equal to the minimum
+ * and maximum distance. We will perform this test by creating random bounds
+ * and comparing the bheavior to MinDistance() and MaxDistance() -- so this test
+ * is assuming that those passed and operate correctly. This is for the
+ * bound-to-point case.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundRangeDistancePoint)
+{
+ for (int i = 0; i < 20; i++)
+ {
+ size_t dim = math::RandInt(20);
+
+ HRectBound<2> a(dim);
+
+ // We will set the low randomly and the width randomly for each dimension of
+ // each bound.
+ arma::vec loA(dim);
+ arma::vec widthA(dim);
+
+ loA.randu();
+ widthA.randu();
+
+ for (size_t j = 0; j < dim; j++)
+ a[j] = Range(loA[j], loA[j] + widthA[j]);
+
+ // Now run the test on a few points.
+ for (int j = 0; j < 10; j++)
+ {
+ arma::vec point(dim);
+
+ point.randu();
+
+ Range r = a.RangeDistance(point);
+
+ BOOST_REQUIRE_CLOSE(r.Lo(), a.MinDistance(point), 1e-5);
+ BOOST_REQUIRE_CLOSE(r.Hi(), a.MaxDistance(point), 1e-5);
+ }
+ }
+}
+
+/**
+ * Test that we can expand the bound to include a new point.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundOrOperatorPoint)
+{
+ // Because this should be independent in each dimension, we can essentially
+ // run five test cases at once.
+ HRectBound<2> b(5);
+
+ b[0] = Range(1.0, 3.0);
+ b[1] = Range(2.0, 4.0);
+ b[2] = Range(-2.0, -1.0);
+ b[3] = Range(0.0, 0.0);
+ b[4] = Range(); // Empty range.
+
+ arma::vec point = "2.0 4.0 2.0 -1.0 6.0";
+
+ b |= point;
+
+ BOOST_REQUIRE_CLOSE(b[0].Lo(), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[0].Hi(), 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[1].Lo(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[1].Hi(), 4.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[2].Lo(), -2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[2].Hi(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[3].Lo(), -1.0, 1e-5);
+ BOOST_REQUIRE_SMALL(b[3].Hi(), 1e-5);
+ BOOST_REQUIRE_CLOSE(b[4].Lo(), 6.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[4].Hi(), 6.0, 1e-5);
+}
+
+/**
+ * Test that we can expand the bound to include another bound.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundOrOperatorBound)
+{
+ // Because this should be independent in each dimension, we can run many tests
+ // at once.
+ HRectBound<2> b(8);
+
+ b[0] = Range(1.0, 3.0);
+ b[1] = Range(2.0, 4.0);
+ b[2] = Range(-2.0, -1.0);
+ b[3] = Range(4.0, 5.0);
+ b[4] = Range(2.0, 4.0);
+ b[5] = Range(0.0, 0.0);
+ b[6] = Range();
+ b[7] = Range(1.0, 3.0);
+
+ HRectBound<2> c(8);
+
+ c[0] = Range(-3.0, -1.0); // Entirely less than the other bound.
+ c[1] = Range(0.0, 2.0); // Touching edges.
+ c[2] = Range(-3.0, -1.5); // Partially overlapping.
+ c[3] = Range(4.0, 5.0); // Identical.
+ c[4] = Range(1.0, 5.0); // Entirely enclosing.
+ c[5] = Range(2.0, 2.0); // A single point.
+ c[6] = Range(1.0, 3.0);
+ c[7] = Range(); // Empty set.
+
+ HRectBound<2> d = c;
+
+ b |= c;
+ d |= b;
+
+ BOOST_REQUIRE_CLOSE(b[0].Lo(), -3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[0].Hi(), 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[0].Lo(), -3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[0].Hi(), 3.0, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(b[1].Lo(), 0.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[1].Hi(), 4.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[1].Lo(), 0.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[1].Hi(), 4.0, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(b[2].Lo(), -3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[2].Hi(), -1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[2].Lo(), -3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[2].Hi(), -1.0, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(b[3].Lo(), 4.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[3].Hi(), 5.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[3].Lo(), 4.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[3].Hi(), 5.0, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(b[4].Lo(), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[4].Hi(), 5.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[4].Lo(), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[4].Hi(), 5.0, 1e-5);
+
+ BOOST_REQUIRE_SMALL(b[5].Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(b[5].Hi(), 2.0, 1e-5);
+ BOOST_REQUIRE_SMALL(d[5].Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(d[5].Hi(), 2.0, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(b[6].Lo(), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[6].Hi(), 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[6].Lo(), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[6].Hi(), 3.0, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(b[7].Lo(), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b[7].Hi(), 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[7].Lo(), 1.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d[7].Hi(), 3.0, 1e-5);
+}
+
+/**
+ * Test that the Contains() function correctly figures out whether or not a
+ * point is in a bound.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundContains)
+{
+ // We can test a couple different points: completely outside the bound,
+ // adjacent in one dimension to the bound, adjacent in all dimensions to the
+ // bound, and inside the bound.
+ HRectBound<2> b(3);
+
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(0.0, 2.0);
+ b[2] = Range(0.0, 2.0);
+
+ // Completely outside the range.
+ arma::vec point = "-1.0 4.0 4.0";
+ BOOST_REQUIRE(!b.Contains(point));
+
+ // Completely outside, but one dimension is in the range.
+ point = "-1.0 4.0 1.0";
+ BOOST_REQUIRE(!b.Contains(point));
+
+ // Outside, but one dimension is on the edge.
+ point = "-1.0 0.0 3.0";
+ BOOST_REQUIRE(!b.Contains(point));
+
+ // Two dimensions are on the edge, but one is outside.
+ point = "0.0 0.0 3.0";
+ BOOST_REQUIRE(!b.Contains(point));
+
+ // Completely on the edge (should be contained).
+ point = "0.0 0.0 0.0";
+ BOOST_REQUIRE(b.Contains(point));
+
+ // Inside the range.
+ point = "0.3 1.0 0.4";
+ BOOST_REQUIRE(b.Contains(point));
+}
+
+BOOST_AUTO_TEST_CASE(TestBallBound)
+{
+ BallBound<> b1;
+ BallBound<> b2;
+
+ // Create two balls with a center distance of 1 from each other.
+ // Give the first one a radius of 0.3 and the second a radius of 0.4.
+ b1.Center().set_size(3);
+ b1.Center()[0] = 1;
+ b1.Center()[1] = 2;
+ b1.Center()[2] = 3;
+ b1.Radius() = 0.3;
+
+ b2.Center().set_size(3);
+ b2.Center()[0] = 1;
+ b2.Center()[1] = 2;
+ b2.Center()[2] = 4;
+ b2.Radius() = 0.4;
+
+ BOOST_REQUIRE_CLOSE(b1.MinDistance(b2), 1-0.3-0.4, 1e-5);
+ BOOST_REQUIRE_CLOSE(b1.RangeDistance(b2).Hi(), 1+0.3+0.4, 1e-5);
+ BOOST_REQUIRE_CLOSE(b1.RangeDistance(b2).Lo(), 1-0.3-0.4, 1e-5);
+ BOOST_REQUIRE_CLOSE(b1.RangeDistance(b2).Hi(), 1+0.3+0.4, 1e-5);
+ BOOST_REQUIRE_CLOSE(b1.RangeDistance(b2).Lo(), 1-0.3-0.4, 1e-5);
+
+ BOOST_REQUIRE_CLOSE(b2.MinDistance(b1), 1-0.3-0.4, 1e-5);
+ BOOST_REQUIRE_CLOSE(b2.MaxDistance(b1), 1+0.3+0.4, 1e-5);
+ BOOST_REQUIRE_CLOSE(b2.RangeDistance(b1).Hi(), 1+0.3+0.4, 1e-5);
+ BOOST_REQUIRE_CLOSE(b2.RangeDistance(b1).Lo(), 1-0.3-0.4, 1e-5);
+
+ BOOST_REQUIRE(b1.Contains(b1.Center()));
+ BOOST_REQUIRE(!b1.Contains(b2.Center()));
+
+ BOOST_REQUIRE(!b2.Contains(b1.Center()));
+ BOOST_REQUIRE(b2.Contains(b2.Center()));
+ arma::vec b2point(3); // A point that's within the radius but not the center.
+ b2point[0] = 1.1;
+ b2point[1] = 2.1;
+ b2point[2] = 4.1;
+
+ BOOST_REQUIRE(b2.Contains(b2point));
+
+ BOOST_REQUIRE_SMALL(b1.MinDistance(b1.Center()), 1e-5);
+ BOOST_REQUIRE_CLOSE(b1.MinDistance(b2.Center()), 1 - 0.3, 1e-5);
+ BOOST_REQUIRE_CLOSE(b2.MinDistance(b1.Center()), 1 - 0.4, 1e-5);
+ BOOST_REQUIRE_CLOSE(b2.MaxDistance(b1.Center()), 1 + 0.4, 1e-5);
+ BOOST_REQUIRE_CLOSE(b1.MaxDistance(b2.Center()), 1 + 0.3, 1e-5);
+}
+
+/**
+ * Ensure that we calculate the correct minimum distance between a point and a
+ * bound.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundRootMinDistancePoint)
+{
+ // We'll do the calculation in five dimensions, and we'll use three cases for
+ // the point: point is outside the bound; point is on the edge of the bound;
+ // point is inside the bound. In the latter two cases, the distance should be
+ // zero.
+ HRectBound<2, true> b(5);
+
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(1.0, 5.0);
+ b[2] = Range(-2.0, 2.0);
+ b[3] = Range(-5.0, -2.0);
+ b[4] = Range(1.0, 2.0);
+
+ arma::vec point = "-2.0 0.0 10.0 3.0 3.0";
+
+ // This will be the Euclidean distance.
+ BOOST_REQUIRE_CLOSE(b.MinDistance(point), sqrt(95.0), 1e-5);
+
+ point = "2.0 5.0 2.0 -5.0 1.0";
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(point), 1e-5);
+
+ point = "1.0 2.0 0.0 -2.0 1.5";
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(point), 1e-5);
+}
+
+/**
+ * Ensure that we calculate the correct minimum distance between a bound and
+ * another bound.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundRootMinDistanceBound)
+{
+ // We'll do the calculation in five dimensions, and we can use six cases.
+ // The other bound is completely outside the bound; the other bound is on the
+ // edge of the bound; the other bound partially overlaps the bound; the other
+ // bound fully overlaps the bound; the other bound is entirely inside the
+ // bound; the other bound entirely envelops the bound.
+ HRectBound<2, true> b(5);
+
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(1.0, 5.0);
+ b[2] = Range(-2.0, 2.0);
+ b[3] = Range(-5.0, -2.0);
+ b[4] = Range(1.0, 2.0);
+
+ HRectBound<2, true> c(5);
+
+ // The other bound is completely outside the bound.
+ c[0] = Range(-5.0, -2.0);
+ c[1] = Range(6.0, 7.0);
+ c[2] = Range(-2.0, 2.0);
+ c[3] = Range(2.0, 5.0);
+ c[4] = Range(3.0, 4.0);
+
+ BOOST_REQUIRE_CLOSE(b.MinDistance(c), sqrt(22.0), 1e-5);
+ BOOST_REQUIRE_CLOSE(c.MinDistance(b), sqrt(22.0), 1e-5);
+
+ // The other bound is on the edge of the bound.
+ c[0] = Range(-2.0, 0.0);
+ c[1] = Range(0.0, 1.0);
+ c[2] = Range(-3.0, -2.0);
+ c[3] = Range(-10.0, -5.0);
+ c[4] = Range(2.0, 3.0);
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
+ BOOST_REQUIRE_SMALL(c.MinDistance(b), 1e-5);
+
+ // The other bound partially overlaps the bound.
+ c[0] = Range(-2.0, 1.0);
+ c[1] = Range(0.0, 2.0);
+ c[2] = Range(-2.0, 2.0);
+ c[3] = Range(-8.0, -4.0);
+ c[4] = Range(0.0, 4.0);
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
+ BOOST_REQUIRE_SMALL(c.MinDistance(b), 1e-5);
+
+ // The other bound fully overlaps the bound.
+ BOOST_REQUIRE_SMALL(b.MinDistance(b), 1e-5);
+ BOOST_REQUIRE_SMALL(c.MinDistance(c), 1e-5);
+
+ // The other bound is entirely inside the bound / the other bound entirely
+ // envelops the bound.
+ c[0] = Range(-1.0, 3.0);
+ c[1] = Range(0.0, 6.0);
+ c[2] = Range(-3.0, 3.0);
+ c[3] = Range(-7.0, 0.0);
+ c[4] = Range(0.0, 5.0);
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
+ BOOST_REQUIRE_SMALL(c.MinDistance(b), 1e-5);
+
+ // Now we must be sure that the minimum distance to itself is 0.
+ BOOST_REQUIRE_SMALL(b.MinDistance(b), 1e-5);
+ BOOST_REQUIRE_SMALL(c.MinDistance(c), 1e-5);
+}
+
+/**
+ * Ensure that we calculate the correct maximum distance between a bound and a
+ * point. This uses the same test cases as the MinDistance test.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundRootMaxDistancePoint)
+{
+ // We'll do the calculation in five dimensions, and we'll use three cases for
+ // the point: point is outside the bound; point is on the edge of the bound;
+ // point is inside the bound. In the latter two cases, the distance should be
+ // zero.
+ HRectBound<2, true> b(5);
+
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(1.0, 5.0);
+ b[2] = Range(-2.0, 2.0);
+ b[3] = Range(-5.0, -2.0);
+ b[4] = Range(1.0, 2.0);
+
+ arma::vec point = "-2.0 0.0 10.0 3.0 3.0";
+
+ // This will be the Euclidean distance.
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(point), sqrt(253.0), 1e-5);
+
+ point = "2.0 5.0 2.0 -5.0 1.0";
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(point), sqrt(46.0), 1e-5);
+
+ point = "1.0 2.0 0.0 -2.0 1.5";
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(point), sqrt(23.25), 1e-5);
+}
+
+/**
+ * Ensure that we calculate the correct maximum distance between a bound and
+ * another bound. This uses the same test cases as the MinDistance test.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundRootMaxDistanceBound)
+{
+ // We'll do the calculation in five dimensions, and we can use six cases.
+ // The other bound is completely outside the bound; the other bound is on the
+ // edge of the bound; the other bound partially overlaps the bound; the other
+ // bound fully overlaps the bound; the other bound is entirely inside the
+ // bound; the other bound entirely envelops the bound.
+ HRectBound<2, true> b(5);
+
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(1.0, 5.0);
+ b[2] = Range(-2.0, 2.0);
+ b[3] = Range(-5.0, -2.0);
+ b[4] = Range(1.0, 2.0);
+
+ HRectBound<2, true> c(5);
+
+ // The other bound is completely outside the bound.
+ c[0] = Range(-5.0, -2.0);
+ c[1] = Range(6.0, 7.0);
+ c[2] = Range(-2.0, 2.0);
+ c[3] = Range(2.0, 5.0);
+ c[4] = Range(3.0, 4.0);
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), sqrt(210.0), 1e-5);
+ BOOST_REQUIRE_CLOSE(c.MaxDistance(b), sqrt(210.0), 1e-5);
+
+ // The other bound is on the edge of the bound.
+ c[0] = Range(-2.0, 0.0);
+ c[1] = Range(0.0, 1.0);
+ c[2] = Range(-3.0, -2.0);
+ c[3] = Range(-10.0, -5.0);
+ c[4] = Range(2.0, 3.0);
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), sqrt(134.0), 1e-5);
+ BOOST_REQUIRE_CLOSE(c.MaxDistance(b), sqrt(134.0), 1e-5);
+
+ // The other bound partially overlaps the bound.
+ c[0] = Range(-2.0, 1.0);
+ c[1] = Range(0.0, 2.0);
+ c[2] = Range(-2.0, 2.0);
+ c[3] = Range(-8.0, -4.0);
+ c[4] = Range(0.0, 4.0);
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), sqrt(102.0), 1e-5);
+ BOOST_REQUIRE_CLOSE(c.MaxDistance(b), sqrt(102.0), 1e-5);
+
+ // The other bound fully overlaps the bound.
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(b), sqrt(46.0), 1e-5);
+ BOOST_REQUIRE_CLOSE(c.MaxDistance(c), sqrt(61.0), 1e-5);
+
+ // The other bound is entirely inside the bound / the other bound entirely
+ // envelops the bound.
+ c[0] = Range(-1.0, 3.0);
+ c[1] = Range(0.0, 6.0);
+ c[2] = Range(-3.0, 3.0);
+ c[3] = Range(-7.0, 0.0);
+ c[4] = Range(0.0, 5.0);
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), sqrt(100.0), 1e-5);
+ BOOST_REQUIRE_CLOSE(c.MaxDistance(b), sqrt(100.0), 1e-5);
+
+ // Identical bounds. This will be the sum of the squared widths in each
+ // dimension.
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(b), sqrt(46.0), 1e-5);
+ BOOST_REQUIRE_CLOSE(c.MaxDistance(c), sqrt(162.0), 1e-5);
+
+ // One last additional case. If the bound encloses only one point, the
+ // maximum distance between it and itself is 0.
+ HRectBound<2, true> d(2);
+
+ d[0] = Range(2.0, 2.0);
+ d[1] = Range(3.0, 3.0);
+
+ BOOST_REQUIRE_SMALL(d.MaxDistance(d), 1e-5);
+}
+
+/**
+ * Ensure that the ranges returned by RangeDistance() are equal to the minimum
+ * and maximum distance. We will perform this test by creating random bounds
+ * and comparing the behavior to MinDistance() and MaxDistance() -- so this test
+ * is assuming that those passed and operate correctly.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundRootRangeDistanceBound)
+{
+ for (int i = 0; i < 50; i++)
+ {
+ size_t dim = math::RandInt(20);
+
+ HRectBound<2, true> a(dim);
+ HRectBound<2, true> b(dim);
+
+ // We will set the low randomly and the width randomly for each dimension of
+ // each bound.
+ arma::vec loA(dim);
+ arma::vec widthA(dim);
+
+ loA.randu();
+ widthA.randu();
+
+ arma::vec lo_b(dim);
+ arma::vec width_b(dim);
+
+ lo_b.randu();
+ width_b.randu();
+
+ for (size_t j = 0; j < dim; j++)
+ {
+ a[j] = Range(loA[j], loA[j] + widthA[j]);
+ b[j] = Range(lo_b[j], lo_b[j] + width_b[j]);
+ }
+
+ // Now ensure that MinDistance and MaxDistance report the same.
+ Range r = a.RangeDistance(b);
+ Range s = b.RangeDistance(a);
+
+ BOOST_REQUIRE_CLOSE(r.Lo(), s.Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(r.Hi(), s.Hi(), 1e-5);
+
+ BOOST_REQUIRE_CLOSE(r.Lo(), a.MinDistance(b), 1e-5);
+ BOOST_REQUIRE_CLOSE(r.Hi(), a.MaxDistance(b), 1e-5);
+
+ BOOST_REQUIRE_CLOSE(s.Lo(), b.MinDistance(a), 1e-5);
+ BOOST_REQUIRE_CLOSE(s.Hi(), b.MaxDistance(a), 1e-5);
+ }
+}
+
+/**
+ * Ensure that the ranges returned by RangeDistance() are equal to the minimum
+ * and maximum distance. We will perform this test by creating random bounds
+ * and comparing the bheavior to MinDistance() and MaxDistance() -- so this test
+ * is assuming that those passed and operate correctly. This is for the
+ * bound-to-point case.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundRootRangeDistancePoint)
+{
+ for (int i = 0; i < 20; i++)
+ {
+ size_t dim = math::RandInt(20);
+
+ HRectBound<2, true> a(dim);
+
+ // We will set the low randomly and the width randomly for each dimension of
+ // each bound.
+ arma::vec loA(dim);
+ arma::vec widthA(dim);
+
+ loA.randu();
+ widthA.randu();
+
+ for (size_t j = 0; j < dim; j++)
+ a[j] = Range(loA[j], loA[j] + widthA[j]);
+
+ // Now run the test on a few points.
+ for (int j = 0; j < 10; j++)
+ {
+ arma::vec point(dim);
+
+ point.randu();
+
+ Range r = a.RangeDistance(point);
+
+ BOOST_REQUIRE_CLOSE(r.Lo(), a.MinDistance(point), 1e-5);
+ BOOST_REQUIRE_CLOSE(r.Hi(), a.MaxDistance(point), 1e-5);
+ }
+ }
+}
+
+/**
+ * Ensure that HRectBound::Diameter() works properly.
+ */
+BOOST_AUTO_TEST_CASE(HRectBoundDiameter)
+{
+ HRectBound<3> b(4);
+ b[0] = math::Range(0.0, 1.0);
+ b[1] = math::Range(-1.0, 0.0);
+ b[2] = math::Range(2.0, 3.0);
+ b[3] = math::Range(7.0, 7.0);
+
+ BOOST_REQUIRE_CLOSE(b.Diameter(), std::pow(3.0, 1.0 / 3.0), 1e-5);
+
+ HRectBound<2, false> c(4);
+ c[0] = math::Range(0.0, 1.0);
+ c[1] = math::Range(-1.0, 0.0);
+ c[2] = math::Range(2.0, 3.0);
+ c[3] = math::Range(0.0, 0.0);
+
+ BOOST_REQUIRE_CLOSE(c.Diameter(), 3.0, 1e-5);
+
+ HRectBound<5> d(2);
+ d[0] = math::Range(2.2, 2.2);
+ d[1] = math::Range(1.0, 1.0);
+
+ BOOST_REQUIRE_SMALL(d.Diameter(), 1e-5);
+}
+
+/**
+ * Ensure that a bound, by default, is empty and has no dimensionality, and the
+ * box size vector is empty.
+ */
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundEmptyConstructor)
+{
+ PeriodicHRectBound<2> b;
+
+ BOOST_REQUIRE_EQUAL(b.Dim(), 0);
+ BOOST_REQUIRE_EQUAL(b.Box().n_elem, 0);
+}
+
+/**
+ * Ensure that when we specify the dimensionality in the constructor, it is
+ * correct, and the bounds are all the empty set.
+ */
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundBoxConstructor)
+{
+ PeriodicHRectBound<2> b(arma::vec("5 6"));
+
+ BOOST_REQUIRE_EQUAL(b.Dim(), 2);
+ BOOST_REQUIRE_SMALL(b[0].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(b[1].Width(), 1e-5);
+ BOOST_REQUIRE_EQUAL(b.Box().n_elem, 2);
+ BOOST_REQUIRE_CLOSE(b.Box()[0], 5.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b.Box()[1], 6.0, 1e-5);
+
+ PeriodicHRectBound<2> d(arma::vec("2 3 4 5 6"));
+
+ BOOST_REQUIRE_EQUAL(d.Dim(), 5);
+ BOOST_REQUIRE_SMALL(d[0].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(d[1].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(d[2].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(d[3].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(d[4].Width(), 1e-5);
+ BOOST_REQUIRE_EQUAL(d.Box().n_elem, 5);
+ BOOST_REQUIRE_CLOSE(d.Box()[0], 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Box()[1], 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Box()[2], 4.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Box()[3], 5.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(d.Box()[4], 6.0, 1e-5);
+}
+
+/**
+ * Test the copy constructor.
+ */
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundCopyConstructor)
+{
+ PeriodicHRectBound<2> b(arma::vec("3 4"));
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(2.0, 3.0);
+
+ PeriodicHRectBound<2> c(b);
+
+ BOOST_REQUIRE_EQUAL(c.Dim(), 2);
+ BOOST_REQUIRE_SMALL(c[0].Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(c[0].Hi(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c[1].Lo(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c[1].Hi(), 3.0, 1e-5);
+ BOOST_REQUIRE_EQUAL(c.Box().n_elem, 2);
+ BOOST_REQUIRE_CLOSE(c.Box()[0], 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c.Box()[1], 4.0, 1e-5);
+}
+
+/**
+ * Test the assignment operator.
+ *
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundAssignmentOperator)
+{
+ PeriodicHRectBound<2> b(arma::vec("3 4"));
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(2.0, 3.0);
+
+ PeriodicHRectBound<2> c(arma::vec("3 4 5 6"));
+
+ c = b;
+
+ BOOST_REQUIRE_EQUAL(c.Dim(), 2);
+ BOOST_REQUIRE_SMALL(c[0].Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(c[0].Hi(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c[1].Lo(), 2.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c[1].Hi(), 3.0, 1e-5);
+ BOOST_REQUIRE_EQUAL(c.Box().n_elem, 2);
+ BOOST_REQUIRE_CLOSE(c.Box()[0], 3.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(c.Box()[1], 4.0, 1e-5);
+}*/
+
+/**
+ * Ensure that we can set the box size correctly.
+ *
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundSetBoxSize)
+{
+ PeriodicHRectBound<2> b(arma::vec("1 2"));
+
+ b.SetBoxSize(arma::vec("10 12"));
+
+ BOOST_REQUIRE_EQUAL(b.Box().n_elem, 2);
+ BOOST_REQUIRE_CLOSE(b.Box()[0], 10.0, 1e-5);
+ BOOST_REQUIRE_CLOSE(b.Box()[1], 12.0, 1e-5);
+}*/
+
+/**
+ * Ensure that we can clear the dimensions correctly. This does not involve the
+ * box size at all, so the test can be identical to the HRectBound test.
+ *
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundClear)
+{
+ // We'll do this with two dimensions only.
+ PeriodicHRectBound<2> b(arma::vec("5 5"));
+
+ b[0] = Range(0.0, 2.0);
+ b[1] = Range(2.0, 4.0);
+
+ // Now we just need to make sure that we clear the range.
+ b.Clear();
+
+ BOOST_REQUIRE_SMALL(b[0].Width(), 1e-5);
+ BOOST_REQUIRE_SMALL(b[1].Width(), 1e-5);
+}*/
+
+/**
+ * Ensure that we get the correct centroid for our bound.
+ *
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundCentroid) {
+ // Create a simple 3-dimensional bound. The centroid is not affected by the
+ // periodic coordinates.
+ PeriodicHRectBound<2> b(arma::vec("100 100 100"));
+
+ b[0] = Range(0.0, 5.0);
+ b[1] = Range(-2.0, -1.0);
+ b[2] = Range(-10.0, 50.0);
+
+ arma::vec centroid;
+
+ b.Centroid(centroid);
+
+ BOOST_REQUIRE_EQUAL(centroid.n_elem, 3);
+ BOOST_REQUIRE_CLOSE(centroid[0], 2.5, 1e-5);
+ BOOST_REQUIRE_CLOSE(centroid[1], -1.5, 1e-5);
+ BOOST_REQUIRE_CLOSE(centroid[2], 20.0, 1e-5);
+}*/
+
+/**
+ * Correctly calculate the minimum distance between the bound and a point in
+ * periodic coordinates. We have to account for the shifts necessary in
+ * periodic coordinates too, so that makes testing this a little more difficult.
+ *
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundMinDistancePoint)
+{
+ // First, we'll start with a simple 2-dimensional case where the point is
+ // inside the bound, then on the edge of the bound, then barely outside the
+ // bound. The box size will be large enough that this is basically the
+ // HRectBound case.
+ PeriodicHRectBound<2> b(arma::vec("100 100"));
+
+ b[0] = Range(0.0, 5.0);
+ b[1] = Range(2.0, 4.0);
+
+ // Inside the bound.
+ arma::vec point = "2.5 3.0";
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(point), 1e-5);
+
+ // On the edge.
+ point = "5.0 4.0";
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(point), 1e-5);
+
+ // And just a little outside the bound.
+ point = "6.0 5.0";
+
+ BOOST_REQUIRE_CLOSE(b.MinDistance(point), 2.0, 1e-5);
+
+ // Now we start to invoke the periodicity. This point will "alias" to (-1,
+ // -1).
+ point = "99.0 99.0";
+
+ BOOST_REQUIRE_CLOSE(b.MinDistance(point), 10.0, 1e-5);
+
+ // We will perform several tests on a one-dimensional bound.
+ PeriodicHRectBound<2> a(arma::vec("5.0"));
+ point.set_size(1);
+
+ a[0] = Range(2.0, 4.0); // Entirely inside box.
+ point[0] = 7.5; // Inside first right image of the box.
+
+ BOOST_REQUIRE_SMALL(a.MinDistance(point), 1e-5);
+
+ a[0] = Range(0.0, 5.0); // Fills box fully.
+ point[1] = 19.3; // Inside the box, which covers everything.
+
+ BOOST_REQUIRE_SMALL(a.MinDistance(point), 1e-5);
+
+ a[0] = Range(-10.0, 10.0); // Larger than the box.
+ point[0] = -500.0; // Inside the box, which covers everything.
+
+ BOOST_REQUIRE_SMALL(a.MinDistance(point), 1e-5);
+
+ a[0] = Range(-2.0, 1.0); // Crosses over an edge.
+ point[0] = 2.9; // The first right image of the bound starts at 3.0.
+
+ BOOST_REQUIRE_CLOSE(a.MinDistance(point), 0.01, 1e-5);
+
+ a[0] = Range(2.0, 4.0); // Inside box.
+ point[0] = 0.0; // Closest to the first left image of the bound.
+
+ BOOST_REQUIRE_CLOSE(a.MinDistance(point), 1.0, 1e-5);
+
+ a[0] = Range(0.0, 2.0); // On edge of box.
+ point[0] = 7.1; // 0.1 away from the first right image of the bound.
+
+ BOOST_REQUIRE_CLOSE(a.MinDistance(point), 0.01, 1e-5);
+
+ PeriodicHRectBound<2> d(arma::vec("0.0"));
+ d[0] = Range(-10.0, 10.0); // Box is of infinite size.
+ point[0] = 810.0; // 800 away from the only image of the box.
+
+ BOOST_REQUIRE_CLOSE(d.MinDistance(point), 640000.0, 1e-5);
+
+ PeriodicHRectBound<2> e(arma::vec("-5.0"));
+ e[0] = Range(2.0, 4.0); // Box size of -5 should function the same as 5.
+ point[0] = -10.8; // Should alias to 4.2.
+
+ BOOST_REQUIRE_CLOSE(e.MinDistance(point), 0.04, 1e-5);
+
+ // Switch our bound to a higher dimensionality. This should ensure that the
+ // dimensions are independent like they should be.
+ PeriodicHRectBound<2> c(arma::vec("5.0 5.0 5.0 5.0 5.0 5.0 0.0 -5.0"));
+
+ c[0] = Range(2.0, 4.0); // Entirely inside box.
+ c[1] = Range(0.0, 5.0); // Fills box fully.
+ c[2] = Range(-10.0, 10.0); // Larger than the box.
+ c[3] = Range(-2.0, 1.0); // Crosses over an edge.
+ c[4] = Range(2.0, 4.0); // Inside box.
+ c[5] = Range(0.0, 2.0); // On edge of box.
+ c[6] = Range(-10.0, 10.0); // Box is of infinite size.
+ c[7] = Range(2.0, 4.0); // Box size of -5 should function the same as 5.
+
+ point.set_size(8);
+ point[0] = 7.5; // Inside first right image of the box.
+ point[1] = 19.3; // Inside the box, which covers everything.
+ point[2] = -500.0; // Inside the box, which covers everything.
+ point[3] = 2.9; // The first right image of the bound starts at 3.0.
+ point[4] = 0.0; // Closest to the first left image of the bound.
+ point[5] = 7.1; // 0.1 away from the first right image of the bound.
+ point[6] = 810.0; // 800 away from the only image of the box.
+ point[7] = -10.8; // Should alias to 4.2.
+
+ BOOST_REQUIRE_CLOSE(c.MinDistance(point), 640001.06, 1e-10);
+}*/
+
+/**
+ * Correctly calculate the minimum distance between the bound and another bound in
+ * periodic coordinates. We have to account for the shifts necessary in
+ * periodic coordinates too, so that makes testing this a little more difficult.
+ *
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundMinDistanceBound)
+{
+ // First, we'll start with a simple 2-dimensional case where the bounds are nonoverlapping,
+ // then one bound is on the edge of the other bound,
+ // then overlapping, then one range entirely covering the other. The box size will be large enough that this is basically the
+ // HRectBound case.
+ PeriodicHRectBound<2> b(arma::vec("100 100"));
+ PeriodicHRectBound<2> c(arma::vec("100 100"));
+
+ b[0] = Range(0.0, 5.0);
+ b[1] = Range(2.0, 4.0);
+
+ // Inside the bound.
+ c[0] = Range(7.0, 9.0);
+ c[1] = Range(10.0,12.0);
+
+
+ BOOST_REQUIRE_CLOSE(b.MinDistance(c), 40.0, 1e-5);
+
+ // On the edge.
+ c[0] = Range(5.0, 8.0);
+ c[1] = Range(4.0, 6.0);
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
+
+ // Overlapping the bound.
+ c[0] = Range(3.0, 6.0);
+ c[1] = Range(1.0, 3.0);
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
+
+ // One range entirely covering the other
+
+ c[0] = Range(0.0, 6.0);
+ c[1] = Range(1.0, 7.0);
+
+ BOOST_REQUIRE_SMALL(b.MinDistance(c), 1e-5);
+
+ // Now we start to invoke the periodicity. These bounds "alias" to (-3.0,
+ // -1.0) and (5,0,6.0).
+
+ c[0] = Range(97.0, 99.0);
+ c[1] = Range(105.0, 106.0);
+
+ BOOST_REQUIRE_CLOSE(b.MinDistance(c), 2.0, 1e-5);
+
+ // We will perform several tests on a one-dimensional bound and smaller box size and mostly overlapping.
+ PeriodicHRectBound<2> a(arma::vec("5.0"));
+ PeriodicHRectBound<2> d(a);
+
+ a[0] = Range(2.0, 4.0); // Entirely inside box.
+ d[0] = Range(7.5, 10.0); // In the right image of the box, overlapping ranges.
+
+ BOOST_REQUIRE_SMALL(a.MinDistance(d), 1e-5);
+
+ a[0] = Range(0.0, 5.0); // Fills box fully.
+ d[0] = Range(19.3, 21.0); // Two intervals inside the box, same as range of b[0].
+
+ BOOST_REQUIRE_SMALL(a.MinDistance(d), 1e-5);
+
+ a[0] = Range(-10.0, 10.0); // Larger than the box.
+ d[0] = Range(-500.0, -498.0); // Inside the box, which covers everything.
+
+ BOOST_REQUIRE_SMALL(a.MinDistance(d), 1e-5);
+
+ a[0] = Range(-2.0, 1.0); // Crosses over an edge.
+ d[0] = Range(2.9, 5.1); // The first right image of the bound starts at 3.0. Overlapping
+
+ BOOST_REQUIRE_SMALL(a.MinDistance(d), 1e-5);
+
+ a[0] = Range(-1.0, 1.0); // Crosses over an edge.
+ d[0] = Range(11.9, 12.5); // The first right image of the bound starts at 4.0.
+ BOOST_REQUIRE_CLOSE(a.MinDistance(d), 0.81, 1e-5);
+
+ a[0] = Range(2.0, 3.0);
+ d[0] = Range(9.5, 11);
+ BOOST_REQUIRE_CLOSE(a.MinDistance(d), 1.0, 1e-5);
+
+}*/
+
+/**
+ * Correctly calculate the maximum distance between the bound and a point in
+ * periodic coordinates. We have to account for the shifts necessary in
+ * periodic coordinates too, so that makes testing this a little more difficult.
+ *
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundMaxDistancePoint)
+{
+ // First, we'll start with a simple 2-dimensional case where the point is
+ // inside the bound, then on the edge of the bound, then barely outside the
+ // bound. The box size will be large enough that this is basically the
+ // HRectBound case.
+ PeriodicHRectBound<2> b(arma::vec("100 100"));
+
+ b[0] = Range(0.0, 5.0);
+ b[1] = Range(2.0, 4.0);
+
+ // Inside the bound.
+ arma::vec point = "2.5 3.0";
+
+ //BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 7.25, 1e-5);
+
+ // On the edge.
+ point = "5.0 4.0";
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 29.0, 1e-5);
+
+ // And just a little outside the bound.
+ point = "6.0 5.0";
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 45.0, 1e-5);
+
+ // Now we start to invoke the periodicity. This point will "alias" to (-1,
+ // -1).
+ point = "99.0 99.0";
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(point), 61.0, 1e-5);
+
+ // We will perform several tests on a one-dimensional bound and smaller box size.
+ PeriodicHRectBound<2> a(arma::vec("5.0"));
+ point.set_size(1);
+
+ a[0] = Range(2.0, 4.0); // Entirely inside box.
+ point[0] = 7.5; // Inside first right image of the box.
+
+ BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 2.25, 1e-5);
+
+ a[0] = Range(0.0, 5.0); // Fills box fully.
+ point[1] = 19.3; // Inside the box, which covers everything.
+
+ BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 18.49, 1e-5);
+
+ a[0] = Range(-10.0, 10.0); // Larger than the box.
+ point[0] = -500.0; // Inside the box, which covers everything.
+
+ BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 25.0, 1e-5);
+
+ a[0] = Range(-2.0, 1.0); // Crosses over an edge.
+ point[0] = 2.9; // The first right image of the bound starts at 3.0.
+
+ BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 8.41, 1e-5);
+
+ a[0] = Range(2.0, 4.0); // Inside box.
+ point[0] = 0.0; // Farthest from the first right image of the bound.
+
+ BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 25.0, 1e-5);
+
+ a[0] = Range(0.0, 2.0); // On edge of box.
+ point[0] = 7.1; // 2.1 away from the first left image of the bound.
+
+ BOOST_REQUIRE_CLOSE(a.MaxDistance(point), 4.41, 1e-5);
+
+ PeriodicHRectBound<2> d(arma::vec("0.0"));
+ d[0] = Range(-10.0, 10.0); // Box is of infinite size.
+ point[0] = 810.0; // 820 away from the only left image of the box.
+
+ BOOST_REQUIRE_CLOSE(d.MinDistance(point), 672400.0, 1e-5);
+
+ PeriodicHRectBound<2> e(arma::vec("-5.0"));
+ e[0] = Range(2.0, 4.0); // Box size of -5 should function the same as 5.
+ point[0] = -10.8; // Should alias to 4.2.
+
+ BOOST_REQUIRE_CLOSE(e.MaxDistance(point), 4.84, 1e-5);
+
+ // Switch our bound to a higher dimensionality. This should ensure that the
+ // dimensions are independent like they should be.
+ PeriodicHRectBound<2> c(arma::vec("5.0 5.0 5.0 5.0 5.0 5.0 0.0 -5.0"));
+
+ c[0] = Range(2.0, 4.0); // Entirely inside box.
+ c[1] = Range(0.0, 5.0); // Fills box fully.
+ c[2] = Range(-10.0, 10.0); // Larger than the box.
+ c[3] = Range(-2.0, 1.0); // Crosses over an edge.
+ c[4] = Range(2.0, 4.0); // Inside box.
+ c[5] = Range(0.0, 2.0); // On edge of box.
+ c[6] = Range(-10.0, 10.0); // Box is of infinite size.
+ c[7] = Range(2.0, 4.0); // Box size of -5 should function the same as 5.
+
+ point.set_size(8);
+ point[0] = 7.5; // Inside first right image of the box.
+ point[1] = 19.3; // Inside the box, which covers everything.
+ point[2] = -500.0; // Inside the box, which covers everything.
+ point[3] = 2.9; // The first right image of the bound starts at 3.0.
+ point[4] = 0.0; // Closest to the first left image of the bound.
+ point[5] = 7.1; // 0.1 away from the first right image of the bound.
+ point[6] = 810.0; // 800 away from the only image of the box.
+ point[7] = -10.8; // Should alias to 4.2.
+
+ BOOST_REQUIRE_CLOSE(c.MaxDistance(point), 672630.65, 1e-10);
+}*/
+
+/**
+ * Correctly calculate the maximum distance between the bound and another bound in
+ * periodic coordinates. We have to account for the shifts necessary in
+ * periodic coordinates too, so that makes testing this a little more difficult.
+ *
+BOOST_AUTO_TEST_CASE(PeriodicHRectBoundMaxDistanceBound)
+{
+ // First, we'll start with a simple 2-dimensional case where the bounds are nonoverlapping,
+ // then one bound is on the edge of the other bound,
+ // then overlapping, then one range entirely covering the other. The box size will be large enough that this is basically the
+ // HRectBound case.
+ PeriodicHRectBound<2> b(arma::vec("100 100"));
+ PeriodicHRectBound<2> c(b);
+
+ b[0] = Range(0.0, 5.0);
+ b[1] = Range(2.0, 4.0);
+
+ // Inside the bound.
+
+ c[0] = Range(7.0, 9.0);
+ c[1] = Range(10.0,12.0);
+
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 181.0, 1e-5);
+
+ // On the edge.
+
+ c[0] = Range(5.0, 8.0);
+ c[1] = Range(4.0, 6.0);
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 80.0, 1e-5);
+
+ // Overlapping the bound.
+
+ c[0] = Range(3.0, 6.0);
+ c[1] = Range(1.0, 3.0);
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 45.0, 1e-5);
+
+ // One range entirely covering the other
+
+ c[0] = Range(0.0, 6.0);
+ c[1] = Range(1.0, 7.0);
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 61.0, 1e-5);
+
+ // Now we start to invoke the periodicity. Thess bounds "alias" to (-3.0,
+ // -1.0) and (5,0,6.0).
+
+ c[0] = Range(97.0, 99.0);
+ c[1] = Range(105.0, 106.0);
+
+ BOOST_REQUIRE_CLOSE(b.MaxDistance(c), 80.0, 1e-5);
+
+ // We will perform several tests on a one-dimensional bound and smaller box size.
+ PeriodicHRectBound<2> a(arma::vec("5.0"));
+ PeriodicHRectBound<2> d(a);
+
+ a[0] = Range(2.0, 4.0); // Entirely inside box.
+ d[0] = Range(7.5, 10); // In the right image of the box, overlapping ranges.
+
+ BOOST_REQUIRE_CLOSE(a.MaxDistance(d), 9.0, 1e-5);
+
+ a[0] = Range(0.0, 5.0); // Fills box fully.
+ d[0] = Range(19.3, 21.0); // Two intervals inside the box, same as range of b[0].
+
+ BOOST_REQUIRE_CLOSE(a.MaxDistance(d), 18.49, 1e-5);
+
+ a[0] = Range(-10.0, 10.0); // Larger than the box.
+ d[0] = Range(-500.0, -498.0); // Inside the box, which covers everything.
+
+ BOOST_REQUIRE_CLOSE(a.MaxDistance(d), 9.0, 1e-5);
+
+ a[0] = Range(-2.0, 1.0); // Crosses over an edge.
+ d[0] = Range(2.9, 5.1); // The first right image of the bound starts at 3.0.
+
+ BOOST_REQUIRE_CLOSE(a.MaxDistance(d), 24.01, 1e-5);
+}*/
+
+
+/**
+ * It seems as though Bill has stumbled across a bug where
+ * BinarySpaceTree<>::count() returns something different than
+ * BinarySpaceTree<>::count_. So, let's build a simple tree and make sure they
+ * are the same.
+ */
+BOOST_AUTO_TEST_CASE(TreeCountMismatch)
+{
+ arma::mat dataset = "2.0 5.0 9.0 4.0 8.0 7.0;"
+ "3.0 4.0 6.0 7.0 1.0 2.0 ";
+
+ // Leaf size of 1.
+ BinarySpaceTree<HRectBound<2> > rootNode(dataset, 1);
+
+ BOOST_REQUIRE(rootNode.Count() == 6);
+ BOOST_REQUIRE(rootNode.Left()->Count() == 3);
+ BOOST_REQUIRE(rootNode.Left()->Left()->Count() == 2);
+ BOOST_REQUIRE(rootNode.Left()->Left()->Left()->Count() == 1);
+ BOOST_REQUIRE(rootNode.Left()->Left()->Right()->Count() == 1);
+ BOOST_REQUIRE(rootNode.Left()->Right()->Count() == 1);
+ BOOST_REQUIRE(rootNode.Right()->Count() == 3);
+ BOOST_REQUIRE(rootNode.Right()->Left()->Count() == 2);
+ BOOST_REQUIRE(rootNode.Right()->Left()->Left()->Count() == 1);
+ BOOST_REQUIRE(rootNode.Right()->Left()->Right()->Count() == 1);
+ BOOST_REQUIRE(rootNode.Right()->Right()->Count() == 1);
+}
+
+BOOST_AUTO_TEST_CASE(CheckParents)
+{
+ arma::mat dataset = "2.0 5.0 9.0 4.0 8.0 7.0;"
+ "3.0 4.0 6.0 7.0 1.0 2.0 ";
+
+ // Leaf size of 1.
+ BinarySpaceTree<HRectBound<2> > rootNode(dataset, 1);
+
+ BOOST_REQUIRE_EQUAL(rootNode.Parent(),
+ (BinarySpaceTree<HRectBound<2> >*) NULL);
+ BOOST_REQUIRE_EQUAL(&rootNode, rootNode.Left()->Parent());
+ BOOST_REQUIRE_EQUAL(&rootNode, rootNode.Right()->Parent());
+ BOOST_REQUIRE_EQUAL(rootNode.Left(), rootNode.Left()->Left()->Parent());
+ BOOST_REQUIRE_EQUAL(rootNode.Left(), rootNode.Left()->Right()->Parent());
+ BOOST_REQUIRE_EQUAL(rootNode.Left()->Left(),
+ rootNode.Left()->Left()->Left()->Parent());
+ BOOST_REQUIRE_EQUAL(rootNode.Left()->Left(),
+ rootNode.Left()->Left()->Right()->Parent());
+ BOOST_REQUIRE_EQUAL(rootNode.Right(), rootNode.Right()->Left()->Parent());
+ BOOST_REQUIRE_EQUAL(rootNode.Right(), rootNode.Right()->Right()->Parent());
+ BOOST_REQUIRE_EQUAL(rootNode.Right()->Left(),
+ rootNode.Right()->Left()->Left()->Parent());
+ BOOST_REQUIRE_EQUAL(rootNode.Right()->Left(),
+ rootNode.Right()->Left()->Right()->Parent());
+}
+
+BOOST_AUTO_TEST_CASE(CheckDataset)
+{
+ arma::mat dataset = "2.0 5.0 9.0 4.0 8.0 7.0;"
+ "3.0 4.0 6.0 7.0 1.0 2.0 ";
+
+ // Leaf size of 1.
+ BinarySpaceTree<HRectBound<2> > rootNode(dataset, 1);
+
+ BOOST_REQUIRE_EQUAL(&rootNode.Dataset(), &dataset);
+ BOOST_REQUIRE_EQUAL(&rootNode.Left()->Dataset(), &dataset);
+ BOOST_REQUIRE_EQUAL(&rootNode.Right()->Dataset(), &dataset);
+ BOOST_REQUIRE_EQUAL(&rootNode.Left()->Left()->Dataset(), &dataset);
+ BOOST_REQUIRE_EQUAL(&rootNode.Left()->Right()->Dataset(), &dataset);
+ BOOST_REQUIRE_EQUAL(&rootNode.Right()->Left()->Dataset(), &dataset);
+ BOOST_REQUIRE_EQUAL(&rootNode.Right()->Right()->Dataset(), &dataset);
+ BOOST_REQUIRE_EQUAL(&rootNode.Left()->Left()->Left()->Dataset(), &dataset);
+ BOOST_REQUIRE_EQUAL(&rootNode.Left()->Left()->Right()->Dataset(), &dataset);
+ BOOST_REQUIRE_EQUAL(&rootNode.Right()->Left()->Left()->Dataset(), &dataset);
+ BOOST_REQUIRE_EQUAL(&rootNode.Right()->Left()->Right()->Dataset(), &dataset);
+}
+
+// Ensure FurthestDescendantDistance() works.
+BOOST_AUTO_TEST_CASE(FurthestDescendantDistanceTest)
+{
+ arma::mat dataset = "1; 3"; // One point.
+ BinarySpaceTree<HRectBound<2> > rootNode(dataset, 1);
+
+ BOOST_REQUIRE_SMALL(rootNode.FurthestDescendantDistance(), 1e-5);
+
+ dataset = "1 -1; 1 -1"; // Square of size [2, 2].
+
+ // Both points are contained in the one node.
+ BinarySpaceTree<HRectBound<2> > twoPoint(dataset);
+ BOOST_REQUIRE_CLOSE(twoPoint.FurthestDescendantDistance(), sqrt(2.0), 1e-5);
+}
+
+// Forward declaration of methods we need for the next test.
+template<typename TreeType, typename MatType>
+bool CheckPointBounds(TreeType* node, const MatType& data);
+
+template<typename TreeType>
+void GenerateVectorOfTree(TreeType* node,
+ size_t depth,
+ std::vector<TreeType*>& v);
+
+template<int t_pow>
+bool DoBoundsIntersect(HRectBound<t_pow>& a,
+ HRectBound<t_pow>& b,
+ size_t ia,
+ size_t ib);
+
+/**
+ * Exhaustive kd-tree test based on #125.
+ *
+ * - Generate a random dataset of a random size.
+ * - Build a tree on that dataset.
+ * - Ensure all the permutation indices map back to the correct points.
+ * - Verify that each point is contained inside all of the bounds of its parent
+ * nodes.
+ * - Verify that each bound at a particular level of the tree does not overlap
+ * with any other bounds at that level.
+ *
+ * Then, we do that whole process a handful of times.
+ */
+BOOST_AUTO_TEST_CASE(KdTreeTest)
+{
+ typedef BinarySpaceTree<HRectBound<2> > TreeType;
+
+ size_t maxRuns = 10; // Ten total tests.
+ size_t pointIncrements = 1000; // Range is from 2000 points to 11000.
+
+ // We use the default leaf size of 20.
+ for(size_t run = 0; run < maxRuns; run++)
+ {
+ size_t dimensions = run + 2;
+ size_t maxPoints = (run + 1) * pointIncrements;
+
+ size_t size = maxPoints;
+ arma::mat dataset = arma::mat(dimensions, size);
+ arma::mat datacopy; // Used to test mappings.
+
+ // Mappings for post-sort verification of data.
+ std::vector<size_t> newToOld;
+ std::vector<size_t> oldToNew;
+
+ // Generate data.
+ dataset.randu();
+ datacopy = dataset; // Save a copy.
+
+ // Build the tree itself.
+ TreeType root(dataset, newToOld, oldToNew);
+
+ // Ensure the size of the tree is correct.
+ BOOST_REQUIRE_EQUAL(root.Count(), size);
+
+ // Check the forward and backward mappings for correctness.
+ for(size_t i = 0; i < size; i++)
+ {
+ for(size_t j = 0; j < dimensions; j++)
+ {
+ BOOST_REQUIRE_EQUAL(dataset(j, i), datacopy(j, newToOld[i]));
+ BOOST_REQUIRE_EQUAL(dataset(j, oldToNew[i]), datacopy(j, i));
+ }
+ }
+
+ // Now check that each point is contained inside of all bounds above it.
+ CheckPointBounds(&root, dataset);
+
+ // Now check that no peers overlap.
+ std::vector<TreeType*> v;
+ GenerateVectorOfTree(&root, 1, v);
+
+ // Start with the first pair.
+ size_t depth = 2;
+ // Compare each peer against every other peer.
+ while (depth < v.size())
+ {
+ for (size_t i = depth; i < 2 * depth && i < v.size(); i++)
+ for (size_t j = i + 1; j < 2 * depth && j < v.size(); j++)
+ if (v[i] != NULL && v[j] != NULL)
+ BOOST_REQUIRE(!DoBoundsIntersect(v[i]->Bound(), v[j]->Bound(),
+ i, j));
+
+ depth *= 2;
+ }
+ }
+
+ arma::mat dataset = arma::mat(25, 1000);
+ for (size_t col = 0; col < dataset.n_cols; ++col)
+ for (size_t row = 0; row < dataset.n_rows; ++row)
+ dataset(row, col) = row + col;
+
+ TreeType root(dataset);
+ // Check the tree size.
+ BOOST_REQUIRE_EQUAL(root.TreeSize(), 127);
+ // Check the tree depth.
+ BOOST_REQUIRE_EQUAL(root.TreeDepth(), 7);
+}
+
+// Recursively checks that each node contains all points that it claims to have.
+template<typename TreeType, typename MatType>
+bool CheckPointBounds(TreeType* node, const MatType& data)
+{
+ if (node == NULL) // We have passed a leaf node.
+ return true;
+
+ TreeType* left = node->Left();
+ TreeType* right = node->Right();
+
+ size_t begin = node->Begin();
+ size_t count = node->Count();
+
+ // Check that each point which this tree claims is actually inside the tree.
+ for (size_t index = begin; index < begin + count; index++)
+ if (!node->Bound().Contains(data.col(index)))
+ return false;
+
+ return CheckPointBounds(left, data) && CheckPointBounds(right, data);
+}
+
+template<int t_pow>
+bool DoBoundsIntersect(HRectBound<t_pow>& a,
+ HRectBound<t_pow>& b,
+ size_t /* ia */,
+ size_t /* ib */)
+{
+ size_t dimensionality = a.Dim();
+
+ Range r_a;
+ Range r_b;
+
+ for (size_t i = 0; i < dimensionality; i++)
+ {
+ r_a = a[i];
+ r_b = b[i];
+ if (r_a < r_b || r_a > r_b) // If a does not overlap b at all.
+ return false;
+ }
+
+ return true;
+}
+
+template<typename TreeType>
+void GenerateVectorOfTree(TreeType* node,
+ size_t depth,
+ std::vector<TreeType*>& v)
+{
+ if (node == NULL)
+ return;
+
+ if (depth >= v.size())
+ v.resize(2 * depth + 1, NULL); // Resize to right size; fill with NULL.
+
+ v[depth] = node;
+
+ // Recurse to the left and right children.
+ GenerateVectorOfTree(node->Left(), depth * 2, v);
+ GenerateVectorOfTree(node->Right(), depth * 2 + 1, v);
+
+ return;
+}
+
+#ifdef ARMA_HAS_SPMAT
+// Only run sparse tree tests if we are using Armadillo 3.6. Armadillo 3.4 has
+// some bugs that cause the kd-tree splitting procedure to fail terribly. Soon,
+// that version will be obsolete, though.
+#if !((ARMA_VERSION_MAJOR == 3) && (ARMA_VERSION_MINOR == 4))
+
+/**
+ * Exhaustive sparse kd-tree test based on #125.
+ *
+ * - Generate a random dataset of a random size.
+ * - Build a tree on that dataset.
+ * - Ensure all the permutation indices map back to the correct points.
+ * - Verify that each point is contained inside all of the bounds of its parent
+ * nodes.
+ * - Verify that each bound at a particular level of the tree does not overlap
+ * with any other bounds at that level.
+ *
+ * Then, we do that whole process a handful of times.
+ */
+BOOST_AUTO_TEST_CASE(ExhaustiveSparseKDTreeTest)
+{
+ typedef BinarySpaceTree<HRectBound<2>, EmptyStatistic, arma::SpMat<double> >
+ TreeType;
+
+ size_t maxRuns = 2; // Two total tests.
+ size_t pointIncrements = 200; // Range is from 200 points to 400.
+
+ // We use the default leaf size of 20.
+ for(size_t run = 0; run < maxRuns; run++)
+ {
+ size_t dimensions = run + 2;
+ size_t maxPoints = (run + 1) * pointIncrements;
+
+ size_t size = maxPoints;
+ arma::SpMat<double> dataset = arma::SpMat<double>(dimensions, size);
+ arma::SpMat<double> datacopy; // Used to test mappings.
+
+ // Mappings for post-sort verification of data.
+ std::vector<size_t> newToOld;
+ std::vector<size_t> oldToNew;
+
+ // Generate data.
+ dataset.sprandu(dimensions, size, 0.1);
+ datacopy = dataset; // Save a copy.
+
+ // Build the tree itself.
+ TreeType root(dataset, newToOld, oldToNew);
+
+ // Ensure the size of the tree is correct.
+ BOOST_REQUIRE_EQUAL(root.Count(), size);
+
+ // Check the forward and backward mappings for correctness.
+ for(size_t i = 0; i < size; i++)
+ {
+ for(size_t j = 0; j < dimensions; j++)
+ {
+ BOOST_REQUIRE_EQUAL(dataset(j, i), datacopy(j, newToOld[i]));
+ BOOST_REQUIRE_EQUAL(dataset(j, oldToNew[i]), datacopy(j, i));
+ }
+ }
+
+ // Now check that each point is contained inside of all bounds above it.
+ CheckPointBounds(&root, dataset);
+
+ // Now check that no peers overlap.
+ std::vector<TreeType*> v;
+ GenerateVectorOfTree(&root, 1, v);
+
+ // Start with the first pair.
+ size_t depth = 2;
+ // Compare each peer against every other peer.
+ while (depth < v.size())
+ {
+ for (size_t i = depth; i < 2 * depth && i < v.size(); i++)
+ for (size_t j = i + 1; j < 2 * depth && j < v.size(); j++)
+ if (v[i] != NULL && v[j] != NULL)
+ BOOST_REQUIRE(!DoBoundsIntersect(v[i]->Bound(), v[j]->Bound(),
+ i, j));
+
+ depth *= 2;
+ }
+ }
+
+ arma::SpMat<double> dataset(25, 1000);
+ for (size_t col = 0; col < dataset.n_cols; ++col)
+ for (size_t row = 0; row < dataset.n_rows; ++row)
+ dataset(row, col) = row + col;
+
+ TreeType root(dataset);
+ // Check the tree size.
+ BOOST_REQUIRE_EQUAL(root.TreeSize(), 127);
+ // Check the tree depth.
+ BOOST_REQUIRE_EQUAL(root.TreeDepth(), 7);
+}
+
+#endif // Using Armadillo 3.4.
+#endif // ARMA_HAS_SPMAT
+
+template<typename TreeType>
+void RecurseTreeCountLeaves(const TreeType& node, arma::vec& counts)
+{
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ {
+ if (node.Child(i).NumChildren() == 0)
+ counts[node.Child(i).Point()]++;
+ else
+ RecurseTreeCountLeaves<TreeType>(node.Child(i), counts);
+ }
+}
+
+template<typename TreeType>
+void CheckSelfChild(const TreeType& node)
+{
+ if (node.NumChildren() == 0)
+ return; // No self-child applicable here.
+
+ bool found = false;
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ {
+ if (node.Child(i).Point() == node.Point())
+ found = true;
+
+ // Recursively check the children.
+ CheckSelfChild(node.Child(i));
+ }
+
+ // Ensure this has its own self-child.
+ BOOST_REQUIRE_EQUAL(found, true);
+}
+
+template<typename TreeType, typename MetricType>
+void CheckCovering(const TreeType& node)
+{
+ // Return if a leaf. No checking necessary.
+ if (node.NumChildren() == 0)
+ return;
+
+ const arma::mat& dataset = node.Dataset();
+ const size_t nodePoint = node.Point();
+
+ // To ensure that this node satisfies the covering principle, we must ensure
+ // that the distance to each child is less than pow(base, scale).
+ double maxDistance = pow(node.Base(), node.Scale());
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ {
+ const size_t childPoint = node.Child(i).Point();
+
+ double distance = MetricType::Evaluate(dataset.col(nodePoint),
+ dataset.col(childPoint));
+
+ BOOST_REQUIRE_LE(distance, maxDistance);
+
+ // Check the child.
+ CheckCovering<TreeType, MetricType>(node.Child(i));
+ }
+}
+
+template<typename TreeType, typename MetricType>
+void CheckIndividualSeparation(const TreeType& constantNode,
+ const TreeType& node)
+{
+ // Don't check points at a lower scale.
+ if (node.Scale() < constantNode.Scale())
+ return;
+
+ // If at a higher scale, recurse.
+ if (node.Scale() > constantNode.Scale())
+ {
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ {
+ // Don't recurse into leaves.
+ if (node.Child(i).NumChildren() > 0)
+ CheckIndividualSeparation<TreeType, MetricType>(constantNode,
+ node.Child(i));
+ }
+
+ return;
+ }
+
+ // Don't compare the same point against itself.
+ if (node.Point() == constantNode.Point())
+ return;
+
+ // Now we know we are at the same scale, so make the comparison.
+ const arma::mat& dataset = constantNode.Dataset();
+ const size_t constantPoint = constantNode.Point();
+ const size_t nodePoint = node.Point();
+
+ // Make sure the distance is at least the following value (in accordance with
+ // the separation principle of cover trees).
+ double minDistance = pow(constantNode.Base(),
+ constantNode.Scale());
+
+ double distance = MetricType::Evaluate(dataset.col(constantPoint),
+ dataset.col(nodePoint));
+
+ BOOST_REQUIRE_GE(distance, minDistance);
+}
+
+template<typename TreeType, typename MetricType>
+void CheckSeparation(const TreeType& node, const TreeType& root)
+{
+ // Check the separation between this point and all other points on this scale.
+ CheckIndividualSeparation<TreeType, MetricType>(node, root);
+
+ // Check the children, but only if they are not leaves. Leaves don't need to
+ // be checked.
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ if (node.Child(i).NumChildren() > 0)
+ CheckSeparation<TreeType, MetricType>(node.Child(i), root);
+}
+
+
+/**
+ * Create a simple cover tree and then make sure it is valid.
+ */
+BOOST_AUTO_TEST_CASE(SimpleCoverTreeConstructionTest)
+{
+ // 20-point dataset.
+ arma::mat data = arma::trans(arma::mat("0.0 0.0;"
+ "1.0 0.0;"
+ "0.5 0.5;"
+ "2.0 2.0;"
+ "-1.0 2.0;"
+ "3.0 0.0;"
+ "1.5 5.5;"
+ "-2.0 -2.0;"
+ "-1.5 1.5;"
+ "0.0 4.0;"
+ "2.0 1.0;"
+ "2.0 1.2;"
+ "-3.0 -2.5;"
+ "-5.0 -5.0;"
+ "3.5 1.5;"
+ "2.0 2.5;"
+ "-1.0 -1.0;"
+ "-3.5 1.5;"
+ "3.5 -1.5;"
+ "2.0 1.0;"));
+
+ // The root point will be the first point, (0, 0).
+ CoverTree<> tree(data); // Expansion constant of 2.0.
+
+ // The furthest point from the root will be (-5, -5), with a distance of
+ // of sqrt(50). This means the scale of the root node should be 3 (because
+ // 2^3 = 8).
+ BOOST_REQUIRE_EQUAL(tree.Scale(), 3);
+
+ // Now loop through the tree and ensure that each leaf is only created once.
+ arma::vec counts;
+ counts.zeros(20);
+ RecurseTreeCountLeaves(tree, counts);
+
+ // Each point should only have one leaf node representing it.
+ for (size_t i = 0; i < 20; ++i)
+ BOOST_REQUIRE_EQUAL(counts[i], 1);
+
+ // Each non-leaf should have a self-child.
+ CheckSelfChild<CoverTree<> >(tree);
+
+ // Each node must satisfy the covering principle (its children must be less
+ // than or equal to a certain distance apart).
+ CheckCovering<CoverTree<>, LMetric<2, true> >(tree);
+
+ // Each node's children must be separated by at least a certain value.
+ CheckSeparation<CoverTree<>, LMetric<2, true> >(tree, tree);
+}
+
+/**
+ * Create a large cover tree and make sure it's accurate.
+ */
+BOOST_AUTO_TEST_CASE(CoverTreeConstructionTest)
+{
+ arma::mat dataset;
+ // 50-dimensional, 1000 point.
+ dataset.randu(50, 1000);
+
+ CoverTree<> tree(dataset);
+
+ // Ensure each leaf is only created once.
+ arma::vec counts;
+ counts.zeros(1000);
+ RecurseTreeCountLeaves(tree, counts);
+
+ for (size_t i = 0; i < 1000; ++i)
+ BOOST_REQUIRE_EQUAL(counts[i], 1);
+
+ // Each non-leaf should have a self-child.
+ CheckSelfChild<CoverTree<> >(tree);
+
+ // Each node must satisfy the covering principle (its children must be less
+ // than or equal to a certain distance apart).
+ CheckCovering<CoverTree<>, LMetric<2, true> >(tree);
+
+ // Each node's children must be separated by at least a certain value.
+ CheckSeparation<CoverTree<>, LMetric<2, true> >(tree, tree);
+}
+
+/**
+ * Test the manual constructor.
+ */
+BOOST_AUTO_TEST_CASE(CoverTreeManualConstructorTest)
+{
+ arma::mat dataset;
+ dataset.zeros(10, 10);
+
+ CoverTree<> node(dataset, 1.3, 3, 2, NULL, 1.5, 2.75);
+
+ BOOST_REQUIRE_EQUAL(&node.Dataset(), &dataset);
+ BOOST_REQUIRE_EQUAL(node.Base(), 1.3);
+ BOOST_REQUIRE_EQUAL(node.Point(), 3);
+ BOOST_REQUIRE_EQUAL(node.Scale(), 2);
+ BOOST_REQUIRE_EQUAL(node.Parent(), (CoverTree<>*) NULL);
+ BOOST_REQUIRE_EQUAL(node.ParentDistance(), 1.5);
+ BOOST_REQUIRE_EQUAL(node.FurthestDescendantDistance(), 2.75);
+}
+
+/**
+ * Make sure cover trees work in different metric spaces.
+ */
+BOOST_AUTO_TEST_CASE(CoverTreeAlternateMetricTest)
+{
+ arma::mat dataset;
+ // 5-dimensional, 300-point dataset.
+ dataset.randu(5, 300);
+
+ CoverTree<LMetric<1, true> > tree(dataset);
+
+ // Ensure each leaf is only created once.
+ arma::vec counts;
+ counts.zeros(300);
+ RecurseTreeCountLeaves<CoverTree<LMetric<1, true> > >(tree, counts);
+
+ for (size_t i = 0; i < 300; ++i)
+ BOOST_REQUIRE_EQUAL(counts[i], 1);
+
+ // Each non-leaf should have a self-child.
+ CheckSelfChild<CoverTree<LMetric<1, true> > >(tree);
+
+ // Each node must satisfy the covering principle (its children must be less
+ // than or equal to a certain distance apart).
+ CheckCovering<CoverTree<LMetric<1, true> >, LMetric<1, true> >(tree);
+
+ // Each node's children must be separated by at least a certain value.
+ CheckSeparation<CoverTree<LMetric<1, true> >, LMetric<1, true> >(tree, tree);
+}
+
+/**
+ * Make sure copy constructor works for the cover tree.
+ */
+BOOST_AUTO_TEST_CASE(CoverTreeCopyConstructor)
+{
+ arma::mat dataset;
+ dataset.randu(10, 10); // dataset is irrelevant.
+ CoverTree<> c(dataset, 1.3, 0, 5, NULL, 1.45, 5.2); // Random parameters.
+ c.Children().push_back(new CoverTree<>(dataset, 1.3, 1, 4, &c, 1.3, 2.45));
+ c.Children().push_back(new CoverTree<>(dataset, 1.5, 2, 3, &c, 1.2, 5.67));
+
+ CoverTree<> d = c;
+
+ // Check that everything is the same.
+ BOOST_REQUIRE_EQUAL(c.Dataset().memptr(), d.Dataset().memptr());
+ BOOST_REQUIRE_CLOSE(c.Base(), d.Base(), 1e-50);
+ BOOST_REQUIRE_EQUAL(c.Point(), d.Point());
+ BOOST_REQUIRE_EQUAL(c.Scale(), d.Scale());
+ BOOST_REQUIRE_EQUAL(c.Parent(), d.Parent());
+ BOOST_REQUIRE_EQUAL(c.ParentDistance(), d.ParentDistance());
+ BOOST_REQUIRE_EQUAL(c.FurthestDescendantDistance(),
+ d.FurthestDescendantDistance());
+ BOOST_REQUIRE_EQUAL(c.NumChildren(), d.NumChildren());
+ BOOST_REQUIRE_NE(&c.Child(0), &d.Child(0));
+ BOOST_REQUIRE_NE(&c.Child(1), &d.Child(1));
+
+ BOOST_REQUIRE_EQUAL(c.Child(0).Parent(), &c);
+ BOOST_REQUIRE_EQUAL(c.Child(1).Parent(), &c);
+ BOOST_REQUIRE_EQUAL(d.Child(0).Parent(), &d);
+ BOOST_REQUIRE_EQUAL(d.Child(1).Parent(), &d);
+
+ // Check that the children are okay.
+ BOOST_REQUIRE_EQUAL(c.Child(0).Dataset().memptr(),
+ d.Child(0).Dataset().memptr());
+ BOOST_REQUIRE_CLOSE(c.Child(0).Base(), d.Child(0).Base(), 1e-50);
+ BOOST_REQUIRE_EQUAL(c.Child(0).Point(), d.Child(0).Point());
+ BOOST_REQUIRE_EQUAL(c.Child(0).Scale(), d.Child(0).Scale());
+ BOOST_REQUIRE_EQUAL(c.Child(0).ParentDistance(), d.Child(0).ParentDistance());
+ BOOST_REQUIRE_EQUAL(c.Child(0).FurthestDescendantDistance(),
+ d.Child(0).FurthestDescendantDistance());
+ BOOST_REQUIRE_EQUAL(c.Child(0).NumChildren(), d.Child(0).NumChildren());
+
+ BOOST_REQUIRE_EQUAL(c.Child(1).Dataset().memptr(),
+ d.Child(1).Dataset().memptr());
+ BOOST_REQUIRE_CLOSE(c.Child(1).Base(), d.Child(1).Base(), 1e-50);
+ BOOST_REQUIRE_EQUAL(c.Child(1).Point(), d.Child(1).Point());
+ BOOST_REQUIRE_EQUAL(c.Child(1).Scale(), d.Child(1).Scale());
+ BOOST_REQUIRE_EQUAL(c.Child(1).ParentDistance(), d.Child(1).ParentDistance());
+ BOOST_REQUIRE_EQUAL(c.Child(1).FurthestDescendantDistance(),
+ d.Child(1).FurthestDescendantDistance());
+ BOOST_REQUIRE_EQUAL(c.Child(1).NumChildren(), d.Child(1).NumChildren());
+}
+
+/**
+ * Make sure copy constructor works right for the binary space tree.
+ */
+BOOST_AUTO_TEST_CASE(BinarySpaceTreeCopyConstructor)
+{
+ arma::mat data("1");
+ BinarySpaceTree<HRectBound<2> > b(data);
+ b.Begin() = 10;
+ b.Count() = 50;
+ b.Left() = new BinarySpaceTree<HRectBound<2> >(data);
+ b.Left()->Begin() = 10;
+ b.Left()->Count() = 30;
+ b.Right() = new BinarySpaceTree<HRectBound<2> >(data);
+ b.Right()->Begin() = 40;
+ b.Right()->Count() = 20;
+
+ // Copy the tree.
+ BinarySpaceTree<HRectBound<2> > c(b);
+
+ // Ensure everything copied correctly.
+ BOOST_REQUIRE_EQUAL(b.Begin(), c.Begin());
+ BOOST_REQUIRE_EQUAL(b.Count(), c.Count());
+ BOOST_REQUIRE_NE(b.Left(), c.Left());
+ BOOST_REQUIRE_NE(b.Right(), c.Right());
+
+ // Check the children.
+ BOOST_REQUIRE_EQUAL(b.Left()->Begin(), c.Left()->Begin());
+ BOOST_REQUIRE_EQUAL(b.Left()->Count(), c.Left()->Count());
+ BOOST_REQUIRE_EQUAL(b.Left()->Left(),
+ (BinarySpaceTree<HRectBound<2> >*) NULL);
+ BOOST_REQUIRE_EQUAL(b.Left()->Left(), c.Left()->Left());
+ BOOST_REQUIRE_EQUAL(b.Left()->Right(),
+ (BinarySpaceTree<HRectBound<2> >*) NULL);
+ BOOST_REQUIRE_EQUAL(b.Left()->Right(), c.Left()->Right());
+
+ BOOST_REQUIRE_EQUAL(b.Right()->Begin(), c.Right()->Begin());
+ BOOST_REQUIRE_EQUAL(b.Right()->Count(), c.Right()->Count());
+ BOOST_REQUIRE_EQUAL(b.Right()->Left(),
+ (BinarySpaceTree<HRectBound<2> >*) NULL);
+ BOOST_REQUIRE_EQUAL(b.Right()->Left(), c.Right()->Left());
+ BOOST_REQUIRE_EQUAL(b.Right()->Right(),
+ (BinarySpaceTree<HRectBound<2> >*) NULL);
+ BOOST_REQUIRE_EQUAL(b.Right()->Right(), c.Right()->Right());
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/tree_traits_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/tree_traits_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/tree_traits_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,79 +0,0 @@
-/**
- * @file tree_traits_test.cpp
- * @author Ryan Curtin
- *
- * Tests for the TreeTraits class. These could all be known at compile-time,
- * but realistically the function is to be sure that nobody changes tree traits
- * without breaking something. Thus, people must be certain when they make a
- * change like that (because they have to change the test too). That's the
- * hope, at least...
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/core.hpp>
-#include <mlpack/core/tree/tree_traits.hpp>
-#include <mlpack/core/tree/binary_space_tree.hpp>
-#include <mlpack/core/tree/cover_tree.hpp>
-
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::tree;
-using namespace mlpack::metric;
-
-BOOST_AUTO_TEST_SUITE(TreeTraitsTest);
-
-// Be careful! When writing new tests, always get the boolean value and store
-// it in a temporary, because the Boost unit test macros do weird things and
-// will cause bizarre problems.
-
-// Test the defaults.
-BOOST_AUTO_TEST_CASE(DefaultsTraitsTest)
-{
- // An irrelevant non-tree type class is used here so that the default
- // implementation of TreeTraits is chosen.
- bool b = TreeTraits<int>::HasParentDistance;
- BOOST_REQUIRE_EQUAL(b, false);
- b = TreeTraits<int>::HasOverlappingChildren;
- BOOST_REQUIRE_EQUAL(b, true);
-}
-
-// Test the binary space tree traits.
-BOOST_AUTO_TEST_CASE(BinarySpaceTreeTraitsTest)
-{
- // ParentDistance() is not available.
- bool b = TreeTraits<BinarySpaceTree<LMetric<2, false> > >::HasParentDistance;
- BOOST_REQUIRE_EQUAL(b, false);
-
- // Children are non-overlapping.
- b = TreeTraits<BinarySpaceTree<LMetric<2, false> > >::HasOverlappingChildren;
- BOOST_REQUIRE_EQUAL(b, false);
-}
-
-// Test the cover tree traits.
-BOOST_AUTO_TEST_CASE(CoverTreeTraitsTest)
-{
- // ParentDistance() is available.
- bool b = TreeTraits<CoverTree<> >::HasParentDistance;
- BOOST_REQUIRE_EQUAL(b, true);
-
- // Children may be overlapping.
- b = TreeTraits<CoverTree<> >::HasOverlappingChildren;
- BOOST_REQUIRE_EQUAL(b, true);
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/tree_traits_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/tree_traits_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/tree_traits_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/tree_traits_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,79 @@
+/**
+ * @file tree_traits_test.cpp
+ * @author Ryan Curtin
+ *
+ * Tests for the TreeTraits class. These could all be known at compile-time,
+ * but realistically the function is to be sure that nobody changes tree traits
+ * without breaking something. Thus, people must be certain when they make a
+ * change like that (because they have to change the test too). That's the
+ * hope, at least...
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/tree/tree_traits.hpp>
+#include <mlpack/core/tree/binary_space_tree.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::tree;
+using namespace mlpack::metric;
+
+BOOST_AUTO_TEST_SUITE(TreeTraitsTest);
+
+// Be careful! When writing new tests, always get the boolean value and store
+// it in a temporary, because the Boost unit test macros do weird things and
+// will cause bizarre problems.
+
+// Test the defaults.
+BOOST_AUTO_TEST_CASE(DefaultsTraitsTest)
+{
+ // An irrelevant non-tree type class is used here so that the default
+ // implementation of TreeTraits is chosen.
+ bool b = TreeTraits<int>::HasParentDistance;
+ BOOST_REQUIRE_EQUAL(b, false);
+ b = TreeTraits<int>::HasOverlappingChildren;
+ BOOST_REQUIRE_EQUAL(b, true);
+}
+
+// Test the binary space tree traits.
+BOOST_AUTO_TEST_CASE(BinarySpaceTreeTraitsTest)
+{
+ // ParentDistance() is not available.
+ bool b = TreeTraits<BinarySpaceTree<LMetric<2, false> > >::HasParentDistance;
+ BOOST_REQUIRE_EQUAL(b, false);
+
+ // Children are non-overlapping.
+ b = TreeTraits<BinarySpaceTree<LMetric<2, false> > >::HasOverlappingChildren;
+ BOOST_REQUIRE_EQUAL(b, false);
+}
+
+// Test the cover tree traits.
+BOOST_AUTO_TEST_CASE(CoverTreeTraitsTest)
+{
+ // ParentDistance() is available.
+ bool b = TreeTraits<CoverTree<> >::HasParentDistance;
+ BOOST_REQUIRE_EQUAL(b, true);
+
+ // Children may be overlapping.
+ b = TreeTraits<CoverTree<> >::HasOverlappingChildren;
+ BOOST_REQUIRE_EQUAL(b, true);
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Deleted: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/union_find_test.cpp
===================================================================
--- mlpack/branches/mlpack-1.x/src/mlpack/tests/union_find_test.cpp 2013-05-02 03:09:17 UTC (rev 14996)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/union_find_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -1,64 +0,0 @@
-/**
- * @file union_find_test.cpp
- * @author Bill March (march at gatech.edu)
- *
- * Unit tests for the Union-Find data structure.
- *
- * This file is part of MLPACK 1.0.4.
- *
- * MLPACK is free software: you can redistribute it and/or modify it under the
- * terms of the GNU Lesser General Public License as published by the Free
- * Software Foundation, either version 3 of the License, or (at your option) any
- * later version.
- *
- * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
- * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- * details (LICENSE.txt).
- *
- * You should have received a copy of the GNU General Public License along with
- * MLPACK. If not, see <http://www.gnu.org/licenses/>.
- */
-#include <mlpack/methods/emst/union_find.hpp>
-
-#include <mlpack/core.hpp>
-#include <boost/test/unit_test.hpp>
-#include "old_boost_test_definitions.hpp"
-
-using namespace mlpack;
-using namespace mlpack::emst;
-
-BOOST_AUTO_TEST_SUITE(UnionFindTest);
-
-BOOST_AUTO_TEST_CASE(TestFind)
-{
- static const size_t testSize_ = 10;
- UnionFind testUnionFind_(testSize_);
-
- for (size_t i = 0; i < testSize_; i++)
- BOOST_REQUIRE(testUnionFind_.Find(i) == i);
-
- testUnionFind_.Union(0, 1);
- testUnionFind_.Union(1, 2);
-
- BOOST_REQUIRE(testUnionFind_.Find(2) == testUnionFind_.Find(0));
-}
-
-BOOST_AUTO_TEST_CASE(TestUnion)
-{
- static const size_t testSize_ = 10;
- UnionFind testUnionFind_(testSize_);
-
- testUnionFind_.Union(0, 1);
- testUnionFind_.Union(2, 3);
- testUnionFind_.Union(0, 2);
- testUnionFind_.Union(5, 0);
- testUnionFind_.Union(0, 6);
-
- BOOST_REQUIRE(testUnionFind_.Find(0) == testUnionFind_.Find(1));
- BOOST_REQUIRE(testUnionFind_.Find(2) == testUnionFind_.Find(3));
- BOOST_REQUIRE(testUnionFind_.Find(1) == testUnionFind_.Find(5));
- BOOST_REQUIRE(testUnionFind_.Find(6) == testUnionFind_.Find(3));
-}
-
-BOOST_AUTO_TEST_SUITE_END();
Copied: mlpack/tags/mlpack-1.0.5/src/mlpack/tests/union_find_test.cpp (from rev 14998, mlpack/branches/mlpack-1.x/src/mlpack/tests/union_find_test.cpp)
===================================================================
--- mlpack/tags/mlpack-1.0.5/src/mlpack/tests/union_find_test.cpp (rev 0)
+++ mlpack/tags/mlpack-1.0.5/src/mlpack/tests/union_find_test.cpp 2013-05-02 04:20:13 UTC (rev 15001)
@@ -0,0 +1,64 @@
+/**
+ * @file union_find_test.cpp
+ * @author Bill March (march at gatech.edu)
+ *
+ * Unit tests for the Union-Find data structure.
+ *
+ * This file is part of MLPACK 1.0.5.
+ *
+ * MLPACK is free software: you can redistribute it and/or modify it under the
+ * terms of the GNU Lesser General Public License as published by the Free
+ * Software Foundation, either version 3 of the License, or (at your option) any
+ * later version.
+ *
+ * MLPACK is distributed in the hope that it will be useful, but WITHOUT ANY
+ * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+ * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+ * details (LICENSE.txt).
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * MLPACK. If not, see <http://www.gnu.org/licenses/>.
+ */
+#include <mlpack/methods/emst/union_find.hpp>
+
+#include <mlpack/core.hpp>
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::emst;
+
+BOOST_AUTO_TEST_SUITE(UnionFindTest);
+
+BOOST_AUTO_TEST_CASE(TestFind)
+{
+ static const size_t testSize_ = 10;
+ UnionFind testUnionFind_(testSize_);
+
+ for (size_t i = 0; i < testSize_; i++)
+ BOOST_REQUIRE(testUnionFind_.Find(i) == i);
+
+ testUnionFind_.Union(0, 1);
+ testUnionFind_.Union(1, 2);
+
+ BOOST_REQUIRE(testUnionFind_.Find(2) == testUnionFind_.Find(0));
+}
+
+BOOST_AUTO_TEST_CASE(TestUnion)
+{
+ static const size_t testSize_ = 10;
+ UnionFind testUnionFind_(testSize_);
+
+ testUnionFind_.Union(0, 1);
+ testUnionFind_.Union(2, 3);
+ testUnionFind_.Union(0, 2);
+ testUnionFind_.Union(5, 0);
+ testUnionFind_.Union(0, 6);
+
+ BOOST_REQUIRE(testUnionFind_.Find(0) == testUnionFind_.Find(1));
+ BOOST_REQUIRE(testUnionFind_.Find(2) == testUnionFind_.Find(3));
+ BOOST_REQUIRE(testUnionFind_.Find(1) == testUnionFind_.Find(5));
+ BOOST_REQUIRE(testUnionFind_.Find(6) == testUnionFind_.Find(3));
+}
+
+BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list