Commit e6cc7d8b authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Bob/serialize2 (#267)

* add json lib 2.1.1

* add json serialization of graph

* cleanup

* enhance unit test

* remove FunctionProvider class and replace with virtual get_function() in Node

* remove json code from element_type

* move serialize to be directly in the ngraph namespace. cleanup header file.

* add cname check to element::operator==

* add using json = nlohmann::json
parent 4630c37d
...@@ -88,15 +88,29 @@ set (SRC ...@@ -88,15 +88,29 @@ set (SRC
runtime/tensor_view.cpp runtime/tensor_view.cpp
runtime/tuple.cpp runtime/tuple.cpp
runtime/utils.cpp runtime/utils.cpp
serializer.cpp
shape.cpp shape.cpp
types/element_type.cpp types/element_type.cpp
types/type.cpp types/type.cpp
util.cpp util.cpp
) )
message(STATUS ${CMAKE_CURRENT_SOURCE_DIR}/ops)
file(GLOB_RECURSE OPS "${CMAKE_CURRENT_SOURCE_DIR}/ops" "${CMAKE_CURRENT_SOURCE_DIR}/ops/*.hpp")
foreach(OP ${OPS})
file(STRINGS ${OP} OP_CLASS REGEX "class [A-Za-z0-9_]+ :")
foreach(LINE ${OP_CLASS})
string(REGEX REPLACE ".*class ([A-Za-z0-9_]+) : public ([A-Za-z0-9_]+).*" "\\1:\\2" CLASS_FOUND ${LINE})
set(OP_CLASS_LIST ${OP_CLASS_LIST} ${CLASS_FOUND})
endforeach(LINE ${OP_CLASS})
endforeach()
message(STATUS "${CMAKE_CURRENT_BINARY_DIR}/ops_list.txt")
string(REPLACE ";" "\n" OP_CLASS_LINES "${OP_CLASS_LIST}")
file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/ops_list.txt" "${OP_CLASS_LINES}")
# find_program (GRAPHVIZ dot) # find_program (GRAPHVIZ dot)
# message (STATUS "graphviz '${GRAPHVIZ}'") # message (STATUS "graphviz '${GRAPHVIZ}'")
find_package(Graphviz) find_package(Graphviz QUIET)
if (GRAPHVIZ_FOUND) if (GRAPHVIZ_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DGRAPHVIZ_FOUND") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DGRAPHVIZ_FOUND")
endif() endif()
......
...@@ -40,6 +40,7 @@ namespace ngraph ...@@ -40,6 +40,7 @@ namespace ngraph
const std::string& name = ""); const std::string& name = "");
std::shared_ptr<Node> get_result() { return m_result; } std::shared_ptr<Node> get_result() { return m_result; }
std::shared_ptr<const Node> get_result() const { return m_result; }
const std::vector<std::shared_ptr<op::Parameter>>& get_parameters() const const std::vector<std::shared_ptr<op::Parameter>>& get_parameters() const
{ {
return m_parameters; return m_parameters;
......
// clang-format off // clang-format off
#pragma clang diagnostic ignored "-Weverything"
/* /*
__ _____ _____ _____ __ _____ _____ _____
__| | __| | | | JSON for Modern C++ __| | __| | | | JSON for Modern C++
......
...@@ -181,6 +181,11 @@ std::shared_ptr<Node> Node::backprop_node(const std::shared_ptr<Node>& x, ...@@ -181,6 +181,11 @@ std::shared_ptr<Node> Node::backprop_node(const std::shared_ptr<Node>& x,
return adjoints_it->second.get(x); return adjoints_it->second.get(x);
} }
std::shared_ptr<Function> Node::get_function() const
{
return nullptr;
}
namespace ngraph namespace ngraph
{ {
ostream& operator<<(ostream& out, const Node& node) ostream& operator<<(ostream& out, const Node& node)
......
...@@ -111,6 +111,8 @@ namespace ngraph ...@@ -111,6 +111,8 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const std::vector<std::shared_ptr<Node>>& new_args) const = 0; copy_with_new_args(const std::vector<std::shared_ptr<Node>>& new_args) const = 0;
virtual std::shared_ptr<Function> get_function() const;
protected: protected:
std::string m_node_type; std::string m_node_type;
Nodes m_arguments; Nodes m_arguments;
......
...@@ -45,7 +45,7 @@ namespace ngraph ...@@ -45,7 +45,7 @@ namespace ngraph
/// | Backend | Status | /// | Backend | Status |
/// | ------- | ------------------ | /// | ------- | ------------------ |
/// | NGVM | Fully implemented. | /// | NGVM | Fully implemented. |
class FunctionCall : public ngraph::Node class FunctionCall : public Node
{ {
public: public:
/// \brief Constructs a function call operation. /// \brief Constructs a function call operation.
...@@ -62,7 +62,7 @@ namespace ngraph ...@@ -62,7 +62,7 @@ namespace ngraph
} }
/// \return The function to be called. /// \return The function to be called.
std::shared_ptr<Function> get_function() const { return m_function; } std::shared_ptr<Function> get_function() const override { return m_function; }
protected: protected:
std::shared_ptr<Function> m_function; std::shared_ptr<Function> m_function;
}; };
......
...@@ -20,8 +20,6 @@ namespace ngraph ...@@ -20,8 +20,6 @@ namespace ngraph
{ {
namespace op namespace op
{ {
class Node;
/// \brief Operation to get an element from a tuple. /// \brief Operation to get an element from a tuple.
/// ///
/// ## Parameters /// ## Parameters
...@@ -47,7 +45,7 @@ namespace ngraph ...@@ -47,7 +45,7 @@ namespace ngraph
/// | Backend | Status | /// | Backend | Status |
/// | ------- | ------------------ | /// | ------- | ------------------ |
/// | NGVM | Fully implemented. | /// | NGVM | Fully implemented. |
class GetTupleElement : public ngraph::Node class GetTupleElement : public Node
{ {
public: public:
/// \brief Constructs a get-tuple-element operation. /// \brief Constructs a get-tuple-element operation.
......
...@@ -111,10 +111,7 @@ namespace ngraph ...@@ -111,10 +111,7 @@ namespace ngraph
} }
/// \return The function to use for reduction. /// \return The function to use for reduction.
std::shared_ptr<Function> get_reduction_function() const std::shared_ptr<Function> get_function() const override { return m_reduction_function; }
{
return m_reduction_function;
}
/// \return The axis positions (0-based) to be eliminated through reduction. /// \return The axis positions (0-based) to be eliminated through reduction.
const AxisSet& get_reduction_axes() const { return m_reduction_axes; } const AxisSet& get_reduction_axes() const { return m_reduction_axes; }
protected: protected:
......
...@@ -39,7 +39,7 @@ namespace ngraph ...@@ -39,7 +39,7 @@ namespace ngraph
/// | Backend | Status | /// | Backend | Status |
/// | ------- | ------------------ | /// | ------- | ------------------ |
/// | NGVM | Fully implemented. | /// | NGVM | Fully implemented. |
class Tuple : public ngraph::Node class Tuple : public Node
{ {
public: public:
/// \brief Constructs a tuple construction operation. /// \brief Constructs a tuple construction operation.
......
...@@ -16,12 +16,12 @@ ...@@ -16,12 +16,12 @@
#include <memory> #include <memory>
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/ops/function_call.hpp" #include "ngraph/ops/function_call.hpp"
#include "ngraph/ops/reduce.hpp" #include "ngraph/ops/reduce.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
#include "ngraph/util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -38,30 +38,11 @@ void ngraph::pass::Manager::initialize_default_passes() ...@@ -38,30 +38,11 @@ void ngraph::pass::Manager::initialize_default_passes()
{ {
} }
static void find_functions(shared_ptr<Function> f, set<shared_ptr<Function>>& funcs)
{
funcs.insert(f);
for (shared_ptr<Node> node : f->get_ops())
{
shared_ptr<op::FunctionCall> fc = dynamic_pointer_cast<op::FunctionCall>(node);
if (fc)
{
find_functions(fc->get_function(), funcs);
}
shared_ptr<op::Reduce> reduce = dynamic_pointer_cast<op::Reduce>(node);
if (reduce)
{
find_functions(reduce->get_reduction_function(), funcs);
}
}
}
void ngraph::pass::Manager::run_passes(shared_ptr<Function> func) void ngraph::pass::Manager::run_passes(shared_ptr<Function> func)
{ {
// find all functions // find all functions
set<shared_ptr<Function>> tfs; set<shared_ptr<Function>> tfs;
find_functions(func, tfs); traverse_functions(func, [&](shared_ptr<Function> f) { tfs.insert(f); });
get_state().set_functions(tfs); get_state().set_functions(tfs);
vector<shared_ptr<Function>> fs; vector<shared_ptr<Function>> fs;
......
...@@ -1001,7 +1001,7 @@ void Emitter::EmitReduce(const ngraph::Node* n, ...@@ -1001,7 +1001,7 @@ void Emitter::EmitReduce(const ngraph::Node* n,
const std::vector<TensorViewInfo>& outputs) const std::vector<TensorViewInfo>& outputs)
{ {
auto reduce = static_cast<const op::Reduce*>(n); auto reduce = static_cast<const op::Reduce*>(n);
auto reduction_function = reduce->get_reduction_function(); auto reduction_function = reduce->get_function();
auto reductee_type = reduce->get_arguments().at(0)->get_value_type(); auto reductee_type = reduce->get_arguments().at(0)->get_value_type();
auto reductee_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(reductee_type); auto reductee_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(reductee_type);
......
...@@ -248,9 +248,9 @@ using namespace ngraph::runtime::cpu::eigen; ...@@ -248,9 +248,9 @@ using namespace ngraph::runtime::cpu::eigen;
{ {
for (descriptor::Tensor* tensor : node->liveness_new_list) for (descriptor::Tensor* tensor : node->liveness_new_list)
{ {
TU << tensor->get_element_type() << "* " << tensor->get_name() << " = (" TU << tensor->get_element_type().c_type_string() << "* " << tensor->get_name()
<< tensor->get_element_type() << "*)(memory_handler.get_ptr(" << " = (" << tensor->get_element_type().c_type_string()
<< tensor->get_pool_offset() << "));\n"; << "*)(memory_handler.get_ptr(" << tensor->get_pool_offset() << "));\n";
} }
} }
TU << "\n"; TU << "\n";
......
...@@ -740,7 +740,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -740,7 +740,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
REGISTER_TO_OP_MAP(op::Reduce) REGISTER_TO_OP_MAP(op::Reduce)
{ {
auto reduce = static_cast<const op::Reduce*>(n); auto reduce = static_cast<const op::Reduce*>(n);
auto reduction_function = reduce->get_reduction_function(); auto reduction_function = reduce->get_function();
std::shared_ptr<ExternalFunction> external; std::shared_ptr<ExternalFunction> external;
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/serializer.hpp"
#include "ngraph/ops/abs.hpp"
#include "ngraph/ops/acos.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/asin.hpp"
#include "ngraph/ops/atan.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/ceiling.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convert.hpp"
#include "ngraph/ops/cos.hpp"
#include "ngraph/ops/cosh.hpp"
#include "ngraph/ops/divide.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/equal.hpp"
#include "ngraph/ops/exp.hpp"
#include "ngraph/ops/floor.hpp"
#include "ngraph/ops/function_call.hpp"
#include "ngraph/ops/get_tuple_element.hpp"
#include "ngraph/ops/greater.hpp"
#include "ngraph/ops/greater_eq.hpp"
#include "ngraph/ops/less.hpp"
#include "ngraph/ops/less_eq.hpp"
#include "ngraph/ops/log.hpp"
#include "ngraph/ops/maximum.hpp"
#include "ngraph/ops/minimum.hpp"
#include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/negative.hpp"
#include "ngraph/ops/not_equal.hpp"
#include "ngraph/ops/power.hpp"
#include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/remainder.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/select.hpp"
#include "ngraph/ops/sign.hpp"
#include "ngraph/ops/sin.hpp"
#include "ngraph/ops/sinh.hpp"
#include "ngraph/ops/slice.hpp"
#include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/sum.hpp"
#include "ngraph/ops/tan.hpp"
#include "ngraph/ops/tanh.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/util.hpp"
using namespace ngraph;
using namespace std;
using json = nlohmann::json;
std::shared_ptr<ngraph::Function>
read_function(const json&, std::unordered_map<std::string, std::shared_ptr<Function>>&);
json write(const ngraph::Function&);
json write(const ngraph::Node&);
json write(const ngraph::element::Type&);
// This stupidity is caused by the fact that we do not pass element types
// by value but by reference even though they can be compared. There is no reason to pass
// them by reference EVERYWERE but here we are...
const element::Type& to_ref(const element::Type& t)
{
if (t == element::boolean)
{
return element::boolean;
}
if (t == element::f32)
{
return element::f32;
}
if (t == element::f64)
{
return element::f64;
}
if (t == element::i8)
{
return element::i8;
}
if (t == element::i16)
{
return element::i16;
}
if (t == element::i32)
{
return element::i32;
}
if (t == element::i64)
{
return element::i64;
}
if (t == element::u8)
{
return element::u8;
}
if (t == element::u16)
{
return element::u16;
}
if (t == element::u32)
{
return element::u32;
}
if (t == element::u64)
{
return element::u64;
}
throw runtime_error("type not valid");
}
static json write_element_type(const ngraph::element::Type& n)
{
json j;
j["bitwidth"] = n.bitwidth();
j["is_real"] = n.is_real();
j["is_signed"] = n.is_signed();
j["c_type_string"] = n.c_type_string();
return j;
}
static const element::Type& read_element_type(const json& j)
{
size_t bitwidth = j.at("bitwidth").get<size_t>();
bool is_real = j.at("is_real").get<bool>();
bool is_signed = j.at("is_signed").get<bool>();
string c_type_string = j.at("c_type_string").get<string>();
return to_ref(element::Type(bitwidth, is_real, is_signed, c_type_string));
}
string ngraph::serialize(shared_ptr<ngraph::Function> func)
{
json j;
vector<json> functions;
traverse_functions(func,
[&](shared_ptr<ngraph::Function> f) { functions.push_back(write(*f)); });
for (auto it = functions.rbegin(); it != functions.rend(); it++)
{
j.push_back(*it);
}
return j.dump();
}
shared_ptr<ngraph::Function> ngraph::deserialize(istream& in)
{
json js = json::array();
shared_ptr<Function> rc;
in >> js;
unordered_map<string, shared_ptr<Function>> function_map;
for (json func : js)
{
shared_ptr<Function> f = read_function(func, function_map);
if (rc == nullptr)
{
rc = f;
}
}
return rc;
}
json write(const Function& f)
{
json function;
function["name"] = f.get_name();
function["result_type"] = write_element_type(f.get_result_type()->get_element_type());
function["result_shape"] = f.get_result_type()->get_shape();
for (auto param : f.get_parameters())
{
function["parameters"].push_back(param->get_name());
}
function["result"].push_back(f.get_result()->get_name());
list<shared_ptr<Node>> result_list;
{
deque<Node*> independent_nodes;
unordered_map<const Node*, size_t> node_depencency_count;
unordered_map<Node*, shared_ptr<Node>> node_map;
traverse_nodes(const_cast<Function*>(&f), [&](shared_ptr<Node> node) {
node_map[node.get()] = node;
node_depencency_count[node.get()] = node->get_arguments().size();
if (node->get_arguments().size() == 0)
{
independent_nodes.push_back(node.get());
}
});
while (independent_nodes.size() > 0)
{
auto independent_node = independent_nodes.front();
result_list.push_back(node_map[independent_node]);
independent_nodes.pop_front();
for (auto user : independent_node->users())
{
node_depencency_count[user] -= 1;
size_t count = node_depencency_count[user];
if (count == 0)
{
independent_nodes.push_back(user);
}
}
}
}
json nodes;
for (shared_ptr<Node> node : result_list)
{
nodes.push_back(write(*node));
}
function["ops"] = nodes;
return function;
}
shared_ptr<ngraph::Function>
read_function(const json& func_js, unordered_map<string, shared_ptr<Function>>& function_map)
{
shared_ptr<ngraph::Function> rc;
string func_name = func_js.at("name").get<string>();
vector<string> func_result = func_js.at("result").get<vector<string>>();
vector<string> func_parameters = func_js.at("parameters").get<vector<string>>();
const element::Type& result_type = read_element_type(func_js.at("result_type"));
vector<size_t> result_shape = func_js.at("result_shape").get<vector<size_t>>();
unordered_map<string, shared_ptr<Node>> node_map;
for (json node_js : func_js.at("ops"))
{
string node_name = node_js.at("name").get<string>();
string node_op = node_js.at("op").get<string>();
const element::Type& node_etype = read_element_type(node_js.at("element_type"));
vector<string> node_inputs = node_js.at("inputs").get<vector<string>>();
vector<string> node_outputs = node_js.at("outputs").get<vector<string>>();
shared_ptr<Node> node;
shared_ptr<Function> function_ptr = nullptr;
vector<shared_ptr<Node>> args;
for (const string& name : node_inputs)
{
args.push_back(node_map.at(name));
}
vector<string> known_nodes;
for (auto x : node_map)
{
known_nodes.push_back(x.first);
}
if (node_op == "Abs")
{
node = make_shared<op::Abs>(args[0]);
}
else if (node_op == "Acos")
{
node = make_shared<op::Acos>(args[0]);
}
else if (node_op == "Add")
{
node = make_shared<op::Add>(args[0], args[1]);
}
else if (node_op == "Asin")
{
node = make_shared<op::Asin>(args[0]);
}
else if (node_op == "Atan")
{
node = make_shared<op::Atan>(args[0]);
}
else if (node_op == "Broadcast")
{
auto shape = node_js.at("shape").get<vector<size_t>>();
auto axes = node_js.at("axes").get<set<size_t>>();
node = make_shared<op::Broadcast>(args[0], shape, axes);
}
else if (node_op == "Ceiling")
{
node = make_shared<op::Ceiling>(args[0]);
}
else if (node_op == "Concat")
{
auto axis = node_js.at("axis").get<size_t>();
node = make_shared<op::Concat>(args, axis);
}
else if (node_op == "Constant")
{
auto shape = node_js.at("shape").get<vector<size_t>>();
auto value = node_js.at("value").get<vector<string>>();
node = make_shared<op::Constant>(node_etype, shape, value);
}
else if (node_op == "Convert")
{
auto target_type = read_element_type(node_js.at("target_type"));
node = make_shared<op::Convert>(args[0], target_type);
}
else if (node_op == "Cos")
{
node = make_shared<op::Cos>(args[0]);
}
else if (node_op == "Cosh")
{
node = make_shared<op::Cosh>(args[0]);
}
else if (node_op == "Divide")
{
node = make_shared<op::Divide>(args[0], args[1]);
}
else if (node_op == "Dot")
{
node = make_shared<op::Dot>(args[0], args[1]);
}
else if (node_op == "Equal")
{
node = make_shared<op::Equal>(args[0], args[1]);
}
else if (node_op == "Exp")
{
node = make_shared<op::Exp>(args[0]);
}
else if (node_op == "Floor")
{
node = make_shared<op::Floor>(args[0]);
}
else if (node_op == "FunctionCall")
{
string function_name = node_js.at("function").get<string>();
shared_ptr<Function> f_ptr = function_map.at(function_name);
node = make_shared<op::FunctionCall>(f_ptr, args);
}
// else if (node_op == "GetTupleElement")
// {
// node = make_shared<op::GetTupleElement>(args[0]);
// }
else if (node_op == "Greater")
{
node = make_shared<op::Greater>(args[0], args[1]);
}
else if (node_op == "GreaterEq")
{
node = make_shared<op::GreaterEq>(args[0], args[1]);
}
else if (node_op == "Less")
{
node = make_shared<op::Less>(args[0], args[1]);
}
else if (node_op == "LessEq")
{
node = make_shared<op::LessEq>(args[0], args[1]);
}
else if (node_op == "Log")
{
node = make_shared<op::Log>(args[0]);
}
else if (node_op == "Maximum")
{
node = make_shared<op::Maximum>(args[0], args[1]);
}
else if (node_op == "Minimum")
{
node = make_shared<op::Minimum>(args[0], args[1]);
}
else if (node_op == "Multiply")
{
node = make_shared<op::Multiply>(args[0], args[1]);
}
else if (node_op == "Negative")
{
node = make_shared<op::Negative>(args[0]);
}
else if (node_op == "NotEqual")
{
node = make_shared<op::NotEqual>(args[0], args[1]);
}
else if (node_op == "Parameter")
{
auto shape = node_js.at("shape");
node = make_shared<op::Parameter>(node_etype, shape);
}
else if (node_op == "Power")
{
node = make_shared<op::Power>(args[0], args[1]);
}
else if (node_op == "Reduce")
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Reduce>(args[0], args[1], function_ptr, reduction_axes);
}
else if (node_op == "Remainder")
{
node = make_shared<op::Remainder>(args[0], args[1]);
}
else if (node_op == "Reshape")
{
auto input_order = node_js.at("input_order").get<vector<size_t>>();
auto output_shape = node_js.at("output_shape").get<vector<size_t>>();
node = make_shared<op::Reshape>(args[0], input_order, output_shape);
}
else if (node_op == "Select")
{
node = make_shared<op::Select>(args[0], args[1], args[2]);
}
else if (node_op == "Sign")
{
node = make_shared<op::Sign>(args[0]);
}
else if (node_op == "Sin")
{
node = make_shared<op::Sin>(args[0]);
}
else if (node_op == "Sinh")
{
node = make_shared<op::Sinh>(args[0]);
}
else if (node_op == "Slice")
{
auto lower_bounds = node_js.at("lower_bounds").get<vector<size_t>>();
auto upper_bounds = node_js.at("upper_bounds").get<vector<size_t>>();
auto step = node_js.at("step").get<vector<size_t>>();
node = make_shared<op::Slice>(args[0], lower_bounds, upper_bounds, step);
}
else if (node_op == "Subtract")
{
node = make_shared<op::Subtract>(args[0], args[1]);
}
else if (node_op == "Sum")
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Sum>(args[0], reduction_axes);
}
else if (node_op == "Tan")
{
node = make_shared<op::Tan>(args[0]);
}
else if (node_op == "Tanh")
{
node = make_shared<op::Tanh>(args[0]);
}
else if (node_op == "Tuple")
{
node = make_shared<op::Tuple>(args);
}
else
{
stringstream ss;
ss << "unsupported op " << node_op;
throw runtime_error(ss.str());
}
node_map[node_name] = node;
}
auto result = node_map.at(func_result[0]);
std::vector<std::shared_ptr<op::Parameter>> params;
for (auto param_name : func_parameters)
{
params.push_back(dynamic_pointer_cast<op::Parameter>(node_map.at(param_name)));
}
auto rt = make_shared<TensorViewType>(result_type, result_shape);
rc = make_shared<Function>(result, rt, params, func_name);
function_map[func_name] = rc;
return rc;
}
json write(const Node& n)
{
json node;
node["name"] = n.get_name();
node["op"] = n.description();
node["element_type"] = write_element_type(n.get_element_type());
json inputs = json::array();
json outputs = json::array();
for (const descriptor::Input& input : n.get_inputs())
{
inputs.push_back(input.get_output().get_node()->get_name());
}
for (const descriptor::Output& output : n.get_outputs())
{
outputs.push_back(output.get_node()->get_name());
}
node["inputs"] = inputs;
node["outputs"] = outputs;
string node_op = n.description();
if (node_op == "Abs")
{
}
else if (node_op == "Acos")
{
}
else if (node_op == "Add")
{
}
else if (node_op == "Asin")
{
}
else if (node_op == "Atan")
{
}
else if (node_op == "Broadcast")
{
auto tmp = dynamic_cast<const op::Broadcast*>(&n);
node["axes"] = tmp->get_broadcast_axes();
node["shape"] = tmp->get_broadcast_shape();
}
else if (node_op == "Ceiling")
{
}
else if (node_op == "Concat")
{
auto tmp = dynamic_cast<const op::Concat*>(&n);
node["axis"] = tmp->get_concatenation_axis();
}
else if (node_op == "Constant")
{
auto tmp = dynamic_cast<const op::Constant*>(&n);
node["value"] = tmp->get_value_strings();
node["shape"] = tmp->get_shape();
}
else if (node_op == "Convert")
{
auto tmp = dynamic_cast<const op::Convert*>(&n);
node["target_type"] = write_element_type(tmp->get_convert_element_type());
}
else if (node_op == "Cos")
{
}
else if (node_op == "Cosh")
{
}
else if (node_op == "Divide")
{
}
else if (node_op == "Dot")
{
}
else if (node_op == "Equal")
{
}
else if (node_op == "Exp")
{
}
else if (node_op == "Floor")
{
}
else if (node_op == "FunctionCall")
{
node["function"] = n.get_function()->get_name();
}
else if (node_op == "GetTupleElement")
{
}
else if (node_op == "Greater")
{
}
else if (node_op == "GreaterEq")
{
}
else if (node_op == "Less")
{
}
else if (node_op == "LessEq")
{
}
else if (node_op == "Log")
{
}
else if (node_op == "Maximum")
{
}
else if (node_op == "Minimum")
{
}
else if (node_op == "Multiply")
{
}
else if (node_op == "Negative")
{
}
else if (node_op == "NotEqual")
{
}
else if (node_op == "Parameter")
{
auto tmp = dynamic_cast<const op::Parameter*>(&n);
node["shape"] = tmp->get_shape();
}
else if (node_op == "Power")
{
}
else if (node_op == "Reduce")
{
auto tmp = dynamic_cast<const op::Reduce*>(&n);
node["function"] = tmp->get_function()->get_name();
node["reduction_axes"] = tmp->get_reduction_axes();
}
else if (node_op == "Remainder")
{
}
else if (node_op == "Reshape")
{
auto tmp = dynamic_cast<const op::Reshape*>(&n);
node["input_order"] = tmp->get_input_order();
node["output_shape"] = tmp->get_output_shape();
}
else if (node_op == "Select")
{
}
else if (node_op == "Sign")
{
}
else if (node_op == "Sin")
{
}
else if (node_op == "Sinh")
{
}
else if (node_op == "Slice")
{
auto tmp = dynamic_cast<const op::Slice*>(&n);
node["lower_bounds"] = tmp->get_lower_bounds();
node["upper_bounds"] = tmp->get_upper_bounds();
node["step"] = tmp->get_step();
}
else if (node_op == "Subtract")
{
}
else if (node_op == "Sum")
{
auto tmp = dynamic_cast<const op::Sum*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes();
}
else if (node_op == "Tan")
{
}
else if (node_op == "Tanh")
{
}
else if (node_op == "Tuple")
{
}
return node;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <unordered_map>
#include "ngraph/function.hpp"
#include "ngraph/json.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
std::string serialize(std::shared_ptr<ngraph::Function>);
std::shared_ptr<ngraph::Function> deserialize(std::istream&);
}
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
using namespace ngraph; using namespace ngraph;
const element::Type element::boolean(8, false, false, "bool"); const element::Type element::boolean(8, false, false, "char");
const element::Type element::f32(32, true, true, "float"); const element::Type element::f32(32, true, true, "float");
const element::Type element::f64(64, true, true, "double"); const element::Type element::f64(64, true, true, "double");
const element::Type element::i8(8, false, true, "int8_t"); const element::Type element::i8(8, false, true, "int8_t");
...@@ -39,7 +39,6 @@ element::Type::Type(size_t bitwidth, bool is_real, bool is_signed, const std::st ...@@ -39,7 +39,6 @@ element::Type::Type(size_t bitwidth, bool is_real, bool is_signed, const std::st
, m_is_signed{is_signed} , m_is_signed{is_signed}
, m_cname{cname} , m_cname{cname}
{ {
assert(m_bitwidth % 8 == 0);
} }
const std::string& element::Type::c_type_string() const const std::string& element::Type::c_type_string() const
...@@ -53,13 +52,35 @@ bool element::Type::operator==(const element::Type& other) const ...@@ -53,13 +52,35 @@ bool element::Type::operator==(const element::Type& other) const
m_is_signed == other.m_is_signed && m_cname == other.m_cname; m_is_signed == other.m_is_signed && m_cname == other.m_cname;
} }
bool element::Type::operator<(const Type& other) const
{
size_t v1 = m_bitwidth << 2;
v1 |= (m_is_real ? 2 : 0);
v1 |= (m_is_signed ? 1 : 0);
size_t v2 = other.m_bitwidth << 2;
v2 |= (other.m_is_real ? 2 : 0);
v2 |= (other.m_is_signed ? 1 : 0);
return v1 < v2;
}
size_t element::Type::size() const size_t element::Type::size() const
{ {
return std::ceil(static_cast<float>(m_bitwidth) / 8.0f); return std::ceil(static_cast<float>(m_bitwidth) / 8.0f);
} }
size_t element::Type::hash() const
{
size_t h1 = std::hash<size_t>{}(m_bitwidth);
size_t h2 = std::hash<bool>{}(m_is_real);
size_t h3 = std::hash<bool>{}(m_is_signed);
return h1 ^ ((h2 ^ (h3 << 1)) << 1);
}
std::ostream& element::operator<<(std::ostream& out, const element::Type& obj) std::ostream& element::operator<<(std::ostream& out, const element::Type& obj)
{ {
out << obj.m_cname; out << "element::Type(" << obj.m_bitwidth << ", " << obj.m_is_real << ", " << obj.m_is_signed
<< ")";
return out; return out;
} }
...@@ -47,23 +47,20 @@ namespace ngraph ...@@ -47,23 +47,20 @@ namespace ngraph
class Type class Type
{ {
Type(const Type&) = delete;
Type& operator=(const Type&) = delete;
public: public:
virtual ~Type() {} Type() = delete;
Type(const Type&) = default;
Type(size_t bitwidth, bool is_real, bool is_signed, const std::string& cname); Type(size_t bitwidth, bool is_real, bool is_signed, const std::string& cname);
virtual ~Type() {}
const std::string& c_type_string() const; const std::string& c_type_string() const;
size_t size() const; size_t size() const;
size_t hash() const size_t hash() const;
{ bool is_real() const { return m_is_real; }
std::hash<std::string> h; bool is_signed() const { return m_is_signed; }
return h(m_cname); size_t bitwidth() const { return m_bitwidth; }
}
bool operator==(const Type& other) const; bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); } bool operator!=(const Type& other) const { return !(*this == other); }
bool operator<(const Type& other) const;
friend std::ostream& operator<<(std::ostream&, const Type&); friend std::ostream& operator<<(std::ostream&, const Type&);
private: private:
......
...@@ -145,7 +145,6 @@ void ngraph::traverse_nodes(std::shared_ptr<ngraph::Function> p, ...@@ -145,7 +145,6 @@ void ngraph::traverse_nodes(std::shared_ptr<ngraph::Function> p,
} }
void ngraph::traverse_nodes(ngraph::Function* p, std::function<void(shared_ptr<Node>)> f) void ngraph::traverse_nodes(ngraph::Function* p, std::function<void(shared_ptr<Node>)> f)
{ {
std::unordered_set<shared_ptr<Node>> instances_seen; std::unordered_set<shared_ptr<Node>> instances_seen;
deque<shared_ptr<Node>> stack; deque<shared_ptr<Node>> stack;
...@@ -172,6 +171,34 @@ void ngraph::traverse_nodes(ngraph::Function* p, std::function<void(shared_ptr<N ...@@ -172,6 +171,34 @@ void ngraph::traverse_nodes(ngraph::Function* p, std::function<void(shared_ptr<N
} }
} }
void ngraph::traverse_functions(std::shared_ptr<ngraph::Function> p,
std::function<void(shared_ptr<Function>)> f)
{
std::unordered_set<shared_ptr<Function>> instances_seen;
deque<shared_ptr<Function>> stack;
stack.push_front(p);
while (stack.size() > 0)
{
shared_ptr<Function> func = stack.front();
if (instances_seen.find(func) == instances_seen.end())
{
instances_seen.insert(func);
f(func);
}
stack.pop_front();
for (shared_ptr<Node> op : func->get_ops())
{
shared_ptr<Function> fp = op->get_function();
if (fp)
{
stack.push_front(fp);
}
}
}
}
void ngraph::free_nodes(shared_ptr<Function> p) void ngraph::free_nodes(shared_ptr<Function> p)
{ {
std::deque<Node*> sorted_list; std::deque<Node*> sorted_list;
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <chrono> #include <chrono>
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <list>
#include <map> #include <map>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
...@@ -239,8 +240,9 @@ namespace ngraph ...@@ -239,8 +240,9 @@ namespace ngraph
} }
void traverse_nodes(Function* p, std::function<void(std::shared_ptr<Node>)> f); void traverse_nodes(Function* p, std::function<void(std::shared_ptr<Node>)> f);
void traverse_nodes(std::shared_ptr<Function> p, std::function<void(std::shared_ptr<Node>)> f); void traverse_nodes(std::shared_ptr<Function> p, std::function<void(std::shared_ptr<Node>)> f);
void traverse_functions(std::shared_ptr<Function> p,
std::function<void(std::shared_ptr<Function>)> f);
void free_nodes(std::shared_ptr<Function>); void free_nodes(std::shared_ptr<Function>);
} // end namespace ngraph } // end namespace ngraph
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#include <algorithm> #include <algorithm>
#include <cinttypes> #include <cinttypes>
#include <cmath>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include <cmath>
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
......
...@@ -358,7 +358,7 @@ TEST(copy, reduce) ...@@ -358,7 +358,7 @@ TEST(copy, reduce)
ASSERT_TRUE(nullptr != new_node); ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_arguments()); ASSERT_TRUE(new_args == new_node->get_arguments());
ASSERT_TRUE(f == node_cast->get_reduction_function()); ASSERT_TRUE(f == node_cast->get_function());
ASSERT_TRUE(axes == node_cast->get_reduction_axes()); ASSERT_TRUE(axes == node_cast->get_reduction_axes());
} }
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <map>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/types/element_type.hpp" #include "ngraph/types/element_type.hpp"
...@@ -33,3 +35,50 @@ TEST(element_type, from) ...@@ -33,3 +35,50 @@ TEST(element_type, from)
EXPECT_EQ(element::from<uint32_t>(), element::u32); EXPECT_EQ(element::from<uint32_t>(), element::u32);
EXPECT_EQ(element::from<uint64_t>(), element::u64); EXPECT_EQ(element::from<uint64_t>(), element::u64);
} }
TEST(element_type, mapable)
{
std::map<element::Type, std::string> test_map;
test_map.insert({element::f32, "float"});
}
TEST(element_type, size)
{
{
element::Type t1{1, false, false, ""};
EXPECT_EQ(1, t1.size());
}
{
element::Type t1{2, false, false, ""};
EXPECT_EQ(1, t1.size());
}
{
element::Type t1{3, false, false, ""};
EXPECT_EQ(1, t1.size());
}
{
element::Type t1{4, false, false, ""};
EXPECT_EQ(1, t1.size());
}
{
element::Type t1{5, false, false, ""};
EXPECT_EQ(1, t1.size());
}
{
element::Type t1{6, false, false, ""};
EXPECT_EQ(1, t1.size());
}
{
element::Type t1{7, false, false, ""};
EXPECT_EQ(1, t1.size());
}
{
element::Type t1{2, false, false, ""};
EXPECT_EQ(1, t1.size());
}
{
element::Type t1{9, false, false, ""};
EXPECT_EQ(2, t1.size());
}
}
...@@ -12,4 +12,88 @@ ...@@ -12,4 +12,88 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <fstream>
#include <sstream>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/json.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
template <typename T>
static void copy_data(shared_ptr<runtime::TensorView> tv, const vector<T>& data)
{
size_t data_size = data.size() * sizeof(T);
tv->write(data.data(), 0, data_size);
}
TEST(serialize, main)
{
// First create "f(A,B,C) = (A+B)*C".
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto C = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt_f = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto f = make_shared<Function>((A + B) * C, rt_f, op::Parameters{A, B, C}, "f");
// Now make "g(X,Y,Z) = f(X,Y,Z) + f(X,Y,Z)"
auto X = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Y = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Z = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt_g = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto g = make_shared<Function>(make_shared<op::FunctionCall>(f, Nodes{X, Y, Z}) +
make_shared<op::FunctionCall>(f, Nodes{X, Y, Z}),
rt_g,
op::Parameters{X, Y, Z},
"g");
// Now make "h(X,Y,Z) = g(X,Y,Z) + g(X,Y,Z)"
auto X1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Y1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Z1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt_h = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto h = make_shared<Function>(make_shared<op::FunctionCall>(g, Nodes{X1, Y1, Z1}) +
make_shared<op::FunctionCall>(g, Nodes{X1, Y1, Z1}),
rt_h,
op::Parameters{X1, Y1, Z1},
"h");
string js = serialize(h);
{
ofstream f("serialize_function.js");
f << js;
}
istringstream in(js);
shared_ptr<Function> sfunc = deserialize(in);
// Now call g on some test vectors.
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(sfunc);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
auto x = backend->make_primary_tensor_view(element::Float32::element_type(), shape);
copy_data(x, vector<float>{1, 2, 3, 4});
auto y = backend->make_primary_tensor_view(element::Float32::element_type(), shape);
copy_data(y, vector<float>{5, 6, 7, 8});
auto z = backend->make_primary_tensor_view(element::Float32::element_type(), shape);
copy_data(z, vector<float>{9, 10, 11, 12});
auto result = backend->make_primary_tensor_view(element::Float32::element_type(), shape);
cf->call({x, y, z}, {result});
EXPECT_EQ((vector<float>{54, 80, 110, 144}), result->get_vector<float>());
cf->call({y, x, z}, {result});
EXPECT_EQ((vector<float>{54, 80, 110, 144}), result->get_vector<float>());
cf->call({x, z, y}, {result});
EXPECT_EQ((vector<float>{50, 72, 98, 128}), result->get_vector<float>());
}
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/function.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "util/all_close.hpp" #include "util/all_close.hpp"
...@@ -202,3 +203,40 @@ TEST(util, all_close) ...@@ -202,3 +203,40 @@ TEST(util, all_close)
EXPECT_FALSE(ngraph::test::all_close<float>(c, a, .05f, 0)); EXPECT_FALSE(ngraph::test::all_close<float>(c, a, .05f, 0));
EXPECT_TRUE(ngraph::test::all_close<float>(c, a, .11f, 0)); EXPECT_TRUE(ngraph::test::all_close<float>(c, a, .11f, 0));
} }
TEST(util, traverse_functions)
{
// First create "f(A,B,C) = (A+B)*C".
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto C = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt_f = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto f = make_shared<Function>((A + B) * C, rt_f, op::Parameters{A, B, C}, "f");
// Now make "g(X,Y,Z) = f(X,Y,Z) + f(X,Y,Z)"
auto X = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Y = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Z = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt_g = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto g = make_shared<Function>(make_shared<op::FunctionCall>(f, Nodes{X, Y, Z}) +
make_shared<op::FunctionCall>(f, Nodes{X, Y, Z}),
rt_g,
op::Parameters{X, Y, Z},
"g");
// Now make "h(X,Y,Z) = g(X,Y,Z) + g(X,Y,Z)"
auto X1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Y1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto Z1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt_h = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto h = make_shared<Function>(make_shared<op::FunctionCall>(g, Nodes{X1, Y1, Z1}) +
make_shared<op::FunctionCall>(g, Nodes{X1, Y1, Z1}),
rt_h,
op::Parameters{X1, Y1, Z1},
"h");
vector<Function*> functions;
traverse_functions(h, [&](shared_ptr<Function> fp) { functions.push_back(fp.get()); });
ASSERT_EQ(3, functions.size());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment