Commit fc6cd9ac authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

Merge branch 'master' into maxpool

parents e2d175db 54c0a66b
......@@ -19,6 +19,7 @@
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/asin.hpp"
#include "ngraph/ops/atan.hpp"
#include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/ceiling.hpp"
#include "ngraph/ops/concatenate.hpp"
......@@ -47,12 +48,16 @@
#include "ngraph/ops/not.hpp"
#include "ngraph/ops/not_equal.hpp"
#include "ngraph/ops/one_hot.hpp"
#include "ngraph/ops/pad.hpp"
#include "ngraph/ops/power.hpp"
#include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/reduce_window.hpp"
#include "ngraph/ops/remainder.hpp"
#include "ngraph/ops/replace_slice.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/reverse.hpp"
#include "ngraph/ops/select.hpp"
#include "ngraph/ops/select_and_scatter.hpp"
#include "ngraph/ops/sign.hpp"
#include "ngraph/ops/sin.hpp"
#include "ngraph/ops/sinh.hpp"
......@@ -301,19 +306,12 @@ static shared_ptr<ngraph::Function>
vector<string> node_inputs = node_js.at("inputs").get<vector<string>>();
vector<string> node_outputs = node_js.at("outputs").get<vector<string>>();
shared_ptr<Node> node;
shared_ptr<Function> function_ptr = nullptr;
vector<shared_ptr<Node>> args;
for (const string& name : node_inputs)
{
args.push_back(node_map.at(name));
}
vector<string> known_nodes;
for (auto x : node_map)
{
known_nodes.push_back(x.first);
}
if (node_op == "Abs")
{
node = make_shared<op::Abs>(args[0]);
......@@ -334,6 +332,16 @@ static shared_ptr<ngraph::Function>
{
node = make_shared<op::Atan>(args[0]);
}
else if (node_op == "AvgPool")
{
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<size_t>>();
auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
node = make_shared<op::AvgPool>(
args[0], window_shape, window_movement_strides, padding_below, padding_above);
}
else if (node_op == "Broadcast")
{
auto shape = node_js.at("shape").get<vector<size_t>>();
......@@ -371,12 +379,15 @@ static shared_ptr<ngraph::Function>
node_js.at("window_dilation_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>();
auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
auto image_dilation_strides =
node_js.at("image_dilation_strides").get<vector<size_t>>();
node = make_shared<op::Convolution>(args[0],
args[1],
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above);
padding_above,
image_dilation_strides);
}
else if (node_op == "Cos")
{
......@@ -391,9 +402,19 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::Divide>(args[0], args[1]);
}
else if (node_op == "Dot")
{
// For backwards compatibility, reduction_axes_count is optional.
auto obj = node_js["reduction_axes_count"];
if (obj.empty())
{
node = make_shared<op::Dot>(args[0], args[1]);
}
else
{
size_t reduction_axes_count = obj.get<size_t>();
node = make_shared<op::Dot>(args[0], args[1], reduction_axes_count);
}
}
else if (node_op == "Equal")
{
node = make_shared<op::Equal>(args[0], args[1]);
......@@ -473,6 +494,14 @@ static shared_ptr<ngraph::Function>
auto one_hot_axis = node_js.at("one_hot_axis").get<size_t>();
node = make_shared<op::OneHot>(args[0], shape, one_hot_axis);
}
else if (node_op == "Pad")
{
auto padding_below = node_js.at("padding_below").get<vector<size_t>>();
auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
auto padding_interior = node_js.at("padding_interior").get<vector<size_t>>();
node = make_shared<op::Pad>(
args[0], args[1], padding_below, padding_above, padding_interior);
}
else if (node_op == "Parameter")
{
auto type_node_js =
......@@ -488,7 +517,19 @@ static shared_ptr<ngraph::Function>
else if (node_op == "Reduce")
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Reduce>(args[0], args[1], function_ptr, reduction_axes);
string function_name = node_js.at("function").get<string>();
shared_ptr<Function> f_ptr = function_map.at(function_name);
node = make_shared<op::Reduce>(args[0], args[1], f_ptr, reduction_axes);
}
else if (node_op == "ReduceWindow")
{
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
string function_name = node_js.at("function").get<string>();
shared_ptr<Function> f_ptr = function_map.at(function_name);
node = make_shared<op::ReduceWindow>(
args[0], args[1], f_ptr, window_shape, window_movement_strides);
}
else if (node_op == "Remainder")
{
......@@ -508,10 +549,34 @@ static shared_ptr<ngraph::Function>
auto output_shape = node_js.at("output_shape").get<vector<size_t>>();
node = make_shared<op::Reshape>(args[0], input_order, output_shape);
}
else if (node_op == "Reverse")
{
auto reversed_axes = node_js.at("reversed_axes").get<set<size_t>>();
node = make_shared<op::Reverse>(args[0], reversed_axes);
}
else if (node_op == "Select")
{
node = make_shared<op::Select>(args[0], args[1], args[2]);
}
else if (node_op == "SelectAndScatter")
{
string selection_function_name = node_js.at("selection_function").get<string>();
shared_ptr<Function> selection_f_ptr = function_map.at(selection_function_name);
string scatter_function_name = node_js.at("scatter_function").get<string>();
shared_ptr<Function> scatter_f_ptr = function_map.at(scatter_function_name);
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
node = make_shared<op::SelectAndScatter>(args[0],
args[1],
args[2],
selection_f_ptr,
scatter_f_ptr,
window_shape,
window_movement_strides);
}
else if (node_op == "Sign")
{
node = make_shared<op::Sign>(args[0]);
......@@ -612,6 +677,14 @@ static json write(const Node& n)
else if (node_op == "Atan")
{
}
else if (node_op == "AvgPool")
{
auto tmp = dynamic_cast<const op::AvgPool*>(&n);
node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides();
node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above();
}
else if (node_op == "Broadcast")
{
auto tmp = dynamic_cast<const op::Broadcast*>(&n);
......@@ -645,6 +718,7 @@ static json write(const Node& n)
node["window_dilation_strides"] = tmp->get_window_dilation_strides();
node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above();
node["image_dilation_strides"] = tmp->get_image_dilation_strides();
}
else if (node_op == "Cos")
{
......@@ -657,6 +731,8 @@ static json write(const Node& n)
}
else if (node_op == "Dot")
{
auto tmp = dynamic_cast<const op::Dot*>(&n);
node["reduction_axes_count"] = tmp->get_reduction_axes_count();
}
else if (node_op == "Equal")
{
......@@ -719,6 +795,13 @@ static json write(const Node& n)
node["shape"] = tmp->get_shape();
node["one_hot_axis"] = tmp->get_one_hot_axis();
}
else if (node_op == "Pad")
{
auto tmp = dynamic_cast<const op::Pad*>(&n);
node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above();
node["padding_interior"] = tmp->get_padding_interior();
}
else if (node_op == "Parameter")
{
auto tmp = dynamic_cast<const op::Parameter*>(&n);
......@@ -734,6 +817,13 @@ static json write(const Node& n)
node["function"] = tmp->get_functions()[0]->get_name();
node["reduction_axes"] = tmp->get_reduction_axes();
}
else if (node_op == "ReduceWindow")
{
auto tmp = dynamic_cast<const op::ReduceWindow*>(&n);
node["function"] = tmp->get_functions()[0]->get_name();
node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides();
}
else if (node_op == "Remainder")
{
}
......@@ -750,9 +840,22 @@ static json write(const Node& n)
node["input_order"] = tmp->get_input_order();
node["output_shape"] = tmp->get_output_shape();
}
else if (node_op == "Reverse")
{
auto tmp = dynamic_cast<const op::Reverse*>(&n);
node["reversed_axes"] = tmp->get_reversed_axes();
}
else if (node_op == "Select")
{
}
else if (node_op == "SelectAndScatter")
{
auto tmp = dynamic_cast<const op::SelectAndScatter*>(&n);
node["selection_function"] = tmp->get_functions()[0]->get_name();
node["scatter_function"] = tmp->get_functions()[1]->get_name();
node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides();
}
else if (node_op == "Sign")
{
}
......
......@@ -99,7 +99,22 @@ TEST(serialize, existing_models)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, model);
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> f = ngraph::deserialize(ss);
shared_ptr<Function> f = ngraph::deserialize(json_string);
}
}
TEST(benchmark, serialize)
{
stopwatch timer;
string model = "mxnet/LSTM_backward.json";
const string json_path = file_util::path_join(SERIALIZED_ZOO, model);
timer.start();
const string json_string = file_util::read_file_to_string(json_path);
timer.stop();
cout << "file read took " << timer.get_milliseconds() << "ms\n";
timer.start();
shared_ptr<Function> f = ngraph::deserialize(json_string);
timer.stop();
cout << "deserialize took " << timer.get_milliseconds() << "ms\n";
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment