Unverified Commit a1269685 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

add exception handing in deserializer to tell which op actually caused the exception (#756)

parent 71c2055c
......@@ -339,523 +339,542 @@ static shared_ptr<ngraph::Function>
unordered_map<string, shared_ptr<Node>> node_map;
for (json node_js : func_js.at("ops"))
{
string node_name = node_js.at("name").get<string>();
string node_op = node_js.at("op").get<string>();
vector<string> node_inputs = node_js.at("inputs").get<vector<string>>();
vector<string> node_outputs = node_js.at("outputs").get<vector<string>>();
shared_ptr<Node> node;
vector<shared_ptr<Node>> args;
for (const string& name : node_inputs)
{
args.push_back(node_map.at(name));
}
try
{
string node_name = node_js.at("name").get<string>();
string node_op = node_js.at("op").get<string>();
vector<string> node_inputs = node_js.at("inputs").get<vector<string>>();
vector<string> node_outputs = node_js.at("outputs").get<vector<string>>();
shared_ptr<Node> node;
vector<shared_ptr<Node>> args;
for (const string& name : node_inputs)
{
args.push_back(node_map.at(name));
}
if (node_op == "Abs")
{
node = make_shared<op::Abs>(args[0]);
}
else if (node_op == "Acos")
{
node = make_shared<op::Acos>(args[0]);
}
else if (node_op == "Add")
{
node = make_shared<op::Add>(args[0], args[1]);
}
else if (node_op == "AllReduce")
{
node = make_shared<op::AllReduce>(args[0]);
}
else if (node_op == "Asin")
{
node = make_shared<op::Asin>(args[0]);
}
else if (node_op == "Atan")
{
node = make_shared<op::Atan>(args[0]);
}
else if (node_op == "AvgPool")
{
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<size_t>>();
auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
auto include_padding_in_avg_computation =
node_js.at("include_padding_in_avg_computation").get<bool>();
node = make_shared<op::AvgPool>(args[0],
window_shape,
window_movement_strides,
padding_below,
padding_above,
include_padding_in_avg_computation);
}
else if (node_op == "AvgPoolBackprop")
{
auto forward_arg_shape = node_js.at("forward_arg_shape").get<vector<size_t>>();
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<size_t>>();
auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
auto include_padding_in_avg_computation =
get_or_default<bool>(node_js, "include_padding_in_avg_computation", false);
node = make_shared<op::AvgPoolBackprop>(forward_arg_shape,
args[0],
window_shape,
window_movement_strides,
padding_below,
padding_above,
include_padding_in_avg_computation);
}
else if (node_op == "BatchNorm")
{
auto epsilon = node_js.at("eps").get<double>();
bool training = get_or_default<bool>(node_js, "training", true);
if (training)
if (node_op == "Abs")
{
node = make_shared<op::BatchNorm>(epsilon, args[0], args[1], args[2]);
node = make_shared<op::Abs>(args[0]);
}
else
else if (node_op == "Acos")
{
node = make_shared<op::BatchNorm>(
epsilon, args[0], args[1], args[2], args[3], args[4]);
node = make_shared<op::Acos>(args[0]);
}
}
else if (node_op == "BatchNormBackprop")
{
auto epsilon = node_js.at("eps").get<double>();
node = make_shared<op::BatchNormBackprop>(
epsilon, args[0], args[1], args[2], args[3], args[4], args[5]);
}
else if (node_op == "Broadcast")
{
auto shape = node_js.at("shape").get<vector<size_t>>();
auto axes = node_js.at("axes").get<set<size_t>>();
node = make_shared<op::Broadcast>(args[0], shape, axes);
}
else if (node_op == "Ceiling")
{
node = make_shared<op::Ceiling>(args[0]);
}
else if (node_op == "Concat")
{
auto axis = node_js.at("axis").get<size_t>();
node = make_shared<op::Concat>(args, axis);
}
else if (node_op == "Constant")
{
auto type_node_js =
node_js.count("element_type") == 0 ? node_js.at("value_type") : node_js;
auto element_type = read_element_type(type_node_js.at("element_type"));
auto shape = type_node_js.at("shape");
try
else if (node_op == "Add")
{
auto value = node_js.at("value").get<vector<string>>();
node = make_shared<op::Constant>(element_type, shape, value);
node = make_shared<op::Add>(args[0], args[1]);
}
catch (...)
else if (node_op == "AllReduce")
{
node = const_data_callback(node_name, element_type, shape);
node = make_shared<op::AllReduce>(args[0]);
}
}
else if (node_op == "Convert")
{
auto target_type = read_element_type(node_js.at("target_type"));
node = make_shared<op::Convert>(args[0], target_type);
}
else if (node_op == "Convolution")
{
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
auto window_dilation_strides =
node_js.at("window_dilation_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>();
auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
// For backwards compatibility, we accept "image_dilation_strides" in place of
// "data_dilation_strides", and we also allow it to be omitted altogether.
auto data_dilation_strides_maybe = node_js["data_dilation_strides"];
if (data_dilation_strides_maybe.empty())
else if (node_op == "Asin")
{
node = make_shared<op::Asin>(args[0]);
}
else if (node_op == "Atan")
{
node = make_shared<op::Atan>(args[0]);
}
else if (node_op == "AvgPool")
{
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<size_t>>();
auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
auto include_padding_in_avg_computation =
node_js.at("include_padding_in_avg_computation").get<bool>();
node = make_shared<op::AvgPool>(args[0],
window_shape,
window_movement_strides,
padding_below,
padding_above,
include_padding_in_avg_computation);
}
else if (node_op == "AvgPoolBackprop")
{
auto forward_arg_shape = node_js.at("forward_arg_shape").get<vector<size_t>>();
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<size_t>>();
auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
auto include_padding_in_avg_computation =
get_or_default<bool>(node_js, "include_padding_in_avg_computation", false);
node = make_shared<op::AvgPoolBackprop>(forward_arg_shape,
args[0],
window_shape,
window_movement_strides,
padding_below,
padding_above,
include_padding_in_avg_computation);
}
else if (node_op == "BatchNorm")
{
auto epsilon = node_js.at("eps").get<double>();
bool training = get_or_default<bool>(node_js, "training", true);
if (training)
{
node = make_shared<op::BatchNorm>(epsilon, args[0], args[1], args[2]);
}
else
{
node = make_shared<op::BatchNorm>(
epsilon, args[0], args[1], args[2], args[3], args[4]);
}
}
else if (node_op == "BatchNormBackprop")
{
auto epsilon = node_js.at("eps").get<double>();
node = make_shared<op::BatchNormBackprop>(
epsilon, args[0], args[1], args[2], args[3], args[4], args[5]);
}
else if (node_op == "Broadcast")
{
auto shape = node_js.at("shape").get<vector<size_t>>();
auto axes = node_js.at("axes").get<set<size_t>>();
node = make_shared<op::Broadcast>(args[0], shape, axes);
}
else if (node_op == "Ceiling")
{
node = make_shared<op::Ceiling>(args[0]);
}
else if (node_op == "Concat")
{
data_dilation_strides_maybe = node_js["image_dilation_strides"];
auto axis = node_js.at("axis").get<size_t>();
node = make_shared<op::Concat>(args, axis);
}
else if (node_op == "Constant")
{
auto type_node_js =
node_js.count("element_type") == 0 ? node_js.at("value_type") : node_js;
auto element_type = read_element_type(type_node_js.at("element_type"));
auto shape = type_node_js.at("shape");
try
{
auto value = node_js.at("value").get<vector<string>>();
node = make_shared<op::Constant>(element_type, shape, value);
}
catch (...)
{
node = const_data_callback(node_name, element_type, shape);
}
}
else if (node_op == "Convert")
{
auto target_type = read_element_type(node_js.at("target_type"));
node = make_shared<op::Convert>(args[0], target_type);
}
else if (node_op == "Convolution")
{
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
auto window_dilation_strides =
node_js.at("window_dilation_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>();
auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
if (data_dilation_strides_maybe.empty())
// For backwards compatibility, we accept "image_dilation_strides" in place of
// "data_dilation_strides", and we also allow it to be omitted altogether.
auto data_dilation_strides_maybe = node_js["data_dilation_strides"];
if (data_dilation_strides_maybe.empty())
{
data_dilation_strides_maybe = node_js["image_dilation_strides"];
}
if (data_dilation_strides_maybe.empty())
{
node = make_shared<op::Convolution>(args[0],
args[1],
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above);
}
else
{
node = make_shared<op::Convolution>(
args[0],
args[1],
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides_maybe.get<std::vector<size_t>>());
}
}
else if (node_op == "ConvolutionBackpropData")
{
node = make_shared<op::Convolution>(args[0],
args[1],
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above);
auto data_batch_shape = node_js.at("data_batch_shape").get<vector<size_t>>();
auto window_movement_strides_forward =
node_js.at("window_movement_strides_forward").get<vector<size_t>>();
auto window_dilation_strides_forward =
node_js.at("window_dilation_strides_forward").get<vector<size_t>>();
auto padding_below_forward =
node_js.at("padding_below_forward").get<vector<std::ptrdiff_t>>();
auto padding_above_forward =
node_js.at("padding_above_forward").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides_forward =
node_js.at("data_dilation_strides_forward").get<vector<size_t>>();
node = make_shared<op::ConvolutionBackpropData>(data_batch_shape,
args[0],
args[1],
window_movement_strides_forward,
window_dilation_strides_forward,
padding_below_forward,
padding_above_forward,
data_dilation_strides_forward);
}
else
else if (node_op == "ConvolutionBackpropFilters")
{
node = make_shared<op::Convolution>(
args[0],
args[1],
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides_maybe.get<std::vector<size_t>>());
auto filters_shape = node_js.at("filters_shape").get<vector<size_t>>();
auto window_movement_strides_forward =
node_js.at("window_movement_strides_forward").get<vector<size_t>>();
auto window_dilation_strides_forward =
node_js.at("window_dilation_strides_forward").get<vector<size_t>>();
auto padding_below_forward =
node_js.at("padding_below_forward").get<vector<std::ptrdiff_t>>();
auto padding_above_forward =
node_js.at("padding_above_forward").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides_forward =
node_js.at("data_dilation_strides_forward").get<vector<size_t>>();
node = make_shared<op::ConvolutionBackpropFilters>(args[0],
filters_shape,
args[1],
window_movement_strides_forward,
window_dilation_strides_forward,
padding_below_forward,
padding_above_forward,
data_dilation_strides_forward);
}
}
else if (node_op == "ConvolutionBackpropData")
{
auto data_batch_shape = node_js.at("data_batch_shape").get<vector<size_t>>();
auto window_movement_strides_forward =
node_js.at("window_movement_strides_forward").get<vector<size_t>>();
auto window_dilation_strides_forward =
node_js.at("window_dilation_strides_forward").get<vector<size_t>>();
auto padding_below_forward =
node_js.at("padding_below_forward").get<vector<std::ptrdiff_t>>();
auto padding_above_forward =
node_js.at("padding_above_forward").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides_forward =
node_js.at("data_dilation_strides_forward").get<vector<size_t>>();
node = make_shared<op::ConvolutionBackpropData>(data_batch_shape,
args[0],
args[1],
window_movement_strides_forward,
window_dilation_strides_forward,
padding_below_forward,
padding_above_forward,
data_dilation_strides_forward);
}
else if (node_op == "ConvolutionBackpropFilters")
{
auto filters_shape = node_js.at("filters_shape").get<vector<size_t>>();
auto window_movement_strides_forward =
node_js.at("window_movement_strides_forward").get<vector<size_t>>();
auto window_dilation_strides_forward =
node_js.at("window_dilation_strides_forward").get<vector<size_t>>();
auto padding_below_forward =
node_js.at("padding_below_forward").get<vector<std::ptrdiff_t>>();
auto padding_above_forward =
node_js.at("padding_above_forward").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides_forward =
node_js.at("data_dilation_strides_forward").get<vector<size_t>>();
node = make_shared<op::ConvolutionBackpropFilters>(args[0],
filters_shape,
args[1],
window_movement_strides_forward,
window_dilation_strides_forward,
padding_below_forward,
padding_above_forward,
data_dilation_strides_forward);
}
else if (node_op == "Cos")
{
node = make_shared<op::Cos>(args[0]);
}
else if (node_op == "Cosh")
{
node = make_shared<op::Cosh>(args[0]);
}
else if (node_op == "Divide")
{
node = make_shared<op::Divide>(args[0], args[1]);
}
else if (node_op == "Dot")
{
// For backwards compatibility, reduction_axes_count is optional.
auto obj = node_js["reduction_axes_count"];
if (obj.empty())
else if (node_op == "Cos")
{
node = make_shared<op::Dot>(args[0], args[1]);
node = make_shared<op::Cos>(args[0]);
}
else
else if (node_op == "Cosh")
{
size_t reduction_axes_count = obj.get<size_t>();
node = make_shared<op::Dot>(args[0], args[1], reduction_axes_count);
node = make_shared<op::Cosh>(args[0]);
}
}
else if (node_op == "Equal")
{
node = make_shared<op::Equal>(args[0], args[1]);
}
else if (node_op == "Exp")
{
node = make_shared<op::Exp>(args[0]);
}
else if (node_op == "Floor")
{
node = make_shared<op::Floor>(args[0]);
}
else if (node_op == "FunctionCall")
{
string function_name = node_js.at("function").get<string>();
shared_ptr<Function> f_ptr = function_map.at(function_name);
node = make_shared<op::FunctionCall>(f_ptr, args);
}
else if (node_op == "GetOutputElement")
{
node = make_shared<op::GetOutputElement>(args[0], node_js.at("n").get<size_t>());
}
else if (node_op == "Greater")
{
node = make_shared<op::Greater>(args[0], args[1]);
}
else if (node_op == "GreaterEq")
{
node = make_shared<op::GreaterEq>(args[0], args[1]);
}
else if (node_op == "Less")
{
node = make_shared<op::Less>(args[0], args[1]);
}
else if (node_op == "LessEq")
{
node = make_shared<op::LessEq>(args[0], args[1]);
}
else if (node_op == "Log")
{
node = make_shared<op::Log>(args[0]);
}
else if (node_op == "Max")
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Max>(args[0], reduction_axes);
}
else if (node_op == "MaxPool")
{
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
// For backwards compatibility, both (but not just one) of the padding_ fields may be
// omitted.
auto padding_below_maybe = node_js["padding_below"];
auto padding_above_maybe = node_js["padding_above"];
if (padding_below_maybe.empty() && !padding_above_maybe.empty())
else if (node_op == "Divide")
{
throw runtime_error(
"MaxPool: padding_below is absent but padding_above is present");
node = make_shared<op::Divide>(args[0], args[1]);
}
else if (!padding_below_maybe.empty() && padding_above_maybe.empty())
else if (node_op == "Dot")
{
throw runtime_error(
"MaxPool: padding_below is present but padding_above is absent");
// For backwards compatibility, reduction_axes_count is optional.
auto obj = node_js["reduction_axes_count"];
if (obj.empty())
{
node = make_shared<op::Dot>(args[0], args[1]);
}
else
{
size_t reduction_axes_count = obj.get<size_t>();
node = make_shared<op::Dot>(args[0], args[1], reduction_axes_count);
}
}
else if (!padding_below_maybe.empty() && !padding_above_maybe.empty())
else if (node_op == "Equal")
{
auto padding_below = padding_below_maybe.get<vector<size_t>>();
auto padding_above = padding_above_maybe.get<vector<size_t>>();
node = make_shared<op::MaxPool>(
args[0], window_shape, window_movement_strides, padding_below, padding_above);
node = make_shared<op::Equal>(args[0], args[1]);
}
else
else if (node_op == "Exp")
{
node = make_shared<op::MaxPool>(args[0], window_shape, window_movement_strides);
node = make_shared<op::Exp>(args[0]);
}
}
else if (node_op == "MaxPoolBackprop")
{
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<size_t>>();
auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
node = make_shared<op::MaxPoolBackprop>(args[0],
args[1],
else if (node_op == "Floor")
{
node = make_shared<op::Floor>(args[0]);
}
else if (node_op == "FunctionCall")
{
string function_name = node_js.at("function").get<string>();
shared_ptr<Function> f_ptr = function_map.at(function_name);
node = make_shared<op::FunctionCall>(f_ptr, args);
}
else if (node_op == "GetOutputElement")
{
node = make_shared<op::GetOutputElement>(args[0], node_js.at("n").get<size_t>());
}
else if (node_op == "Greater")
{
node = make_shared<op::Greater>(args[0], args[1]);
}
else if (node_op == "GreaterEq")
{
node = make_shared<op::GreaterEq>(args[0], args[1]);
}
else if (node_op == "Less")
{
node = make_shared<op::Less>(args[0], args[1]);
}
else if (node_op == "LessEq")
{
node = make_shared<op::LessEq>(args[0], args[1]);
}
else if (node_op == "Log")
{
node = make_shared<op::Log>(args[0]);
}
else if (node_op == "Max")
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Max>(args[0], reduction_axes);
}
else if (node_op == "MaxPool")
{
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
// For backwards compatibility, both (but not just one) of the padding_ fields may be
// omitted.
auto padding_below_maybe = node_js["padding_below"];
auto padding_above_maybe = node_js["padding_above"];
if (padding_below_maybe.empty() && !padding_above_maybe.empty())
{
throw runtime_error(
"MaxPool: padding_below is absent but padding_above is present");
}
else if (!padding_below_maybe.empty() && padding_above_maybe.empty())
{
throw runtime_error(
"MaxPool: padding_below is present but padding_above is absent");
}
else if (!padding_below_maybe.empty() && !padding_above_maybe.empty())
{
auto padding_below = padding_below_maybe.get<vector<size_t>>();
auto padding_above = padding_above_maybe.get<vector<size_t>>();
node = make_shared<op::MaxPool>(args[0],
window_shape,
window_movement_strides,
padding_below,
padding_above);
}
else if (node_op == "Maximum")
{
node = make_shared<op::Maximum>(args[0], args[1]);
}
else if (node_op == "Min")
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Min>(args[0], reduction_axes);
}
else if (node_op == "Minimum")
{
node = make_shared<op::Minimum>(args[0], args[1]);
}
else if (node_op == "Multiply")
{
node = make_shared<op::Multiply>(args[0], args[1]);
}
else if (node_op == "Negative")
{
node = make_shared<op::Negative>(args[0]);
}
else if (node_op == "NotEqual")
{
node = make_shared<op::NotEqual>(args[0], args[1]);
}
else if (node_op == "Not")
{
node = make_shared<op::Not>(args[0]);
}
else if (node_op == "OneHot")
{
auto shape = node_js.at("shape").get<vector<size_t>>();
auto one_hot_axis = node_js.at("one_hot_axis").get<size_t>();
node = make_shared<op::OneHot>(args[0], shape, one_hot_axis);
}
else if (node_op == "Pad")
{
auto padding_below = node_js.at("padding_below").get<vector<size_t>>();
auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
auto padding_interior = node_js.at("padding_interior").get<vector<size_t>>();
node = make_shared<op::Pad>(
args[0], args[1], padding_below, padding_above, padding_interior);
}
else if (node_op == "Parameter")
{
auto type_node_js =
node_js.count("element_type") == 0 ? node_js.at("value_type") : node_js;
auto element_type = read_element_type(type_node_js.at("element_type"));
auto shape = type_node_js.at("shape");
node = make_shared<op::Parameter>(element_type, shape);
}
else if (node_op == "Power")
{
node = make_shared<op::Power>(args[0], args[1]);
}
else if (node_op == "Product")
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Product>(args[0], reduction_axes);
}
else if (node_op == "Reduce")
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
string function_name = node_js.at("function").get<string>();
shared_ptr<Function> f_ptr = function_map.at(function_name);
node = make_shared<op::Reduce>(args[0], args[1], f_ptr, reduction_axes);
}
else if (node_op == "ReduceWindow")
{
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
string function_name = node_js.at("function").get<string>();
shared_ptr<Function> f_ptr = function_map.at(function_name);
node = make_shared<op::ReduceWindow>(
args[0], args[1], f_ptr, window_shape, window_movement_strides);
}
else if (node_op == "Remainder")
{
node = make_shared<op::Remainder>(args[0], args[1]);
}
else if (node_op == "Relu")
{
node = make_shared<op::Relu>(args[0]);
}
else if (node_op == "ReluBackprop")
{
node = make_shared<op::ReluBackprop>(args[0], args[1]);
}
else if (node_op == "ReplaceSlice")
{
auto lower_bounds = node_js.at("lower_bounds").get<vector<size_t>>();
auto upper_bounds = node_js.at("upper_bounds").get<vector<size_t>>();
auto strides = node_js.at("strides").get<vector<size_t>>();
node = make_shared<op::ReplaceSlice>(
args[0], args[1], lower_bounds, upper_bounds, strides);
}
else if (node_op == "Reshape")
{
auto input_order = node_js.at("input_order").get<vector<size_t>>();
auto output_shape = node_js.at("output_shape").get<vector<size_t>>();
node = make_shared<op::Reshape>(args[0], input_order, output_shape);
}
else if (node_op == "Result")
{
node = make_shared<op::Result>(args[0]);
}
else if (node_op == "Reverse")
{
auto reversed_axes = node_js.at("reversed_axes").get<set<size_t>>();
node = make_shared<op::Reverse>(args[0], reversed_axes);
}
else if (node_op == "Select")
{
node = make_shared<op::Select>(args[0], args[1], args[2]);
}
else if (node_op == "SelectAndScatter")
{
string selection_function_name = node_js.at("selection_function").get<string>();
shared_ptr<Function> selection_f_ptr = function_map.at(selection_function_name);
string scatter_function_name = node_js.at("scatter_function").get<string>();
shared_ptr<Function> scatter_f_ptr = function_map.at(scatter_function_name);
}
else
{
node = make_shared<op::MaxPool>(args[0], window_shape, window_movement_strides);
}
}
else if (node_op == "MaxPoolBackprop")
{
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<size_t>>();
auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
node = make_shared<op::MaxPoolBackprop>(args[0],
args[1],
window_shape,
window_movement_strides,
padding_below,
padding_above);
}
else if (node_op == "Maximum")
{
node = make_shared<op::Maximum>(args[0], args[1]);
}
else if (node_op == "Min")
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Min>(args[0], reduction_axes);
}
else if (node_op == "Minimum")
{
node = make_shared<op::Minimum>(args[0], args[1]);
}
else if (node_op == "Multiply")
{
node = make_shared<op::Multiply>(args[0], args[1]);
}
else if (node_op == "Negative")
{
node = make_shared<op::Negative>(args[0]);
}
else if (node_op == "NotEqual")
{
node = make_shared<op::NotEqual>(args[0], args[1]);
}
else if (node_op == "Not")
{
node = make_shared<op::Not>(args[0]);
}
else if (node_op == "OneHot")
{
auto shape = node_js.at("shape").get<vector<size_t>>();
auto one_hot_axis = node_js.at("one_hot_axis").get<size_t>();
node = make_shared<op::OneHot>(args[0], shape, one_hot_axis);
}
else if (node_op == "Pad")
{
auto padding_below = node_js.at("padding_below").get<vector<size_t>>();
auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
auto padding_interior = node_js.at("padding_interior").get<vector<size_t>>();
node = make_shared<op::Pad>(
args[0], args[1], padding_below, padding_above, padding_interior);
}
else if (node_op == "Parameter")
{
auto type_node_js =
node_js.count("element_type") == 0 ? node_js.at("value_type") : node_js;
auto element_type = read_element_type(type_node_js.at("element_type"));
auto shape = type_node_js.at("shape");
node = make_shared<op::Parameter>(element_type, shape);
}
else if (node_op == "Power")
{
node = make_shared<op::Power>(args[0], args[1]);
}
else if (node_op == "Product")
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Product>(args[0], reduction_axes);
}
else if (node_op == "Reduce")
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
string function_name = node_js.at("function").get<string>();
shared_ptr<Function> f_ptr = function_map.at(function_name);
node = make_shared<op::Reduce>(args[0], args[1], f_ptr, reduction_axes);
}
else if (node_op == "ReduceWindow")
{
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
string function_name = node_js.at("function").get<string>();
shared_ptr<Function> f_ptr = function_map.at(function_name);
node = make_shared<op::ReduceWindow>(
args[0], args[1], f_ptr, window_shape, window_movement_strides);
}
else if (node_op == "Remainder")
{
node = make_shared<op::Remainder>(args[0], args[1]);
}
else if (node_op == "Relu")
{
node = make_shared<op::Relu>(args[0]);
}
else if (node_op == "ReluBackprop")
{
node = make_shared<op::ReluBackprop>(args[0], args[1]);
}
else if (node_op == "ReplaceSlice")
{
auto lower_bounds = node_js.at("lower_bounds").get<vector<size_t>>();
auto upper_bounds = node_js.at("upper_bounds").get<vector<size_t>>();
auto strides = node_js.at("strides").get<vector<size_t>>();
node = make_shared<op::ReplaceSlice>(
args[0], args[1], lower_bounds, upper_bounds, strides);
}
else if (node_op == "Reshape")
{
auto input_order = node_js.at("input_order").get<vector<size_t>>();
auto output_shape = node_js.at("output_shape").get<vector<size_t>>();
node = make_shared<op::Reshape>(args[0], input_order, output_shape);
}
else if (node_op == "Result")
{
node = make_shared<op::Result>(args[0]);
}
else if (node_op == "Reverse")
{
auto reversed_axes = node_js.at("reversed_axes").get<set<size_t>>();
node = make_shared<op::Reverse>(args[0], reversed_axes);
}
else if (node_op == "Select")
{
node = make_shared<op::Select>(args[0], args[1], args[2]);
}
else if (node_op == "SelectAndScatter")
{
string selection_function_name = node_js.at("selection_function").get<string>();
shared_ptr<Function> selection_f_ptr = function_map.at(selection_function_name);
string scatter_function_name = node_js.at("scatter_function").get<string>();
shared_ptr<Function> scatter_f_ptr = function_map.at(scatter_function_name);
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
node = make_shared<op::SelectAndScatter>(args[0],
args[1],
args[2],
selection_f_ptr,
scatter_f_ptr,
window_shape,
window_movement_strides);
}
else if (node_op == "Sign")
{
node = make_shared<op::Sign>(args[0]);
}
else if (node_op == "Sin")
{
node = make_shared<op::Sin>(args[0]);
}
else if (node_op == "Sinh")
{
node = make_shared<op::Sinh>(args[0]);
}
else if (node_op == "Slice")
{
auto lower_bounds = node_js.at("lower_bounds").get<vector<size_t>>();
auto upper_bounds = node_js.at("upper_bounds").get<vector<size_t>>();
auto strides = node_js.at("strides").get<vector<size_t>>();
node = make_shared<op::Slice>(args[0], lower_bounds, upper_bounds, strides);
}
else if (node_op == "Softmax")
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Softmax>(args[0], reduction_axes);
}
else if (node_op == "Sqrt")
{
node = make_shared<op::Sqrt>(args[0]);
}
else if (node_op == "Subtract")
{
node = make_shared<op::Subtract>(args[0], args[1]);
}
else if (node_op == "Sum")
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Sum>(args[0], reduction_axes);
}
else if (node_op == "Tan")
{
node = make_shared<op::Tan>(args[0]);
}
else if (node_op == "Tanh")
{
node = make_shared<op::Tanh>(args[0]);
node = make_shared<op::SelectAndScatter>(args[0],
args[1],
args[2],
selection_f_ptr,
scatter_f_ptr,
window_shape,
window_movement_strides);
}
else if (node_op == "Sign")
{
node = make_shared<op::Sign>(args[0]);
}
else if (node_op == "Sin")
{
node = make_shared<op::Sin>(args[0]);
}
else if (node_op == "Sinh")
{
node = make_shared<op::Sinh>(args[0]);
}
else if (node_op == "Slice")
{
auto lower_bounds = node_js.at("lower_bounds").get<vector<size_t>>();
auto upper_bounds = node_js.at("upper_bounds").get<vector<size_t>>();
auto strides = node_js.at("strides").get<vector<size_t>>();
node = make_shared<op::Slice>(args[0], lower_bounds, upper_bounds, strides);
}
else if (node_op == "Softmax")
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Softmax>(args[0], reduction_axes);
}
else if (node_op == "Sqrt")
{
node = make_shared<op::Sqrt>(args[0]);
}
else if (node_op == "Subtract")
{
node = make_shared<op::Subtract>(args[0], args[1]);
}
else if (node_op == "Sum")
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
node = make_shared<op::Sum>(args[0], reduction_axes);
}
else if (node_op == "Tan")
{
node = make_shared<op::Tan>(args[0]);
}
else if (node_op == "Tanh")
{
node = make_shared<op::Tanh>(args[0]);
}
else
{
stringstream ss;
ss << "unsupported op " << node_op;
throw runtime_error(ss.str());
}
node_map[node_name] = node;
// Typically, it could be unsafe to change the name of a node since it may break nameing
// uniqueness. However, it could sometimes be helpful to use the original name from
// the serialization for debugging.
// node->set_name(node_name);
}
else
catch (...)
{
stringstream ss;
ss << "unsupported op " << node_op;
throw runtime_error(ss.str());
string node_name;
try
{
node_name = node_js.at("name").get<string>();
}
catch (...)
{
node_name = "UNKNOWN";
}
throw runtime_error("Error parsing json at node '" + node_name + "'");
}
node_map[node_name] = node;
// Typically, it could be unsafe to change the name of a node since it may break nameing
// uniqueness. However, it could sometimes be helpful to use the original name from
// the serialization for debugging.
// node->set_name(node_name);
}
//This handles both graphs w/ `op::Result` and legacy graphs w/o it
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment