[mlpack-git] master: Refactor and add Serialize(). Don't hold a reference to the kernel, hold a pointer and use it internally. (40b8a61)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Jul 10 18:59:32 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/4a97187bbba7ce8a6191b714949dd818ef0f37d2...e5905e62c15d1bcff21e6359b11efcd7ab6d7ca0
>---------------------------------------------------------------
commit 40b8a61cb4a02d46ed0db9aa4e8cd639b90aecde
Author: Ryan Curtin <ryan at ratml.org>
Date: Sat Apr 18 01:32:22 2015 +0000
Refactor and add Serialize().
Don't hold a reference to the kernel, hold a pointer and use it internally.
>---------------------------------------------------------------
40b8a61cb4a02d46ed0db9aa4e8cd639b90aecde
src/mlpack/core/metrics/ip_metric.hpp | 16 +++++++++------
src/mlpack/core/metrics/ip_metric_impl.hpp | 33 ++++++++++++++++++++++--------
2 files changed, 34 insertions(+), 15 deletions(-)
diff --git a/src/mlpack/core/metrics/ip_metric.hpp b/src/mlpack/core/metrics/ip_metric.hpp
index c8d97df..faf1822 100644
--- a/src/mlpack/core/metrics/ip_metric.hpp
+++ b/src/mlpack/core/metrics/ip_metric.hpp
@@ -49,18 +49,22 @@ class IPMetric
double Evaluate(const VecTypeA& a, const VecTypeB& b);
//! Get the kernel.
- const KernelType& Kernel() const { return kernel; }
+ const KernelType& Kernel() const { return *kernel; }
//! Modify the kernel.
- KernelType& Kernel() { return kernel; }
+ KernelType& Kernel() { return *kernel; }
+
+ //! Serialize the metric.
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int version);
//! Returns a string representation of this object.
std::string ToString() const;
private:
- //! The locally stored kernel, if it is necessary.
- KernelType* localKernel;
- //! The reference to the kernel that is being used.
- KernelType& kernel;
+ //! The kernel we are using.
+ KernelType* kernel;
+ //! If true, we are responsible for deleting the kernel.
+ bool kernelOwner;
};
}; // namespace metric
diff --git a/src/mlpack/core/metrics/ip_metric_impl.hpp b/src/mlpack/core/metrics/ip_metric_impl.hpp
index 5b46756..186aed7 100644
--- a/src/mlpack/core/metrics/ip_metric_impl.hpp
+++ b/src/mlpack/core/metrics/ip_metric_impl.hpp
@@ -19,8 +19,8 @@ namespace metric {
// Constructor with no instantiated kernel.
template<typename KernelType>
IPMetric<KernelType>::IPMetric() :
- localKernel(new KernelType()),
- kernel(*localKernel)
+ kernel(new KernelType()),
+ kernelOwner(true)
{
// Nothing to do.
}
@@ -28,8 +28,8 @@ IPMetric<KernelType>::IPMetric() :
// Constructor with instantiated kernel.
template<typename KernelType>
IPMetric<KernelType>::IPMetric(KernelType& kernel) :
- localKernel(NULL),
- kernel(kernel)
+ kernel(&kernel),
+ kernelOwner(false)
{
// Nothing to do.
}
@@ -38,8 +38,8 @@ IPMetric<KernelType>::IPMetric(KernelType& kernel) :
template<typename KernelType>
IPMetric<KernelType>::~IPMetric()
{
- if (localKernel != NULL)
- delete localKernel;
+ if (kernelOwner)
+ delete kernel;
}
template<typename KernelType>
@@ -49,8 +49,23 @@ inline double IPMetric<KernelType>::Evaluate(const Vec1Type& a,
{
// This is the metric induced by the kernel function.
// Maybe we can do better by caching some of this?
- return sqrt(kernel.Evaluate(a, a) + kernel.Evaluate(b, b) -
- 2 * kernel.Evaluate(a, b));
+ return sqrt(kernel->Evaluate(a, a) + kernel->Evaluate(b, b) -
+ 2 * kernel->Evaluate(a, b));
+}
+
+// Serialize the kernel.
+template<typename KernelType>
+template<typename Archive>
+void IPMetric<KernelType>::Serialize(Archive& ar,
+ const unsigned int /* version */)
+{
+ // If we're loading, we need to allocate space for the kernel, and we will own
+ // the kernel.
+ if (Archive::is_loading::value)
+ kernel = new KernelType();
+ kernelOwner = true;
+
+ ar & data::CreateNVP(kernel, "kernel");
}
// Convert object to string.
@@ -60,7 +75,7 @@ std::string IPMetric<KernelType>::ToString() const
std::ostringstream convert;
convert << "IPMetric [" << this << "]" << std::endl;
convert << " Kernel: " << std::endl;
- convert << util::Indent(kernel.ToString(), 2);
+ convert << util::Indent(kernel->ToString(), 2);
return convert.str();
}
More information about the mlpack-git
mailing list