Commit b2fdb1f8 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

Merge branch 'master' into cpu_layout2

parents 4356b2cd 54c0a66b
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "ngraph/ops/add.hpp" #include "ngraph/ops/add.hpp"
#include "ngraph/ops/asin.hpp" #include "ngraph/ops/asin.hpp"
#include "ngraph/ops/atan.hpp" #include "ngraph/ops/atan.hpp"
#include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/ceiling.hpp" #include "ngraph/ops/ceiling.hpp"
#include "ngraph/ops/concatenate.hpp" #include "ngraph/ops/concatenate.hpp"
...@@ -47,12 +48,16 @@ ...@@ -47,12 +48,16 @@
#include "ngraph/ops/not.hpp" #include "ngraph/ops/not.hpp"
#include "ngraph/ops/not_equal.hpp" #include "ngraph/ops/not_equal.hpp"
#include "ngraph/ops/one_hot.hpp" #include "ngraph/ops/one_hot.hpp"
#include "ngraph/ops/pad.hpp"
#include "ngraph/ops/power.hpp" #include "ngraph/ops/power.hpp"
#include "ngraph/ops/reduce.hpp" #include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/reduce_window.hpp"
#include "ngraph/ops/remainder.hpp" #include "ngraph/ops/remainder.hpp"
#include "ngraph/ops/replace_slice.hpp" #include "ngraph/ops/replace_slice.hpp"
#include "ngraph/ops/reshape.hpp" #include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/reverse.hpp"
#include "ngraph/ops/select.hpp" #include "ngraph/ops/select.hpp"
#include "ngraph/ops/select_and_scatter.hpp"
#include "ngraph/ops/sign.hpp" #include "ngraph/ops/sign.hpp"
#include "ngraph/ops/sin.hpp" #include "ngraph/ops/sin.hpp"
#include "ngraph/ops/sinh.hpp" #include "ngraph/ops/sinh.hpp"
...@@ -301,19 +306,12 @@ static shared_ptr<ngraph::Function> ...@@ -301,19 +306,12 @@ static shared_ptr<ngraph::Function>
vector<string> node_inputs = node_js.at("inputs").get<vector<string>>(); vector<string> node_inputs = node_js.at("inputs").get<vector<string>>();
vector<string> node_outputs = node_js.at("outputs").get<vector<string>>(); vector<string> node_outputs = node_js.at("outputs").get<vector<string>>();
shared_ptr<Node> node; shared_ptr<Node> node;
shared_ptr<Function> function_ptr = nullptr;
vector<shared_ptr<Node>> args; vector<shared_ptr<Node>> args;
for (const string& name : node_inputs) for (const string& name : node_inputs)
{ {
args.push_back(node_map.at(name)); args.push_back(node_map.at(name));
} }
vector<string> known_nodes;
for (auto x : node_map)
{
known_nodes.push_back(x.first);
}
if (node_op == "Abs") if (node_op == "Abs")
{ {
node = make_shared<op::Abs>(args[0]); node = make_shared<op::Abs>(args[0]);
...@@ -334,6 +332,16 @@ static shared_ptr<ngraph::Function> ...@@ -334,6 +332,16 @@ static shared_ptr<ngraph::Function>
{ {
node = make_shared<op::Atan>(args[0]); node = make_shared<op::Atan>(args[0]);
} }
else if (node_op == "AvgPool")
{
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<size_t>>();
auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
node = make_shared<op::AvgPool>(
args[0], window_shape, window_movement_strides, padding_below, padding_above);
}
else if (node_op == "Broadcast") else if (node_op == "Broadcast")
{ {
auto shape = node_js.at("shape").get<vector<size_t>>(); auto shape = node_js.at("shape").get<vector<size_t>>();
...@@ -371,12 +379,15 @@ static shared_ptr<ngraph::Function> ...@@ -371,12 +379,15 @@ static shared_ptr<ngraph::Function>
node_js.at("window_dilation_strides").get<vector<size_t>>(); node_js.at("window_dilation_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>(); auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>();
auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>(); auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
auto image_dilation_strides =
node_js.at("image_dilation_strides").get<vector<size_t>>();
node = make_shared<op::Convolution>(args[0], node = make_shared<op::Convolution>(args[0],
args[1], args[1],
window_movement_strides, window_movement_strides,
window_dilation_strides, window_dilation_strides,
padding_below, padding_below,
padding_above); padding_above,
image_dilation_strides);
} }
else if (node_op == "Cos") else if (node_op == "Cos")
{ {
...@@ -391,9 +402,19 @@ static shared_ptr<ngraph::Function> ...@@ -391,9 +402,19 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::Divide>(args[0], args[1]); node = make_shared<op::Divide>(args[0], args[1]);
} }
else if (node_op == "Dot") else if (node_op == "Dot")
{
// For backwards compatibility, reduction_axes_count is optional.
auto obj = node_js["reduction_axes_count"];
if (obj.empty())
{ {
node = make_shared<op::Dot>(args[0], args[1]); node = make_shared<op::Dot>(args[0], args[1]);
} }
else
{
size_t reduction_axes_count = obj.get<size_t>();
node = make_shared<op::Dot>(args[0], args[1], reduction_axes_count);
}
}
else if (node_op == "Equal") else if (node_op == "Equal")
{ {
node = make_shared<op::Equal>(args[0], args[1]); node = make_shared<op::Equal>(args[0], args[1]);
...@@ -473,6 +494,14 @@ static shared_ptr<ngraph::Function> ...@@ -473,6 +494,14 @@ static shared_ptr<ngraph::Function>
auto one_hot_axis = node_js.at("one_hot_axis").get<size_t>(); auto one_hot_axis = node_js.at("one_hot_axis").get<size_t>();
node = make_shared<op::OneHot>(args[0], shape, one_hot_axis); node = make_shared<op::OneHot>(args[0], shape, one_hot_axis);
} }
else if (node_op == "Pad")
{
auto padding_below = node_js.at("padding_below").get<vector<size_t>>();
auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
auto padding_interior = node_js.at("padding_interior").get<vector<size_t>>();
node = make_shared<op::Pad>(
args[0], args[1], padding_below, padding_above, padding_interior);
}
else if (node_op == "Parameter") else if (node_op == "Parameter")
{ {
auto type_node_js = auto type_node_js =
...@@ -488,7 +517,19 @@ static shared_ptr<ngraph::Function> ...@@ -488,7 +517,19 @@ static shared_ptr<ngraph::Function>
else if (node_op == "Reduce") else if (node_op == "Reduce")
{ {
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>(); auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Reduce>(args[0], args[1], function_ptr, reduction_axes); string function_name = node_js.at("function").get<string>();
shared_ptr<Function> f_ptr = function_map.at(function_name);
node = make_shared<op::Reduce>(args[0], args[1], f_ptr, reduction_axes);
}
else if (node_op == "ReduceWindow")
{
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
string function_name = node_js.at("function").get<string>();
shared_ptr<Function> f_ptr = function_map.at(function_name);
node = make_shared<op::ReduceWindow>(
args[0], args[1], f_ptr, window_shape, window_movement_strides);
} }
else if (node_op == "Remainder") else if (node_op == "Remainder")
{ {
...@@ -508,10 +549,34 @@ static shared_ptr<ngraph::Function> ...@@ -508,10 +549,34 @@ static shared_ptr<ngraph::Function>
auto output_shape = node_js.at("output_shape").get<vector<size_t>>(); auto output_shape = node_js.at("output_shape").get<vector<size_t>>();
node = make_shared<op::Reshape>(args[0], input_order, output_shape); node = make_shared<op::Reshape>(args[0], input_order, output_shape);
} }
else if (node_op == "Reverse")
{
auto reversed_axes = node_js.at("reversed_axes").get<set<size_t>>();
node = make_shared<op::Reverse>(args[0], reversed_axes);
}
else if (node_op == "Select") else if (node_op == "Select")
{ {
node = make_shared<op::Select>(args[0], args[1], args[2]); node = make_shared<op::Select>(args[0], args[1], args[2]);
} }
else if (node_op == "SelectAndScatter")
{
string selection_function_name = node_js.at("selection_function").get<string>();
shared_ptr<Function> selection_f_ptr = function_map.at(selection_function_name);
string scatter_function_name = node_js.at("scatter_function").get<string>();
shared_ptr<Function> scatter_f_ptr = function_map.at(scatter_function_name);
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
node = make_shared<op::SelectAndScatter>(args[0],
args[1],
args[2],
selection_f_ptr,
scatter_f_ptr,
window_shape,
window_movement_strides);
}
else if (node_op == "Sign") else if (node_op == "Sign")
{ {
node = make_shared<op::Sign>(args[0]); node = make_shared<op::Sign>(args[0]);
...@@ -612,6 +677,14 @@ static json write(const Node& n) ...@@ -612,6 +677,14 @@ static json write(const Node& n)
else if (node_op == "Atan") else if (node_op == "Atan")
{ {
} }
else if (node_op == "AvgPool")
{
auto tmp = dynamic_cast<const op::AvgPool*>(&n);
node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides();
node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above();
}
else if (node_op == "Broadcast") else if (node_op == "Broadcast")
{ {
auto tmp = dynamic_cast<const op::Broadcast*>(&n); auto tmp = dynamic_cast<const op::Broadcast*>(&n);
...@@ -645,6 +718,7 @@ static json write(const Node& n) ...@@ -645,6 +718,7 @@ static json write(const Node& n)
node["window_dilation_strides"] = tmp->get_window_dilation_strides(); node["window_dilation_strides"] = tmp->get_window_dilation_strides();
node["padding_below"] = tmp->get_padding_below(); node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
node["image_dilation_strides"] = tmp->get_image_dilation_strides();
} }
else if (node_op == "Cos") else if (node_op == "Cos")
{ {
...@@ -657,6 +731,8 @@ static json write(const Node& n) ...@@ -657,6 +731,8 @@ static json write(const Node& n)
} }
else if (node_op == "Dot") else if (node_op == "Dot")
{ {
auto tmp = dynamic_cast<const op::Dot*>(&n);
node["reduction_axes_count"] = tmp->get_reduction_axes_count();
} }
else if (node_op == "Equal") else if (node_op == "Equal")
{ {
...@@ -719,6 +795,13 @@ static json write(const Node& n) ...@@ -719,6 +795,13 @@ static json write(const Node& n)
node["shape"] = tmp->get_shape(); node["shape"] = tmp->get_shape();
node["one_hot_axis"] = tmp->get_one_hot_axis(); node["one_hot_axis"] = tmp->get_one_hot_axis();
} }
else if (node_op == "Pad")
{
auto tmp = dynamic_cast<const op::Pad*>(&n);
node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above();
node["padding_interior"] = tmp->get_padding_interior();
}
else if (node_op == "Parameter") else if (node_op == "Parameter")
{ {
auto tmp = dynamic_cast<const op::Parameter*>(&n); auto tmp = dynamic_cast<const op::Parameter*>(&n);
...@@ -734,6 +817,13 @@ static json write(const Node& n) ...@@ -734,6 +817,13 @@ static json write(const Node& n)
node["function"] = tmp->get_functions()[0]->get_name(); node["function"] = tmp->get_functions()[0]->get_name();
node["reduction_axes"] = tmp->get_reduction_axes(); node["reduction_axes"] = tmp->get_reduction_axes();
} }
else if (node_op == "ReduceWindow")
{
auto tmp = dynamic_cast<const op::ReduceWindow*>(&n);
node["function"] = tmp->get_functions()[0]->get_name();
node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides();
}
else if (node_op == "Remainder") else if (node_op == "Remainder")
{ {
} }
...@@ -750,9 +840,22 @@ static json write(const Node& n) ...@@ -750,9 +840,22 @@ static json write(const Node& n)
node["input_order"] = tmp->get_input_order(); node["input_order"] = tmp->get_input_order();
node["output_shape"] = tmp->get_output_shape(); node["output_shape"] = tmp->get_output_shape();
} }
else if (node_op == "Reverse")
{
auto tmp = dynamic_cast<const op::Reverse*>(&n);
node["reversed_axes"] = tmp->get_reversed_axes();
}
else if (node_op == "Select") else if (node_op == "Select")
{ {
} }
else if (node_op == "SelectAndScatter")
{
auto tmp = dynamic_cast<const op::SelectAndScatter*>(&n);
node["selection_function"] = tmp->get_functions()[0]->get_name();
node["scatter_function"] = tmp->get_functions()[1]->get_name();
node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides();
}
else if (node_op == "Sign") else if (node_op == "Sign")
{ {
} }
......
...@@ -99,7 +99,22 @@ TEST(serialize, existing_models) ...@@ -99,7 +99,22 @@ TEST(serialize, existing_models)
{ {
const string json_path = file_util::path_join(SERIALIZED_ZOO, model); const string json_path = file_util::path_join(SERIALIZED_ZOO, model);
const string json_string = file_util::read_file_to_string(json_path); const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string); shared_ptr<Function> f = ngraph::deserialize(json_string);
shared_ptr<Function> f = ngraph::deserialize(ss);
} }
} }
TEST(benchmark, serialize)
{
stopwatch timer;
string model = "mxnet/LSTM_backward.json";
const string json_path = file_util::path_join(SERIALIZED_ZOO, model);
timer.start();
const string json_string = file_util::read_file_to_string(json_path);
timer.stop();
cout << "file read took " << timer.get_milliseconds() << "ms\n";
timer.start();
shared_ptr<Function> f = ngraph::deserialize(json_string);
timer.stop();
cout << "deserialize took " << timer.get_milliseconds() << "ms\n";
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment