Unverified Commit 603cbdab authored by Fenglei Tian's avatar Fenglei Tian Committed by GitHub

Merge branch 'master' into tfl/send_recv_op

parents ac3743d3 79587d93
......@@ -13,37 +13,52 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/util/fused_op.hpp"
using namespace std;
using namespace ngraph;
bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Node> node)
pass::FusedOpDecomposition::FusedOpDecomposition(op_query_t callback)
: m_has_direct_support{callback}
{
}
bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
{
bool modified = false;
if (auto fused_op = std::dynamic_pointer_cast<ngraph::op::util::FusedOp>(node))
if (auto fused_op = dynamic_pointer_cast<op::util::FusedOp>(node))
{
if (m_callback && m_callback(*node))
if (m_has_direct_support && m_has_direct_support(*node))
{
// Op supported by backend. Do not decompose
return modified;
}
auto subgraph_outputs = fused_op->decompose_op();
// Run recursively untill no more fused ops
auto subgraph = extract_subgraph(subgraph_outputs, fused_op->get_arguments());
for (auto subgraph_node : subgraph)
{
if (auto nested_fused_op = dynamic_pointer_cast<op::util::FusedOp>(subgraph_node))
{
if (!(m_has_direct_support && m_has_direct_support(*nested_fused_op)))
{
run_on_node(nested_fused_op);
}
}
}
size_t i = 0;
for (auto output_node : subgraph_outputs)
{
for (size_t j = 0; j < output_node->get_outputs().size(); j++, i++)
{
// TODO: Provenance
std::set<ngraph::descriptor::Input*> fop_users{
begin(fused_op->get_outputs().at(i).get_inputs()),
end(fused_op->get_outputs().at(i).get_inputs())};
set<descriptor::Input*> fop_users{begin(fused_op->get_outputs().at(i).get_inputs()),
end(fused_op->get_outputs().at(i).get_inputs())};
for (auto fop_user : fop_users)
{
if (auto goe =
......@@ -52,7 +67,7 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod
if (goe->get_n() == i && !goe->get_output_inputs(0).empty())
{
// Replace GOE users
std::set<ngraph::descriptor::Input*> goe_users{
set<descriptor::Input*> goe_users{
begin(goe->get_outputs().at(0).get_inputs()),
end(goe->get_outputs().at(0).get_inputs())};
for (auto goe_user : goe_users)
......@@ -80,8 +95,3 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod
return modified;
}
pass::FusedOpDecomposition::FusedOpDecomposition(op_query_t callback)
: m_callback{callback}
{
}
......@@ -16,6 +16,9 @@
#pragma once
#include <memory>
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/pass/pass.hpp"
namespace ngraph
......@@ -25,13 +28,24 @@ namespace ngraph
class FusedOpDecomposition : public NodePass
{
public:
/// \brief Function signature type for callback used to check whether provided node
/// is supported by backend.
using op_query_t = std::function<bool(const Node& node)>;
///
/// \brief Constructor for the Fused operation decomposition pass.
///
/// \param[in] callback The function object used to determine whether current backend
/// provide direct support for passed node. Should have signature:
/// bool fn(const Node&)
///
FusedOpDecomposition(op_query_t callback = nullptr);
bool run_on_node(std::shared_ptr<ngraph::Node> node) override;
private:
op_query_t m_callback = nullptr;
/// \brief A function returning whether provided Node is supported by current backend.
/// The returned bool value is used to control whether decompose operator or not.
op_query_t m_has_direct_support = nullptr;
};
}
}
......@@ -1180,7 +1180,6 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
REGISTER_KNOBBED_PASS(RecurrentReshapeElimination, false, ngraph::pass);
REGISTER_KNOBBED_PASS_WITH_ARGS(
CoreFusion, true, ngraph::pass, ngraph::pass::FusionType::ALL_FUSIONS);
REGISTER_KNOBBED_PASS_WITH_ARGS(FusedOpDecomposition, true, ngraph::pass, is_supported);
REGISTER_KNOBBED_PASS(CPUFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUQuantFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUHorizontalFusion, true, runtime::cpu::pass);
......
......@@ -193,10 +193,15 @@ static OP_TYPEID get_typeid(const string& s)
return rc;
}
bool has_key(json j, const std::string& key)
{
return j.count(key) != 0;
}
template <typename T>
T get_or_default(nlohmann::json& j, const std::string& key, const T& default_value)
T get_or_default(json j, const std::string& key, const T& default_value)
{
return j.count(key) != 0 ? j.at(key).get<T>() : default_value;
return has_key(j, key) ? j.at(key).get<T>() : default_value;
}
class JSONSerializer
......@@ -215,8 +220,11 @@ public:
json serialize_function(const Function& function);
json serialize_output(const Output<Node>& output);
json serialize_parameter_vector(const ParameterVector& parameters);
json serialize_output_vector(const OutputVector& output_vector);
json serialize_node_reference(const Node& node);
json serialize_node(const Node& node);
json serialize_axis_set(const AxisSet& axis_set);
protected:
size_t m_indent{0};
......@@ -235,10 +243,13 @@ public:
m_const_data_callback = const_data_callback;
}
shared_ptr<Function> deserialize_function(json& j);
Output<Node> deserialize_output(json& j);
shared_ptr<Node> deserialize_node_reference(json& j);
shared_ptr<Node> deserialize_node(json& j);
shared_ptr<Function> deserialize_function(json j);
Output<Node> deserialize_output(json j);
OutputVector deserialize_output_vector(json j);
ParameterVector deserialize_parameter_vector(json j);
shared_ptr<Node> deserialize_node_reference(json j);
shared_ptr<Node> deserialize_node(json j);
AxisSet deserialize_axis_set(json j);
protected:
unordered_map<string, shared_ptr<Node>> m_node_map;
......@@ -261,7 +272,7 @@ static json write_dimension(Dimension d)
}
}
static Dimension read_dimension(const json& j)
static Dimension read_dimension(json j)
{
if (j.is_null())
{
......@@ -290,7 +301,7 @@ static json write_partial_shape(const PartialShape& s)
}
}
static PartialShape read_partial_shape(const json& j)
static PartialShape read_partial_shape(json j)
{
if (j.is_null())
{
......@@ -315,19 +326,32 @@ static json write_auto_broadcast(const op::AutoBroadcastSpec& autob)
return j;
}
static op::AutoBroadcastSpec read_auto_broadcast(const json& j)
static op::AutoBroadcastSpec read_auto_broadcast(json js_node, const std::string& attr)
{
if (!j.is_object())
if (has_key(js_node, attr))
{
return op::AutoBroadcastSpec();
json j = js_node[attr];
return op::AutoBroadcastSpec(static_cast<op::AutoBroadcastType>(j.at("type")),
j.at("axis").get<size_t>());
}
else
{
return op::AutoBroadcastSpec(static_cast<op::AutoBroadcastType>(j.at("type")),
j.at("axis").get<size_t>());
return op::AutoBroadcastSpec();
}
}
static op::PadType read_pad_type(json node_js)
{
return has_key(node_js, "pad_type") ? static_cast<op::PadType>(node_js.at("pad_type"))
: op::PadType::EXPLICIT;
}
static op::PadMode read_pad_mode(json node_js)
{
return has_key(node_js, "pad_mode") ? static_cast<op::PadMode>(node_js.at("pad_mode"))
: op::PadMode::CONSTANT;
}
static json write_element_type(const ngraph::element::Type& n)
{
json j;
......@@ -335,7 +359,7 @@ static json write_element_type(const ngraph::element::Type& n)
return j;
}
static element::Type read_element_type(const json& j)
static element::Type read_element_type(json j)
{
size_t bitwidth = 0;
bool is_real = false;
......@@ -495,21 +519,24 @@ shared_ptr<ngraph::Function> ngraph::deserialize(const string& s)
rc = deserializer.deserialize_function(func);
}
}
return rc;
}
json JSONSerializer::serialize_parameter_vector(const ParameterVector& parameters)
{
json json_parameters = json::array();
for (auto param : parameters)
{
json_parameters.push_back(serialize_node_reference(*param));
}
return json_parameters;
}
json JSONSerializer::serialize_function(const Function& f)
{
json function;
function["name"] = f.get_name();
vector<string> parameter_list;
for (auto param : f.get_parameters())
{
parameter_list.push_back(serialize_node_reference(*param));
}
function["parameters"] = parameter_list;
function["parameters"] = serialize_parameter_vector(f.get_parameters());
// TODO Functions can return multiple results
for (size_t i = 0; i < f.get_output_size(); ++i)
......@@ -521,7 +548,7 @@ json JSONSerializer::serialize_function(const Function& f)
}
template <typename T>
T get_value(nlohmann::json js, const string& key)
T get_value(json js, const string& key)
{
T rc;
auto it = js.find(key);
......@@ -532,13 +559,13 @@ T get_value(nlohmann::json js, const string& key)
return rc;
}
shared_ptr<Node> JSONDeserializer::deserialize_node_reference(json& j)
shared_ptr<Node> JSONDeserializer::deserialize_node_reference(json j)
{
const string& name = j;
return m_node_map.at(name);
}
Output<Node> JSONDeserializer::deserialize_output(json& j)
Output<Node> JSONDeserializer::deserialize_output(json j)
{
size_t index;
json json_node_reference;
......@@ -559,10 +586,48 @@ Output<Node> JSONDeserializer::deserialize_output(json& j)
return Output<Node>(deserialize_node_reference(json_node_reference), index);
}
shared_ptr<Function> JSONDeserializer::deserialize_function(json& func_js)
OutputVector JSONDeserializer::deserialize_output_vector(json j)
{
OutputVector result;
if (j.is_array())
{
for (json jelt : j)
{
result.push_back(deserialize_output(jelt));
}
}
return result;
}
json JSONSerializer::serialize_axis_set(const AxisSet& axis_set)
{
return static_cast<set<size_t>>(axis_set);
}
AxisSet JSONDeserializer::deserialize_axis_set(json j)
{
AxisSet result;
if (j.is_array())
{
result = j.get<set<size_t>>();
}
return result;
}
ParameterVector JSONDeserializer::deserialize_parameter_vector(json json_parameters)
{
std::vector<std::shared_ptr<op::Parameter>> params;
for (auto& param_ref : json_parameters)
{
params.push_back(
dynamic_pointer_cast<op::Parameter>(deserialize_node_reference(param_ref)));
}
return params;
}
shared_ptr<Function> JSONDeserializer::deserialize_function(json func_js)
{
string func_name = func_js.at("name").get<string>();
vector<json> func_parameters = func_js.at("parameters");
vector<json> func_result = func_js.at("result");
for (json node_js : func_js.at("ops"))
{
......@@ -594,12 +659,7 @@ shared_ptr<Function> JSONDeserializer::deserialize_function(json& func_js)
"Graph serialization is inconsistent. Some op::Results appear to be missing");
}
std::vector<std::shared_ptr<op::Parameter>> params;
for (auto& param_ref : func_parameters)
{
params.push_back(
dynamic_pointer_cast<op::Parameter>(deserialize_node_reference(param_ref)));
}
ParameterVector params = deserialize_parameter_vector(func_js.at("parameters"));
shared_ptr<Function> rc{make_shared<Function>(result, params, func_name)};
m_function_map[func_name] = rc;
......@@ -632,7 +692,12 @@ struct OutputHelper
// when all op constructors use the new style arguments.
struct OutputVectorHelper
{
const OutputHelper& operator[](size_t i) const { return m_vector[i]; }
OutputVectorHelper(const OutputVector& output_vector)
: m_vector(output_vector)
{
}
OutputVectorHelper() = default;
OutputHelper operator[](size_t i) const { return OutputHelper(m_vector[i]); }
void push_back(const Output<Node>& output) { m_vector.push_back(output); }
size_t size() const { return m_vector.size(); }
operator vector<shared_ptr<Node>>() const
......@@ -640,14 +705,15 @@ struct OutputVectorHelper
vector<shared_ptr<Node>> result;
for (auto& o : m_vector)
{
result.push_back(o);
result.push_back(OutputHelper(o));
}
return result;
}
vector<OutputHelper> m_vector;
operator const OutputVector&() const { return m_vector; }
OutputVector m_vector;
};
shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
{
shared_ptr<Node> node;
try
......@@ -655,14 +721,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
string node_name = node_js.at("name").get<string>();
string node_op = node_js.at("op").get<string>();
string friendly_name = get_value<string>(node_js, "friendly_name");
vector<json> node_inputs = get_value<vector<json>>(node_js, "inputs");
vector<json> control_deps_inputs = get_value<vector<json>>(node_js, "control_deps");
vector<string> node_outputs = get_value<vector<string>>(node_js, "outputs");
OutputVectorHelper args;
for (auto& node_input : node_inputs)
{
args.push_back(deserialize_output(node_input));
}
OutputVectorHelper args(deserialize_output_vector(node_js["inputs"]));
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
......@@ -683,12 +744,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case OP_TYPEID::Add:
{
node = make_shared<op::Add>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
node = make_shared<op::Add>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break;
}
case OP_TYPEID::All:
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
auto reduction_axes = deserialize_axis_set(node_js.at("reduction_axes"));
node = make_shared<op::All>(args[0], reduction_axes);
break;
}
......@@ -699,12 +760,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case OP_TYPEID::And:
{
node = make_shared<op::And>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
node = make_shared<op::And>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break;
}
case OP_TYPEID::Any:
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
auto reduction_axes = deserialize_axis_set(node_js.at("reduction_axes"));
node = make_shared<op::Any>(args[0], reduction_axes);
break;
}
......@@ -741,12 +802,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
auto include_padding_in_avg_computation =
node_js.at("include_padding_in_avg_computation").get<bool>();
op::PadType pad_type = node_js["pad_type"].empty()
? op::PadType::EXPLICIT
: static_cast<op::PadType>(node_js.at("pad_type"));
bool ceil_mode =
node_js["ceil_mode"].empty() ? false : node_js.at("ceil_mode").get<bool>();
;
op::PadType pad_type = read_pad_type(node_js);
bool ceil_mode = get_or_default<bool>(node_js, "ceil_mode", false);
node = make_shared<op::AvgPool>(args[0],
window_shape,
window_movement_strides,
......@@ -808,7 +865,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case OP_TYPEID::Broadcast:
{
auto shape = node_js.at("shape").get<vector<size_t>>();
auto axes = node_js.at("axes").get<set<size_t>>();
auto axes = deserialize_axis_set(node_js.at("axes"));
node = make_shared<op::Broadcast>(args[0], shape, axes);
break;
}
......@@ -819,7 +876,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case OP_TYPEID::BroadcastLike:
{
auto initial_axes = node_js.at("initial_axes").get<set<size_t>>();
auto initial_axes = deserialize_axis_set(node_js.at("initial_axes"));
node = make_shared<op::BroadcastLike>(args[0], args[1], initial_axes);
break;
}
......@@ -838,13 +895,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case OP_TYPEID::Concat:
{
auto axis = node_js.at("axis").get<size_t>();
node = make_shared<op::Concat>(args, axis);
node = make_shared<op::Concat>(static_cast<OutputVector>(args), axis);
break;
}
case OP_TYPEID::Constant:
{
auto type_node_js =
node_js.count("element_type") == 0 ? node_js.at("value_type") : node_js;
has_key(node_js, "element_type") ? node_js : node_js.at("value_type");
auto element_type = read_element_type(type_node_js.at("element_type"));
auto shape = type_node_js.at("shape");
auto value = node_js.at("value").get<vector<string>>();
......@@ -868,17 +925,19 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
// For backwards compatibility, we accept "image_dilation_strides" in place of
// "data_dilation_strides", and we also allow it to be omitted altogether.
auto data_dilation_strides_maybe = node_js["data_dilation_strides"];
if (data_dilation_strides_maybe.empty())
json data_dilation_strides;
if (has_key(node_js, "data_dilation_strides"))
{
data_dilation_strides_maybe = node_js["image_dilation_strides"];
data_dilation_strides = node_js["data_dilation_strides"];
}
else if (has_key(node_js, "image_dilation_strides"))
{
data_dilation_strides = node_js["image_dilation_strides"];
}
op::PadType pad_type = node_js["pad_type"].empty()
? op::PadType::EXPLICIT
: static_cast<op::PadType>(node_js.at("pad_type"));
op::PadType pad_type = read_pad_type(node_js);
if (data_dilation_strides_maybe.empty())
if (data_dilation_strides.empty())
{
node = make_shared<op::Convolution>(args[0],
args[1],
......@@ -889,15 +948,15 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
else
{
node = make_shared<op::Convolution>(
args[0],
args[1],
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides_maybe.get<std::vector<size_t>>(),
pad_type);
node =
make_shared<op::Convolution>(args[0],
args[1],
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides.get<std::vector<size_t>>(),
pad_type);
}
break;
}
......@@ -1033,33 +1092,28 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case OP_TYPEID::Dequantize:
{
auto type = read_element_type(node_js.at("type"));
auto axes = node_js.at("axes").get<set<size_t>>();
auto axes = deserialize_axis_set(node_js.at("axes"));
node = make_shared<op::Dequantize>(args[0], args[1], args[2], type, axes);
break;
}
case OP_TYPEID::Divide:
{
bool pythondiv = true;
if (node_js["pythondiv"].is_object())
{
pythondiv = node_js.at("pythondiv").get<bool>();
}
bool pythondiv = get_or_default(node_js, "pythondiv", true);
node = make_shared<op::Divide>(
args[0], args[1], pythondiv, read_auto_broadcast(node_js["autob"]));
args[0], args[1], pythondiv, read_auto_broadcast(node_js, "autob"));
break;
}
case OP_TYPEID::Dot:
{
// For backwards compatibility, reduction_axes_count is optional.
auto obj = node_js["reduction_axes_count"];
if (obj.empty())
if (has_key(node_js, "reduction_axes_count"))
{
node = make_shared<op::Dot>(args[0], args[1]);
size_t reduction_axes_count = node_js["reduction_axes_count"].get<size_t>();
node = make_shared<op::Dot>(args[0], args[1], reduction_axes_count);
}
else
{
size_t reduction_axes_count = obj.get<size_t>();
node = make_shared<op::Dot>(args[0], args[1], reduction_axes_count);
node = make_shared<op::Dot>(args[0], args[1]);
}
break;
}
......@@ -1095,7 +1149,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case OP_TYPEID::Equal:
{
node = make_shared<op::Equal>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
node = make_shared<op::Equal>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break;
}
case OP_TYPEID::Erf:
......@@ -1160,13 +1214,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case OP_TYPEID::Greater:
{
node =
make_shared<op::Greater>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
make_shared<op::Greater>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break;
}
case OP_TYPEID::GreaterEq:
{
node =
make_shared<op::GreaterEq>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
make_shared<op::GreaterEq>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break;
}
case OP_TYPEID::GRN:
......@@ -1193,10 +1247,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
auto data_dilation_strides = node_js.at("data_dilation_strides").get<vector<size_t>>();
auto groups = node_js.at("groups").get<size_t>();
op::PadType pad_type = node_js["pad_type"].empty()
? op::PadType::EXPLICIT
: static_cast<op::PadType>(node_js.at("pad_type"));
op::PadType pad_type = read_pad_type(node_js);
node = make_shared<op::GroupConvolution>(args[0],
args[1],
window_movement_strides,
......@@ -1216,9 +1267,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
auto padding_end = node_js.at("padding_end").get<vector<ptrdiff_t>>();
auto output_padding = node_js.at("output_padding").get<vector<ptrdiff_t>>();
auto groups = node_js.at("groups").get<size_t>();
op::PadType pad_type = node_js["pad_type"].empty()
? op::PadType::EXPLICIT
: static_cast<op::PadType>(node_js.at("pad_type"));
op::PadType pad_type = read_pad_type(node_js);
auto output_shape = node_js.at("output_shape").get<vector<size_t>>();
node = make_shared<op::GroupConvolutionTranspose>(args[0],
......@@ -1240,12 +1289,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case OP_TYPEID::Less:
{
node = make_shared<op::Less>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
node = make_shared<op::Less>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break;
}
case OP_TYPEID::LessEq:
{
node = make_shared<op::LessEq>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
node = make_shared<op::LessEq>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break;
}
case OP_TYPEID::Log:
......@@ -1287,7 +1336,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case OP_TYPEID::Max:
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
auto reduction_axes = deserialize_axis_set(node_js.at("reduction_axes"));
node = make_shared<op::Max>(args[0], reduction_axes);
break;
}
......@@ -1298,11 +1347,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
node_js.at("window_movement_strides").get<vector<size_t>>();
// For backwards compatibility, both (but not just one) of the padding_ fields may be
// omitted.
auto padding_below_maybe = node_js["padding_below"];
auto padding_above_maybe = node_js["padding_above"];
op::PadType pad_type = node_js["pad_type"].empty()
? op::PadType::EXPLICIT
: static_cast<op::PadType>(node_js.at("pad_type"));
auto padding_below_maybe = get_or_default(node_js, "padding_below", json{});
auto padding_above_maybe = get_or_default(node_js, "padding_above", json{});
op::PadType pad_type = read_pad_type(node_js);
if (padding_below_maybe.empty() && !padding_above_maybe.empty())
{
throw runtime_error(
......@@ -1361,31 +1408,31 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case OP_TYPEID::Maximum:
{
node =
make_shared<op::Maximum>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
make_shared<op::Maximum>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break;
}
case OP_TYPEID::Min:
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
auto reduction_axes = deserialize_axis_set(node_js.at("reduction_axes"));
node = make_shared<op::Min>(args[0], reduction_axes);
break;
}
case OP_TYPEID::Minimum:
{
node =
make_shared<op::Minimum>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
make_shared<op::Minimum>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break;
}
case OP_TYPEID::Multiply:
{
node =
make_shared<op::Multiply>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
make_shared<op::Multiply>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break;
}
case OP_TYPEID::MVN:
{
auto normalize_variance = node_js.at("normalize_variance").get<bool>();
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
auto reduction_axes = deserialize_axis_set(node_js.at("reduction_axes"));
auto eps = node_js.at("eps").get<double>();
node = make_shared<op::MVN>(args[0], normalize_variance, normalize_variance, eps);
break;
......@@ -1407,7 +1454,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case OP_TYPEID::NotEqual:
{
node =
make_shared<op::NotEqual>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
make_shared<op::NotEqual>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break;
}
case OP_TYPEID::Not:
......@@ -1424,7 +1471,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case OP_TYPEID::Or:
{
node = make_shared<op::Or>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
node = make_shared<op::Or>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break;
}
case OP_TYPEID::Pad:
......@@ -1441,9 +1488,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
[](size_t s) { return s == 0; }),
"Legacy padding_interior field must be zero everywhere.");
auto pad_mode = node_js.count("pad_mode") == 0
? op::PadMode::CONSTANT
: static_cast<op::PadMode>(node_js.at("pad_mode"));
auto pad_mode = read_pad_mode(node_js);
node = make_shared<op::Pad>(args[0], args[1], padding_below, padding_above, pad_mode);
break;
......@@ -1451,7 +1496,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case OP_TYPEID::Parameter:
{
auto type_node_js =
node_js.count("element_type") == 0 ? node_js.at("value_type") : node_js;
has_key(node_js, "element_type") ? node_js : node_js.at("value_type");
auto element_type = read_element_type(type_node_js.at("element_type"));
auto shape = type_node_js.at("shape");
auto cacheable = get_or_default<bool>(node_js, "cacheable", false);
......@@ -1476,7 +1521,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case OP_TYPEID::Power:
{
node = make_shared<op::Power>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
node = make_shared<op::Power>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break;
}
case OP_TYPEID::PRelu:
......@@ -1486,14 +1531,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case OP_TYPEID::Product:
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
auto reduction_axes = deserialize_axis_set(node_js.at("reduction_axes"));
node = make_shared<op::Product>(args[0], reduction_axes);
break;
}
case OP_TYPEID::Quantize:
{
auto type = read_element_type(node_js.at("type"));
auto axes = node_js.at("axes").get<set<size_t>>();
auto axes = deserialize_axis_set(node_js.at("axes"));
auto round_mode = node_js.at("round_mode").get<op::Quantize::RoundMode>();
node = make_shared<op::Quantize>(args[0], args[1], args[2], type, axes, round_mode);
break;
......@@ -1552,8 +1597,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
node_js.at("window_movement_strides").get<vector<size_t>>();
// For backwards compatibility, both (but not just one) of the padding_ fields may be
// omitted.
auto padding_below_maybe = node_js["padding_below"];
auto padding_above_maybe = node_js["padding_above"];
auto padding_below_maybe = get_or_default(node_js, "padding_below", json{});
auto padding_above_maybe = get_or_default(node_js, "padding_above", json{});
auto padding_below = padding_below_maybe.get<vector<size_t>>();
auto padding_above = padding_above_maybe.get<vector<size_t>>();
node = make_shared<op::QuantizedMaxPool>(
......@@ -1607,7 +1652,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case OP_TYPEID::Reverse:
{
auto reversed_axes = node_js.at("reversed_axes").get<set<size_t>>();
auto reversed_axes = deserialize_axis_set(node_js.at("reversed_axes"));
node = make_shared<op::Reverse>(args[0], reversed_axes);
break;
}
......@@ -1697,7 +1742,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case OP_TYPEID::Softmax:
{
auto softmax_axes = node_js.at("softmax_axes").get<set<size_t>>();
auto softmax_axes = deserialize_axis_set(node_js.at("softmax_axes"));
node = make_shared<op::Softmax>(args[0], softmax_axes);
break;
}
......@@ -1732,12 +1777,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case OP_TYPEID::Subtract:
{
node =
make_shared<op::Subtract>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
make_shared<op::Subtract>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break;
}
case OP_TYPEID::Sum:
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
auto reduction_axes = deserialize_axis_set(node_js.at("reduction_axes"));
node = make_shared<op::Sum>(args[0], reduction_axes);
break;
}
......@@ -1873,6 +1918,16 @@ json JSONSerializer::serialize_output(const Output<Node>& output)
return result;
}
json JSONSerializer::serialize_output_vector(const OutputVector& output_vector)
{
json result;
for (const Output<Node>& output : output_vector)
{
result.push_back(serialize_output(output));
}
return result;
}
json JSONSerializer::serialize_node(const Node& n)
{
m_nodes_serialized.insert(&n);
......@@ -1972,7 +2027,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::All:
{
auto tmp = dynamic_cast<const op::All*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes();
node["reduction_axes"] = serialize_axis_set(tmp->get_reduction_axes());
break;
}
case OP_TYPEID::AllReduce: { break;
......@@ -1989,7 +2044,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::Any:
{
auto tmp = dynamic_cast<const op::Any*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes();
node["reduction_axes"] = serialize_axis_set(tmp->get_reduction_axes());
break;
}
case OP_TYPEID::Asin: { break;
......@@ -2045,7 +2100,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::Broadcast:
{
auto tmp = dynamic_cast<const op::Broadcast*>(&n);
node["axes"] = tmp->get_broadcast_axes();
node["axes"] = serialize_axis_set(tmp->get_broadcast_axes());
node["shape"] = tmp->get_broadcast_shape();
break;
}
......@@ -2054,7 +2109,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::BroadcastLike:
{
auto tmp = dynamic_cast<const op::BroadcastLike*>(&n);
node["initial_axes"] = tmp->get_initial_broadcast_axes();
node["initial_axes"] = serialize_axis_set(tmp->get_initial_broadcast_axes());
break;
}
case OP_TYPEID::Ceiling: { break;
......@@ -2168,7 +2223,7 @@ json JSONSerializer::serialize_node(const Node& n)
{
auto tmp = dynamic_cast<const op::Dequantize*>(&n);
node["type"] = write_element_type(tmp->get_element_type());
node["axes"] = tmp->get_axes();
node["axes"] = serialize_axis_set(tmp->get_axes());
break;
}
case OP_TYPEID::DepthToSpace:
......@@ -2361,7 +2416,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::Max:
{
auto tmp = dynamic_cast<const op::Max*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes();
node["reduction_axes"] = serialize_axis_set(tmp->get_reduction_axes());
break;
}
case OP_TYPEID::MaxPool:
......@@ -2395,7 +2450,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::Min:
{
auto tmp = dynamic_cast<const op::Min*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes();
node["reduction_axes"] = serialize_axis_set(tmp->get_reduction_axes());
break;
}
case OP_TYPEID::Minimum:
......@@ -2419,7 +2474,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::MVN:
{
auto tmp = dynamic_cast<const op::MVN*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes();
node["reduction_axes"] = serialize_axis_set(tmp->get_reduction_axes());
node["normalize_variance"] = tmp->get_normalize_variance();
node["eps"] = tmp->get_eps();
break;
......@@ -2499,7 +2554,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::Product:
{
auto tmp = dynamic_cast<const op::Product*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes();
node["reduction_axes"] = serialize_axis_set(tmp->get_reduction_axes());
break;
}
case OP_TYPEID::Power:
......@@ -2515,7 +2570,7 @@ json JSONSerializer::serialize_node(const Node& n)
{
auto tmp = dynamic_cast<const op::Quantize*>(&n);
node["type"] = write_element_type(tmp->get_element_type());
node["axes"] = tmp->get_axes();
node["axes"] = serialize_axis_set(tmp->get_axes());
node["round_mode"] = tmp->get_round_mode();
break;
}
......@@ -2596,7 +2651,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::Reverse:
{
auto tmp = dynamic_cast<const op::Reverse*>(&n);
node["reversed_axes"] = tmp->get_reversed_axes();
node["reversed_axes"] = serialize_axis_set(tmp->get_reversed_axes());
break;
}
case OP_TYPEID::ReverseSequence:
......@@ -2689,13 +2744,13 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::Sum:
{
auto tmp = dynamic_cast<const op::Sum*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes();
node["reduction_axes"] = serialize_axis_set(tmp->get_reduction_axes());
break;
}
case OP_TYPEID::Softmax:
{
auto tmp = dynamic_cast<const op::Softmax*>(&n);
node["softmax_axes"] = tmp->get_axes();
node["softmax_axes"] = serialize_axis_set(tmp->get_axes());
break;
}
case OP_TYPEID::Tan: { break;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment