[mlpack-git] master: Refactor and add Serialize(). Don't hold a reference to the kernel, hold a pointer and use it internally. (40b8a61)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Jul 10 18:59:32 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/4a97187bbba7ce8a6191b714949dd818ef0f37d2...e5905e62c15d1bcff21e6359b11efcd7ab6d7ca0

>---------------------------------------------------------------

commit 40b8a61cb4a02d46ed0db9aa4e8cd639b90aecde
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sat Apr 18 01:32:22 2015 +0000

    Refactor and add Serialize().
    Don't hold a reference to the kernel, hold a pointer and use it internally.


>---------------------------------------------------------------

40b8a61cb4a02d46ed0db9aa4e8cd639b90aecde
 src/mlpack/core/metrics/ip_metric.hpp      | 16 +++++++++------
 src/mlpack/core/metrics/ip_metric_impl.hpp | 33 ++++++++++++++++++++++--------
 2 files changed, 34 insertions(+), 15 deletions(-)

diff --git a/src/mlpack/core/metrics/ip_metric.hpp b/src/mlpack/core/metrics/ip_metric.hpp
index c8d97df..faf1822 100644
--- a/src/mlpack/core/metrics/ip_metric.hpp
+++ b/src/mlpack/core/metrics/ip_metric.hpp
@@ -49,18 +49,22 @@ class IPMetric
   double Evaluate(const VecTypeA& a, const VecTypeB& b);
 
   //! Get the kernel.
-  const KernelType& Kernel() const { return kernel; }
+  const KernelType& Kernel() const { return *kernel; }
   //! Modify the kernel.
-  KernelType& Kernel() { return kernel; }
+  KernelType& Kernel() { return *kernel; }
+
+  //! Serialize the metric.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int version);
 
   //! Returns a string representation of this object.
   std::string ToString() const;
 
  private:
-  //! The locally stored kernel, if it is necessary.
-  KernelType* localKernel;
-  //! The reference to the kernel that is being used.
-  KernelType& kernel;
+  //! The kernel we are using.
+  KernelType* kernel;
+  //! If true, we are responsible for deleting the kernel.
+  bool kernelOwner;
 };
 
 }; // namespace metric
diff --git a/src/mlpack/core/metrics/ip_metric_impl.hpp b/src/mlpack/core/metrics/ip_metric_impl.hpp
index 5b46756..186aed7 100644
--- a/src/mlpack/core/metrics/ip_metric_impl.hpp
+++ b/src/mlpack/core/metrics/ip_metric_impl.hpp
@@ -19,8 +19,8 @@ namespace metric {
 // Constructor with no instantiated kernel.
 template<typename KernelType>
 IPMetric<KernelType>::IPMetric() :
-    localKernel(new KernelType()),
-    kernel(*localKernel)
+    kernel(new KernelType()),
+    kernelOwner(true)
 {
   // Nothing to do.
 }
@@ -28,8 +28,8 @@ IPMetric<KernelType>::IPMetric() :
 // Constructor with instantiated kernel.
 template<typename KernelType>
 IPMetric<KernelType>::IPMetric(KernelType& kernel) :
-    localKernel(NULL),
-    kernel(kernel)
+    kernel(&kernel),
+    kernelOwner(false)
 {
   // Nothing to do.
 }
@@ -38,8 +38,8 @@ IPMetric<KernelType>::IPMetric(KernelType& kernel) :
 template<typename KernelType>
 IPMetric<KernelType>::~IPMetric()
 {
-  if (localKernel != NULL)
-    delete localKernel;
+  if (kernelOwner)
+    delete kernel;
 }
 
 template<typename KernelType>
@@ -49,8 +49,23 @@ inline double IPMetric<KernelType>::Evaluate(const Vec1Type& a,
 {
   // This is the metric induced by the kernel function.
   // Maybe we can do better by caching some of this?
-  return sqrt(kernel.Evaluate(a, a) + kernel.Evaluate(b, b) -
-      2 * kernel.Evaluate(a, b));
+  return sqrt(kernel->Evaluate(a, a) + kernel->Evaluate(b, b) -
+      2 * kernel->Evaluate(a, b));
+}
+
+// Serialize the kernel.
+template<typename KernelType>
+template<typename Archive>
+void IPMetric<KernelType>::Serialize(Archive& ar,
+                                     const unsigned int /* version */)
+{
+  // If we're loading, we need to allocate space for the kernel, and we will own
+  // the kernel.
+  if (Archive::is_loading::value)
+    kernel = new KernelType();
+  kernelOwner = true;
+
+  ar & data::CreateNVP(kernel, "kernel");
 }
 
 // Convert object to string.
@@ -60,7 +75,7 @@ std::string IPMetric<KernelType>::ToString() const
   std::ostringstream convert;
   convert << "IPMetric [" << this << "]" << std::endl;
   convert << "  Kernel: " << std::endl;
-  convert << util::Indent(kernel.ToString(), 2);
+  convert << util::Indent(kernel->ToString(), 2);
   return convert.str();
 }
 



More information about the mlpack-git mailing list