Commit bb7f083e authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

Serializer Pass (#1050)

* serializer pass
parent 2f776ef0
......@@ -125,6 +125,7 @@ set (SRC
pass/validate_graph.cpp
pass/visualize_tree.cpp
pass/core_fusion.cpp
pass/serialize.cpp
pass/zero_dim_tensor_elimination.cpp
pattern/matcher.cpp
runtime/aligned_buffer.cpp
......
......@@ -25,6 +25,7 @@
#include "ngraph/op/reduce.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/serialize.hpp"
#include "ngraph/pass/visualize_tree.hpp"
using namespace std;
......@@ -37,6 +38,11 @@ ngraph::pass::Manager::Manager()
{
m_visualize = true;
}
static const auto nest = std::getenv("NGRAPH_ENABLE_SERIALIZE_TRACING");
if (nest)
{
m_serialize = true;
}
}
ngraph::pass::Manager::~Manager()
......@@ -93,18 +99,28 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func)
}
}
if (m_visualize)
if (m_visualize || m_serialize)
{
//visualizations will be named after the outermost function
//visualizations and serializations will be named after the outermost function
const size_t num_digits_in_pass_index = 3;
std::string index_str = std::to_string(index);
index_str = std::string(num_digits_in_pass_index - index_str.length(), '0') + index_str;
auto fname = fs.at(0)->get_name() + std::string("_") + index_str + std::string("_") +
m_pass_names.at(index) + std::string(".") +
pass::VisualizeTree::get_file_ext();
pass::VisualizeTree vt(fname);
auto base_filename = fs.at(0)->get_name() + std::string("_") + index_str +
std::string("_") + m_pass_names.at(index) + std::string(".");
if (m_visualize)
{
pass::VisualizeTree vt(base_filename + pass::VisualizeTree::get_file_ext());
vt.run_on_module(fs);
}
if (m_serialize)
{
//no "." in the extension
pass::Serialization st(base_filename + "json");
st.run_on_module(fs);
}
}
index++;
}
}
......
......@@ -48,7 +48,7 @@ public:
auto pass = std::make_shared<T>(args...);
auto pass_base = std::static_pointer_cast<PassBase>(pass);
m_pass_list.push_back(pass_base);
if (m_visualize)
if (m_visualize || m_serialize)
{
m_pass_names.push_back(typeid(T).name());
}
......@@ -58,9 +58,11 @@ public:
ManagerState& get_state();
void set_pass_visualization(bool new_state) { m_visualize = new_state; }
void set_pass_serialization(bool new_state) { m_serialize = new_state; }
private:
std::vector<std::string> m_pass_names;
std::vector<std::shared_ptr<PassBase>> m_pass_list;
ManagerState m_state;
bool m_visualize = false;
bool m_serialize = false;
};
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <fstream>
#include "ngraph/file_util.hpp"
#include "ngraph/pass/serialize.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "nlohmann/json.hpp"
using namespace std;
using namespace ngraph;
pass::Serialization::Serialization(const string& name)
: m_name{name}
{
}
bool pass::Serialization::run_on_module(vector<shared_ptr<Function>>& functions)
{
//serializing the outermost functions
//also implicitly serializes any inner functions
serialize(m_name, functions.at(0), 4);
return false;
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <string>
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class Serialization;
}
}
class ngraph::pass::Serialization : public ModulePass
{
public:
Serialization(const std::string& name);
virtual bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
private:
const std::string m_name;
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment