[mlpack-git] master: Hierarchical model support (0f95ab8)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:56:09 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 0f95ab87af47998adc830b5709c0ea63f279206b
Author: michaelfox99 <michaelfox99 at gmail.com>
Date:   Tue Aug 5 13:18:31 2014 +0000

    Hierarchical model support


>---------------------------------------------------------------

0f95ab87af47998adc830b5709c0ea63f279206b
 src/mlpack/core/util/save_restore_utility.cpp | 73 +++++++++++++++++----------
 1 file changed, 47 insertions(+), 26 deletions(-)

diff --git a/src/mlpack/core/util/save_restore_utility.cpp b/src/mlpack/core/util/save_restore_utility.cpp
index 55c531c..0d7abd0 100644
--- a/src/mlpack/core/util/save_restore_utility.cpp
+++ b/src/mlpack/core/util/save_restore_utility.cpp
@@ -1,6 +1,7 @@
 /**
  * @file save_restore_utility.cpp
  * @author Neil Slagle
+ * @author Michael Fox
  *
  * The SaveRestoreUtility provides helper functions in saving and
  *   restoring models.  The current output file type is XML.
@@ -19,25 +20,30 @@ bool SaveRestoreUtility::ReadFile(const std::string& filename)
   }
 
   xmlNodePtr root = xmlDocGetRootElement(xmlDocTree);
-  parameters.clear();
-
-  RecurseOnNodes(root->children);
+  ReadFile(root->children);
   xmlFreeDoc(xmlDocTree);
   return true;
 }
 
-void SaveRestoreUtility::RecurseOnNodes(xmlNode* n)
+void SaveRestoreUtility::ReadFile(xmlNode* n)
 {
+  parameters.clear();
   xmlNodePtr current = NULL;
   for (current = n; current; current = current->next)
   {
     if (current->type == XML_ELEMENT_NODE)
     {
       xmlChar* content = xmlNodeGetContent(current);
-      parameters[(const char*) current->name] = (const char*) content;
+      if(xmlChildElementCount(current) == 0)
+      {
+        parameters[(const char*) current->name] = (const char*) content;
+      }
+      else
+      {
+        children[(const char*) current->name].ReadFile(current->children);
+      }
       xmlFree(content);
     }
-    RecurseOnNodes(current->children);
   }
 }
 
@@ -46,30 +52,38 @@ bool SaveRestoreUtility::WriteFile(const std::string& filename)
   bool success = false;
   xmlDocPtr xmlDocTree = xmlNewDoc(BAD_CAST "1.0");
   xmlNodePtr root = xmlNewNode(NULL, BAD_CAST "root");
-
   xmlDocSetRootElement(xmlDocTree, root);
-
-  for (std::map<std::string, std::string>::iterator it = parameters.begin();
-       it != parameters.end();
-       ++it)
-  {
-    xmlNewChild(root, NULL, BAD_CAST(*it).first.c_str(),
-                            BAD_CAST(*it).second.c_str());
-    /* TODO: perhaps we'll add more later?
-     * xmlNewProp(child, BAD_CAST "attr", BAD_CAST "add more addibutes?"); */
-  }
+  WriteFile(root);
 
   // Actually save the file.
-  success =
-      (xmlSaveFormatFileEnc(filename.c_str(), xmlDocTree, "UTF-8", 1) != -1);
+  success = (xmlSaveFormatFileEnc(filename.c_str(), xmlDocTree, "UTF-8", 1) !=
+             -1);
   xmlFreeDoc(xmlDocTree);
   return success;
 }
 
+// this should be private member function
+void SaveRestoreUtility::WriteFile(xmlNode* n)
+{
+  for (std::map<std::string, std::string>::reverse_iterator it =
+	    parameters.rbegin(); it != parameters.rend(); ++it)
+  {
+    xmlNewChild(n, NULL, BAD_CAST(*it).first.c_str(),
+        BAD_CAST(*it).second.c_str());
+  }
+  xmlNodePtr child;
+  for (std::map<std::string, SaveRestoreUtility>::iterator it =
+       children.begin(); it != children.end(); ++it)
+  {
+    child = xmlNewChild(n, NULL, BAD_CAST(*it).first.c_str(), NULL);
+    it->second.WriteFile(child);
+  }
+}
+
 arma::mat& SaveRestoreUtility::LoadParameter(arma::mat& matrix,
-                                             const std::string& name)
+                                             const std::string& name) const
 {
-  std::map<std::string, std::string>::iterator it = parameters.find(name);
+  std::map<std::string, std::string>::const_iterator it = parameters.find(name);
   if (it != parameters.end())
   {
     std::string value = (*it).second;
@@ -126,9 +140,9 @@ arma::mat& SaveRestoreUtility::LoadParameter(arma::mat& matrix,
 }
 
 std::string SaveRestoreUtility::LoadParameter(std::string& str,
-                                              const std::string& name)
+                                              const std::string& name) const
 {
-  std::map<std::string, std::string>::iterator it = parameters.find(name);
+  std::map<std::string, std::string>::const_iterator it = parameters.find(name);
   if (it != parameters.end())
   {
     return str = (*it).second;
@@ -140,9 +154,9 @@ std::string SaveRestoreUtility::LoadParameter(std::string& str,
   return "";
 }
 
-char SaveRestoreUtility::LoadParameter(char c, const std::string& name)
+char SaveRestoreUtility::LoadParameter(char c, const std::string& name) const
 {
-  std::map<std::string, std::string>::iterator it = parameters.find(name);
+  std::map<std::string, std::string>::const_iterator it = parameters.find(name);
   if (it != parameters.end())
   {
     int temp;
@@ -190,7 +204,7 @@ namespace util {
 
 template<>
 arma::vec& SaveRestoreUtility::LoadParameter(arma::vec& t,
-                                             const std::string& name)
+                                             const std::string& name) const
 {
   return (arma::vec&) LoadParameter((arma::mat&) t, name);
 }
@@ -202,5 +216,12 @@ void SaveRestoreUtility::SaveParameter(const arma::vec& t,
   SaveParameter((const arma::mat&) t, name);
 }
     
+void SaveRestoreUtility::AddChild(SaveRestoreUtility& mn, const std::string&
+    name)
+{
+  children[name] = mn;
+}
+
+  
 }; // namespace util
 }; // namespace mlpack



More information about the mlpack-git mailing list