Unverified Commit 24c715f4 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into mlir

parents 3b5bfdab 5e19c25c
...@@ -37,6 +37,7 @@ NodeInput = Union[Node, NumericData] ...@@ -37,6 +37,7 @@ NodeInput = Union[Node, NumericData]
ngraph_to_numpy_types_map = [ ngraph_to_numpy_types_map = [
(NgraphType.boolean, np.bool), (NgraphType.boolean, np.bool),
(NgraphType.f16, np.float16),
(NgraphType.f32, np.float32), (NgraphType.f32, np.float32),
(NgraphType.f64, np.float64), (NgraphType.f64, np.float64),
(NgraphType.i8, np.int8), (NgraphType.i8, np.int8),
......
...@@ -28,6 +28,7 @@ void regclass_pyngraph_Type(py::module m) ...@@ -28,6 +28,7 @@ void regclass_pyngraph_Type(py::module m)
py::class_<ngraph::element::Type, std::shared_ptr<ngraph::element::Type>> type(m, "Type"); py::class_<ngraph::element::Type, std::shared_ptr<ngraph::element::Type>> type(m, "Type");
type.doc() = "ngraph.impl.Type wraps ngraph::element::Type"; type.doc() = "ngraph.impl.Type wraps ngraph::element::Type";
type.attr("boolean") = ngraph::element::boolean; type.attr("boolean") = ngraph::element::boolean;
type.attr("f16") = ngraph::element::f16;
type.attr("f32") = ngraph::element::f32; type.attr("f32") = ngraph::element::f32;
type.attr("f64") = ngraph::element::f64; type.attr("f64") = ngraph::element::f64;
type.attr("i8") = ngraph::element::i8; type.attr("i8") = ngraph::element::i8;
......
...@@ -85,6 +85,8 @@ add_library(onnx_import STATIC ...@@ -85,6 +85,8 @@ add_library(onnx_import STATIC
op/equal.hpp op/equal.hpp
op/erf.hpp op/erf.hpp
op/exp.hpp op/exp.hpp
op/eye_like.cpp
op/eye_like.hpp
op/flatten.cpp op/flatten.cpp
op/flatten.hpp op/flatten.hpp
op/floor.hpp op/floor.hpp
...@@ -191,6 +193,7 @@ add_library(onnx_import STATIC ...@@ -191,6 +193,7 @@ add_library(onnx_import STATIC
op/xor.hpp op/xor.hpp
ops_bridge.cpp ops_bridge.cpp
ops_bridge.hpp ops_bridge.hpp
utils/common.cpp
utils/common.hpp utils/common.hpp
utils/convpool.cpp utils/convpool.cpp
utils/convpool.hpp utils/convpool.hpp
......
...@@ -52,6 +52,8 @@ namespace ngraph ...@@ -52,6 +52,8 @@ namespace ngraph
const std::string& output(int index) const; const std::string& output(int index) const;
std::size_t get_outputs_size() const; std::size_t get_outputs_size() const;
bool has_attribute(const std::string& name) const;
template <typename T> template <typename T>
T get_attribute_value(const std::string& name, T default_value) const; T get_attribute_value(const std::string& name, T default_value) const;
...@@ -87,6 +89,15 @@ namespace ngraph ...@@ -87,6 +89,15 @@ namespace ngraph
} }
std::size_t Node::Impl::get_outputs_size() const { return m_output_names.size(); } std::size_t Node::Impl::get_outputs_size() const { return m_output_names.size(); }
bool Node::Impl::has_attribute(const std::string& name) const
{
auto it = std::find_if(
std::begin(m_attributes), std::end(m_attributes), [&](const Attribute& attribute) {
return attribute.get_name() == name;
});
return it != std::end(m_attributes);
}
template <typename T> template <typename T>
T Node::Impl::get_attribute_value(const std::string& name, T default_value) const T Node::Impl::get_attribute_value(const std::string& name, T default_value) const
{ {
...@@ -185,6 +196,11 @@ namespace ngraph ...@@ -185,6 +196,11 @@ namespace ngraph
const std::string& Node::output(int index) const { return m_pimpl->output(index); } const std::string& Node::output(int index) const { return m_pimpl->output(index); }
std::size_t Node::get_outputs_size() const { return m_pimpl->get_outputs_size(); } std::size_t Node::get_outputs_size() const { return m_pimpl->get_outputs_size(); }
bool Node::has_attribute(const std::string& name) const
{
return m_pimpl->has_attribute(name);
}
template <> template <>
float Node::get_attribute_value(const std::string& name, float default_value) const float Node::get_attribute_value(const std::string& name, float default_value) const
{ {
......
...@@ -78,6 +78,8 @@ namespace ngraph ...@@ -78,6 +78,8 @@ namespace ngraph
const std::string& output(int index) const; const std::string& output(int index) const;
std::size_t get_outputs_size() const; std::size_t get_outputs_size() const;
bool has_attribute(const std::string& name) const;
template <typename T> template <typename T>
T get_attribute_value(const std::string& name, T default_value) const; T get_attribute_value(const std::string& name, T default_value) const;
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "node.hpp" #include "node.hpp"
#include "tensor.hpp" #include "tensor.hpp"
#include "utils/common.hpp"
#include "weight.hpp" #include "weight.hpp"
namespace ngraph namespace ngraph
...@@ -41,17 +42,8 @@ namespace ngraph ...@@ -41,17 +42,8 @@ namespace ngraph
{ {
} }
}; };
struct unsupported_element_type : ngraph_error } // namespace value_info
{ } // namespace error
explicit unsupported_element_type(TensorProto_DataType type)
: ngraph_error{"unsupported value info element type: " +
onnx::TensorProto_DataType_Name(
static_cast<onnx::TensorProto_DataType>(type))}
{
}
};
}
}
class ValueInfo class ValueInfo
{ {
...@@ -83,24 +75,8 @@ namespace ngraph ...@@ -83,24 +75,8 @@ namespace ngraph
{ {
throw error::value_info::unspecified_element_type{}; throw error::value_info::unspecified_element_type{};
} }
switch (m_value_info_proto->type().tensor_type().elem_type()) return common::get_ngraph_element_type(
{ m_value_info_proto->type().tensor_type().elem_type());
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: return element::boolean;
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: return element::f32;
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE: return element::f64;
case onnx::TensorProto_DataType::TensorProto_DataType_INT8: return element::i8;
case onnx::TensorProto_DataType::TensorProto_DataType_INT16: return element::i16;
case onnx::TensorProto_DataType::TensorProto_DataType_INT32: return element::i32;
case onnx::TensorProto_DataType::TensorProto_DataType_INT64: return element::i64;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8: return element::u8;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: return element::u16;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: return element::u32;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: return element::u64;
default:
throw error::value_info::unsupported_element_type{
m_value_info_proto->type().tensor_type().elem_type()};
}
} }
std::shared_ptr<ngraph::Node> std::shared_ptr<ngraph::Node>
......
...@@ -13,14 +13,13 @@ ...@@ -13,14 +13,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <memory> #include <memory>
#include <onnx-ml.pb.h>
#include "cast.hpp" #include "cast.hpp"
#include "exceptions.hpp" #include "exceptions.hpp"
#include "ngraph/op/convert.hpp" #include "ngraph/op/convert.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -34,25 +33,7 @@ namespace ngraph ...@@ -34,25 +33,7 @@ namespace ngraph
{ {
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
int64_t target_type = node.get_attribute_value<int64_t>("to"); int64_t target_type = node.get_attribute_value<int64_t>("to");
element::Type elem_type; element::Type elem_type = common::get_ngraph_element_type(target_type);
switch (target_type)
{
case onnx::TensorProto_DataType_BOOL: elem_type = element::boolean; break;
case onnx::TensorProto_DataType_DOUBLE: elem_type = element::f64; break;
case onnx::TensorProto_DataType_FLOAT16:
case onnx::TensorProto_DataType_FLOAT: elem_type = element::f32; break;
case onnx::TensorProto_DataType_INT8: elem_type = element::i8; break;
case onnx::TensorProto_DataType_INT16: elem_type = element::i16; break;
case onnx::TensorProto_DataType_INT32: elem_type = element::i32; break;
case onnx::TensorProto_DataType_INT64: elem_type = element::i64; break;
case onnx::TensorProto_DataType_UINT8: elem_type = element::u8; break;
case onnx::TensorProto_DataType_UINT16: elem_type = element::u16; break;
case onnx::TensorProto_DataType_UINT32: elem_type = element::u32; break;
case onnx::TensorProto_DataType_UINT64: elem_type = element::u64; break;
case onnx::TensorProto_DataType_UNDEFINED: elem_type = element::dynamic; break;
default: ASSERT_IS_SUPPORTED(node, false) << "unsupported type";
}
return {std::make_shared<ngraph::op::Convert>(data, elem_type)}; return {std::make_shared<ngraph::op::Convert>(data, elem_type)};
} }
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "eye_like.hpp"
#include "exceptions.hpp"
#include "ngraph/frontend/onnx_import/utils/common.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector eye_like(const Node& node)
{
const auto input = node.get_ng_inputs().at(0);
const auto& input_shape = input->get_shape();
std::int64_t dtype;
element::Type target_type;
std::int64_t shift = node.get_attribute_value<std::int64_t>("k", 0);
if (node.has_attribute("dtype"))
{
dtype = node.get_attribute_value<std::int64_t>("dtype");
target_type = common::get_ngraph_element_type(dtype);
}
else
{
target_type = input->get_element_type();
}
ASSERT_VALID_ARGUMENT(node, input_shape.size() == 2)
<< "The provided shape rank: " << input_shape.size()
<< " is unsupported, only 2D shapes are supported";
std::shared_ptr<ngraph::Node> eye_like_matrix =
common::shifted_square_identity(input_shape, target_type, shift);
return {eye_like_matrix};
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector eye_like(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -53,6 +53,7 @@ ...@@ -53,6 +53,7 @@
#include "op/equal.hpp" #include "op/equal.hpp"
#include "op/erf.hpp" #include "op/erf.hpp"
#include "op/exp.hpp" #include "op/exp.hpp"
#include "op/eye_like.hpp"
#include "op/flatten.hpp" #include "op/flatten.hpp"
#include "op/floor.hpp" #include "op/floor.hpp"
#include "op/gather.hpp" #include "op/gather.hpp"
...@@ -260,6 +261,7 @@ namespace ngraph ...@@ -260,6 +261,7 @@ namespace ngraph
REGISTER_OPERATOR("Equal", 1, equal); REGISTER_OPERATOR("Equal", 1, equal);
REGISTER_OPERATOR("Erf", 1, erf); REGISTER_OPERATOR("Erf", 1, erf);
REGISTER_OPERATOR("Exp", 1, exp); REGISTER_OPERATOR("Exp", 1, exp);
REGISTER_OPERATOR("EyeLike", 1, eye_like);
REGISTER_OPERATOR("Flatten", 1, flatten); REGISTER_OPERATOR("Flatten", 1, flatten);
REGISTER_OPERATOR("Floor", 1, floor); REGISTER_OPERATOR("Floor", 1, floor);
REGISTER_OPERATOR("Gather", 1, gather); REGISTER_OPERATOR("Gather", 1, gather);
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <onnx-ml.pb.h> // onnx types
#include "common.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace common
{
const ngraph::element::Type& get_ngraph_element_type(int64_t onnx_type)
{
switch (onnx_type)
{
case onnx::TensorProto_DataType_BOOL: return element::boolean;
case onnx::TensorProto_DataType_DOUBLE: return element::f64;
case onnx::TensorProto_DataType_FLOAT16: return element::f16;
case onnx::TensorProto_DataType_FLOAT: return element::f32;
case onnx::TensorProto_DataType_INT8: return element::i8;
case onnx::TensorProto_DataType_INT16: return element::i16;
case onnx::TensorProto_DataType_INT32: return element::i32;
case onnx::TensorProto_DataType_INT64: return element::i64;
case onnx::TensorProto_DataType_UINT8: return element::u8;
case onnx::TensorProto_DataType_UINT16: return element::u16;
case onnx::TensorProto_DataType_UINT32: return element::u32;
case onnx::TensorProto_DataType_UINT64: return element::u64;
case onnx::TensorProto_DataType_UNDEFINED: return element::dynamic;
}
throw ngraph_error("unsupported element type: " +
onnx::TensorProto_DataType_Name(
static_cast<onnx::TensorProto_DataType>(onnx_type)));
}
} // namespace common
} // namespace onnx_import
} // namespace ngraph
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <algorithm> // std::generate #include <algorithm> // std::generate
#include <cmath> // std::floor, std::min #include <cmath> // std::floor, std::min
#include <cstddef> // std::size_t #include <cstddef> // std::size_t
#include <cstdint> // std::int64_t
#include <iterator> // std::begin, std::end #include <iterator> // std::begin, std::end
#include <memory> // std::shared_ptr, std::make_shared #include <memory> // std::shared_ptr, std::make_shared
#include <type_traits> // std::enable_if #include <type_traits> // std::enable_if
...@@ -27,6 +28,7 @@ ...@@ -27,6 +28,7 @@
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -34,6 +36,8 @@ namespace ngraph ...@@ -34,6 +36,8 @@ namespace ngraph
{ {
namespace common namespace common
{ {
const ngraph::element::Type& get_ngraph_element_type(std::int64_t onnx_type);
/// \brief Return a monotonic sequence. /// \brief Return a monotonic sequence.
/// ///
/// \note Limitations: this function may not work for very large integer values /// \note Limitations: this function may not work for very large integer values
...@@ -86,6 +90,43 @@ namespace ngraph ...@@ -86,6 +90,43 @@ namespace ngraph
} }
} }
/// \brief Creates a shifted square identity matrix.
/// \note Shifting in the context of this operator means that
/// the matrix can be created with elements equal to 1 not only in the main diagonal.
/// Shifting adds an offset and moves the diagonal up or down
///
/// \param[in] output_shape Shape of the resulting matrix.
/// \param[in] output_type Element type of the resulting matrix.
/// \param[in] shift Shifting of diagonal.
///
/// \return A Constant node representing shifted identity matrix.
template <typename T = double>
std::shared_ptr<ngraph::op::Constant>
shifted_square_identity(const Shape output_shape,
const element::Type& output_type,
const std::int64_t shift)
{
std::vector<T> identity_matrix(shape_size(output_shape), T{0});
std::int64_t rows = output_shape[0];
std::int64_t cols = output_shape[1];
for (std::int64_t row = 0; row < rows; ++row)
{
const std::int64_t diagonal_element_idx = (row * cols) + row + shift;
if (row + shift < 0)
{
continue;
}
else if (row + shift >= cols)
{
break;
}
identity_matrix.at(diagonal_element_idx) = T{1};
}
return std::make_shared<ngraph::op::Constant>(
output_type, output_shape, identity_matrix);
}
/// \brief Creates a square identity matrix. /// \brief Creates a square identity matrix.
/// ///
/// \param[in] n Order of the resulting matrix. /// \param[in] n Order of the resulting matrix.
...@@ -95,16 +136,9 @@ namespace ngraph ...@@ -95,16 +136,9 @@ namespace ngraph
std::shared_ptr<ngraph::op::Constant> square_identity(const size_t n, std::shared_ptr<ngraph::op::Constant> square_identity(const size_t n,
const element::Type& type) const element::Type& type)
{ {
std::vector<T> identity_matrix(n * n, T{0}); return shifted_square_identity(Shape{n, n}, type, 0);
for (size_t row = 0; row < n; ++row)
{
const size_t diagonal_element = (n * row) + row;
identity_matrix.at(diagonal_element) = T{1};
} }
return std::make_shared<ngraph::op::Constant>(type, Shape{{n, n}}, identity_matrix);
}
} // namespace common } // namespace common
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -216,6 +216,7 @@ namespace ngraph ...@@ -216,6 +216,7 @@ namespace ngraph
virtual bool is_op() const { return false; } virtual bool is_op() const { return false; }
virtual bool is_commutative() { return false; } virtual bool is_commutative() { return false; }
virtual bool is_dynamic() const; virtual bool is_dynamic() const;
virtual bool has_state() const { return false; }
size_t get_instance_id() const { return m_instance_id; } size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&); friend std::ostream& operator<<(std::ostream&, const Node&);
virtual std::ostream& write_short_description(std::ostream&) const; virtual std::ostream& write_short_description(std::ostream&) const;
......
...@@ -72,6 +72,8 @@ namespace ngraph ...@@ -72,6 +72,8 @@ namespace ngraph
/// \brief Returns the seed value supplied to a random generator /// \brief Returns the seed value supplied to a random generator
uint64_t get_seed() const { return m_seed; } uint64_t get_seed() const { return m_seed; }
bool get_use_seed() const { return m_use_seed; } bool get_use_seed() const { return m_use_seed; }
/// GenerateMask has state.
bool has_state() const override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override const NodeVector& deltas) override
......
...@@ -24,8 +24,8 @@ NGRAPH_OP(ConvolutionBiasBackpropFiltersBias, ngraph::op) ...@@ -24,8 +24,8 @@ NGRAPH_OP(ConvolutionBiasBackpropFiltersBias, ngraph::op)
NGRAPH_OP(DepthToSpace, ngraph::op) NGRAPH_OP(DepthToSpace, ngraph::op)
NGRAPH_OP(Elu, ngraph::op) NGRAPH_OP(Elu, ngraph::op)
NGRAPH_OP(FakeQuantize, ngraph::op) NGRAPH_OP(FakeQuantize, ngraph::op)
NGRAPH_OP(GRN, ngraph::op)
NGRAPH_OP(Gemm, ngraph::op) NGRAPH_OP(Gemm, ngraph::op)
NGRAPH_OP(GRN, ngraph::op)
NGRAPH_OP(GroupConvolution, ngraph::op) NGRAPH_OP(GroupConvolution, ngraph::op)
NGRAPH_OP(GroupConvolutionTranspose, ngraph::op) NGRAPH_OP(GroupConvolutionTranspose, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op) NGRAPH_OP(HardSigmoid, ngraph::op)
...@@ -35,9 +35,9 @@ NGRAPH_OP(MVN, ngraph::op) ...@@ -35,9 +35,9 @@ NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP(Normalize, ngraph::op) NGRAPH_OP(Normalize, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op) NGRAPH_OP(PRelu, ngraph::op)
NGRAPH_OP(ScaleShift, ngraph::op) NGRAPH_OP(ScaleShift, ngraph::op)
NGRAPH_OP(SpaceToDepth, ngraph::op)
NGRAPH_OP(ShuffleChannels, ngraph::op) NGRAPH_OP(ShuffleChannels, ngraph::op)
NGRAPH_OP(SpaceToDepth, ngraph::op)
NGRAPH_OP(Split, ngraph::op)
NGRAPH_OP(SquaredDifference, ngraph::op) NGRAPH_OP(SquaredDifference, ngraph::op)
NGRAPH_OP(Squeeze, ngraph::op) NGRAPH_OP(Squeeze, ngraph::op)
NGRAPH_OP(Split, ngraph::op)
NGRAPH_OP(Unsqueeze, ngraph::op) NGRAPH_OP(Unsqueeze, ngraph::op)
...@@ -81,11 +81,12 @@ NGRAPH_OP(Cos, ngraph::op) ...@@ -81,11 +81,12 @@ NGRAPH_OP(Cos, ngraph::op)
NGRAPH_OP(Cosh, ngraph::op) NGRAPH_OP(Cosh, ngraph::op)
NGRAPH_OP(Dequantize, ngraph::op) NGRAPH_OP(Dequantize, ngraph::op)
NGRAPH_OP(Divide, ngraph::op) NGRAPH_OP(Divide, ngraph::op)
NGRAPH_OP(DynBroadcast, ngraph::op)
NGRAPH_OP(Dot, ngraph::op) NGRAPH_OP(Dot, ngraph::op)
NGRAPH_OP(DynBroadcast, ngraph::op)
NGRAPH_OP(DynPad, ngraph::op) NGRAPH_OP(DynPad, ngraph::op)
NGRAPH_OP(DynReshape, ngraph::op) NGRAPH_OP(DynReshape, ngraph::op)
NGRAPH_OP(DynSlice, ngraph::op) NGRAPH_OP(DynSlice, ngraph::op)
NGRAPH_OP(EmbeddingLookup, ngraph::op)
NGRAPH_OP(Equal, ngraph::op) NGRAPH_OP(Equal, ngraph::op)
NGRAPH_OP(Erf, ngraph::op) NGRAPH_OP(Erf, ngraph::op)
NGRAPH_OP(Exp, ngraph::op) NGRAPH_OP(Exp, ngraph::op)
...@@ -119,13 +120,13 @@ NGRAPH_OP(Power, ngraph::op) ...@@ -119,13 +120,13 @@ NGRAPH_OP(Power, ngraph::op)
NGRAPH_OP(Product, ngraph::op) NGRAPH_OP(Product, ngraph::op)
NGRAPH_OP(Quantize, ngraph::op) NGRAPH_OP(Quantize, ngraph::op)
NGRAPH_OP(QuantizedAvgPool, ngraph::op) NGRAPH_OP(QuantizedAvgPool, ngraph::op)
NGRAPH_OP(QuantizedConvolution, ngraph::op)
NGRAPH_OP(QuantizedConvolutionBias, ngraph::op) NGRAPH_OP(QuantizedConvolutionBias, ngraph::op)
NGRAPH_OP(QuantizedConvolutionBiasAdd, ngraph::op) NGRAPH_OP(QuantizedConvolutionBiasAdd, ngraph::op)
NGRAPH_OP(QuantizedConvolutionBiasSignedAdd, ngraph::op) NGRAPH_OP(QuantizedConvolutionBiasSignedAdd, ngraph::op)
NGRAPH_OP(QuantizedConvolutionRelu, ngraph::op) NGRAPH_OP(QuantizedConvolutionRelu, ngraph::op)
NGRAPH_OP(QuantizedConvolution, ngraph::op)
NGRAPH_OP(QuantizedDotBias, ngraph::op)
NGRAPH_OP(QuantizedDot, ngraph::op) NGRAPH_OP(QuantizedDot, ngraph::op)
NGRAPH_OP(QuantizedDotBias, ngraph::op)
NGRAPH_OP(QuantizedMaxPool, ngraph::op) NGRAPH_OP(QuantizedMaxPool, ngraph::op)
NGRAPH_OP(Range, ngraph::op) NGRAPH_OP(Range, ngraph::op)
NGRAPH_OP(Relu, ngraph::op) NGRAPH_OP(Relu, ngraph::op)
...@@ -153,7 +154,6 @@ NGRAPH_OP(Subtract, ngraph::op) ...@@ -153,7 +154,6 @@ NGRAPH_OP(Subtract, ngraph::op)
NGRAPH_OP(Sum, ngraph::op) NGRAPH_OP(Sum, ngraph::op)
NGRAPH_OP(Tan, ngraph::op) NGRAPH_OP(Tan, ngraph::op)
NGRAPH_OP(Tanh, ngraph::op) NGRAPH_OP(Tanh, ngraph::op)
NGRAPH_OP(TopK, ngraph::op)
NGRAPH_OP(Tile, ngraph::op) NGRAPH_OP(Tile, ngraph::op)
NGRAPH_OP(TopK, ngraph::op)
NGRAPH_OP(Transpose, ngraph::op) NGRAPH_OP(Transpose, ngraph::op)
NGRAPH_OP(EmbeddingLookup, ngraph::op)
...@@ -13,36 +13,51 @@ ...@@ -13,36 +13,51 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "ngraph/pass/fused_op_decomposition.hpp" #include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/util/fused_op.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Node> node) pass::FusedOpDecomposition::FusedOpDecomposition(op_query_t callback)
: m_has_direct_support{callback}
{
}
bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
{ {
bool modified = false; bool modified = false;
if (auto fused_op = std::dynamic_pointer_cast<ngraph::op::util::FusedOp>(node)) if (auto fused_op = dynamic_pointer_cast<op::util::FusedOp>(node))
{ {
if (m_callback && m_callback(*node)) if (m_has_direct_support && m_has_direct_support(*node))
{ {
// Op supported by backend. Do not decompose // Op supported by backend. Do not decompose
return modified; return modified;
} }
auto subgraph_outputs = fused_op->decompose_op(); auto subgraph_outputs = fused_op->decompose_op();
// Run recursively untill no more fused ops
auto subgraph = extract_subgraph(subgraph_outputs, fused_op->get_arguments());
for (auto subgraph_node : subgraph)
{
if (auto nested_fused_op = dynamic_pointer_cast<op::util::FusedOp>(subgraph_node))
{
if (!(m_has_direct_support && m_has_direct_support(*nested_fused_op)))
{
run_on_node(nested_fused_op);
}
}
}
size_t i = 0; size_t i = 0;
for (auto output_node : subgraph_outputs) for (auto output_node : subgraph_outputs)
{ {
for (size_t j = 0; j < output_node->get_outputs().size(); j++, i++) for (size_t j = 0; j < output_node->get_outputs().size(); j++, i++)
{ {
// TODO: Provenance // TODO: Provenance
std::set<ngraph::descriptor::Input*> fop_users{ set<descriptor::Input*> fop_users{begin(fused_op->get_outputs().at(i).get_inputs()),
begin(fused_op->get_outputs().at(i).get_inputs()),
end(fused_op->get_outputs().at(i).get_inputs())}; end(fused_op->get_outputs().at(i).get_inputs())};
for (auto fop_user : fop_users) for (auto fop_user : fop_users)
{ {
...@@ -52,7 +67,7 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod ...@@ -52,7 +67,7 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod
if (goe->get_n() == i && !goe->get_output_inputs(0).empty()) if (goe->get_n() == i && !goe->get_output_inputs(0).empty())
{ {
// Replace GOE users // Replace GOE users
std::set<ngraph::descriptor::Input*> goe_users{ set<descriptor::Input*> goe_users{
begin(goe->get_outputs().at(0).get_inputs()), begin(goe->get_outputs().at(0).get_inputs()),
end(goe->get_outputs().at(0).get_inputs())}; end(goe->get_outputs().at(0).get_inputs())};
for (auto goe_user : goe_users) for (auto goe_user : goe_users)
...@@ -80,8 +95,3 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod ...@@ -80,8 +95,3 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod
return modified; return modified;
} }
pass::FusedOpDecomposition::FusedOpDecomposition(op_query_t callback)
: m_callback{callback}
{
}
...@@ -16,6 +16,9 @@ ...@@ -16,6 +16,9 @@
#pragma once #pragma once
#include <memory>
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
namespace ngraph namespace ngraph
...@@ -25,13 +28,24 @@ namespace ngraph ...@@ -25,13 +28,24 @@ namespace ngraph
class FusedOpDecomposition : public NodePass class FusedOpDecomposition : public NodePass
{ {
public: public:
/// \brief Function signature type for callback used to check whether provided node
/// is supported by backend.
using op_query_t = std::function<bool(const Node& node)>; using op_query_t = std::function<bool(const Node& node)>;
///
/// \brief Constructor for the Fused operation decomposition pass.
///
/// \param[in] callback The function object used to determine whether current backend
/// provide direct support for passed node. Should have signature:
/// bool fn(const Node&)
///
FusedOpDecomposition(op_query_t callback = nullptr); FusedOpDecomposition(op_query_t callback = nullptr);
bool run_on_node(std::shared_ptr<ngraph::Node> node) override; bool run_on_node(std::shared_ptr<ngraph::Node> node) override;
private: private:
op_query_t m_callback = nullptr; /// \brief A function returning whether provided Node is supported by current backend.
/// The returned bool value is used to control whether decompose operator or not.
op_query_t m_has_direct_support = nullptr;
}; };
} }
} }
...@@ -924,7 +924,8 @@ using namespace ngraph::runtime; ...@@ -924,7 +924,8 @@ using namespace ngraph::runtime;
// Always enable nodes computing output tensors or nodes whose outputs might get // Always enable nodes computing output tensors or nodes whose outputs might get
// overwritten due to inplace kernels // overwritten due to inplace kernels
// TODO (jbobba) - Do we need to handle cacheability // TODO (jbobba) - Do we need to handle cacheability
if (computes_result(node.get()) || possibly_overwritten(node.get())) if (computes_result(node.get()) || possibly_overwritten(node.get()) ||
node->has_state())
{ {
writer << " || 1"; writer << " || 1";
} }
...@@ -1187,7 +1188,6 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes( ...@@ -1187,7 +1188,6 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
{ {
REGISTER_KNOBBED_PASS(CPUFusion, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(CPUFusion, true, runtime::cpu::pass);
} }
REGISTER_KNOBBED_PASS(CPUQuantFusion, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(CPUQuantFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUHorizontalFusion, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(CPUHorizontalFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUCollapseDims, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(CPUCollapseDims, true, runtime::cpu::pass);
...@@ -1437,7 +1437,7 @@ void runtime::cpu::CPU_ExternalFunction::build(ngraph::pass::PassConfig& pass_co ...@@ -1437,7 +1437,7 @@ void runtime::cpu::CPU_ExternalFunction::build(ngraph::pass::PassConfig& pass_co
bool disable_caching = bool disable_caching =
(reuse_memory && (reuse_memory &&
!cacheable) // Check cacheability only if we are reusing intermediate tensors !cacheable) // Check cacheability only if we are reusing intermediate tensors
|| computes_result(node.get()) || possibly_overwritten(node.get()); || computes_result(node.get()) || possibly_overwritten(node.get()) || node->has_state();
vector<reference_wrapper<bool>> in_stale, out_stale; vector<reference_wrapper<bool>> in_stale, out_stale;
for (const auto& name : in_names) for (const auto& name : in_names)
......
...@@ -192,10 +192,15 @@ static OP_TYPEID get_typeid(const string& s) ...@@ -192,10 +192,15 @@ static OP_TYPEID get_typeid(const string& s)
return rc; return rc;
} }
bool has_key(json j, const std::string& key)
{
return j.count(key) != 0;
}
template <typename T> template <typename T>
T get_or_default(nlohmann::json& j, const std::string& key, const T& default_value) T get_or_default(json j, const std::string& key, const T& default_value)
{ {
return j.count(key) != 0 ? j.at(key).get<T>() : default_value; return has_key(j, key) ? j.at(key).get<T>() : default_value;
} }
class JSONSerializer class JSONSerializer
...@@ -214,8 +219,11 @@ public: ...@@ -214,8 +219,11 @@ public:
json serialize_function(const Function& function); json serialize_function(const Function& function);
json serialize_output(const Output<Node>& output); json serialize_output(const Output<Node>& output);
json serialize_parameter_vector(const ParameterVector& parameters);
json serialize_output_vector(const OutputVector& output_vector);
json serialize_node_reference(const Node& node); json serialize_node_reference(const Node& node);
json serialize_node(const Node& node); json serialize_node(const Node& node);
json serialize_axis_set(const AxisSet& axis_set);
protected: protected:
size_t m_indent{0}; size_t m_indent{0};
...@@ -234,10 +242,13 @@ public: ...@@ -234,10 +242,13 @@ public:
m_const_data_callback = const_data_callback; m_const_data_callback = const_data_callback;
} }
shared_ptr<Function> deserialize_function(json& j); shared_ptr<Function> deserialize_function(json j);
Output<Node> deserialize_output(json& j); Output<Node> deserialize_output(json j);
shared_ptr<Node> deserialize_node_reference(json& j); OutputVector deserialize_output_vector(json j);
shared_ptr<Node> deserialize_node(json& j); ParameterVector deserialize_parameter_vector(json j);
shared_ptr<Node> deserialize_node_reference(json j);
shared_ptr<Node> deserialize_node(json j);
AxisSet deserialize_axis_set(json j);
protected: protected:
unordered_map<string, shared_ptr<Node>> m_node_map; unordered_map<string, shared_ptr<Node>> m_node_map;
...@@ -260,7 +271,7 @@ static json write_dimension(Dimension d) ...@@ -260,7 +271,7 @@ static json write_dimension(Dimension d)
} }
} }
static Dimension read_dimension(const json& j) static Dimension read_dimension(json j)
{ {
if (j.is_null()) if (j.is_null())
{ {
...@@ -289,7 +300,7 @@ static json write_partial_shape(const PartialShape& s) ...@@ -289,7 +300,7 @@ static json write_partial_shape(const PartialShape& s)
} }
} }
static PartialShape read_partial_shape(const json& j) static PartialShape read_partial_shape(json j)
{ {
if (j.is_null()) if (j.is_null())
{ {
...@@ -314,19 +325,32 @@ static json write_auto_broadcast(const op::AutoBroadcastSpec& autob) ...@@ -314,19 +325,32 @@ static json write_auto_broadcast(const op::AutoBroadcastSpec& autob)
return j; return j;
} }
static op::AutoBroadcastSpec read_auto_broadcast(const json& j) static op::AutoBroadcastSpec read_auto_broadcast(json js_node, const std::string& attr)
{ {
if (!j.is_object()) if (has_key(js_node, attr))
{ {
return op::AutoBroadcastSpec(); json j = js_node[attr];
return op::AutoBroadcastSpec(static_cast<op::AutoBroadcastType>(j.at("type")),
j.at("axis").get<size_t>());
} }
else else
{ {
return op::AutoBroadcastSpec(static_cast<op::AutoBroadcastType>(j.at("type")), return op::AutoBroadcastSpec();
j.at("axis").get<size_t>());
} }
} }
static op::PadType read_pad_type(json node_js)
{
return has_key(node_js, "pad_type") ? static_cast<op::PadType>(node_js.at("pad_type"))
: op::PadType::EXPLICIT;
}
static op::PadMode read_pad_mode(json node_js)
{
return has_key(node_js, "pad_mode") ? static_cast<op::PadMode>(node_js.at("pad_mode"))
: op::PadMode::CONSTANT;
}
static json write_element_type(const ngraph::element::Type& n) static json write_element_type(const ngraph::element::Type& n)
{ {
json j; json j;
...@@ -334,7 +358,7 @@ static json write_element_type(const ngraph::element::Type& n) ...@@ -334,7 +358,7 @@ static json write_element_type(const ngraph::element::Type& n)
return j; return j;
} }
static element::Type read_element_type(const json& j) static element::Type read_element_type(json j)
{ {
size_t bitwidth = 0; size_t bitwidth = 0;
bool is_real = false; bool is_real = false;
...@@ -494,21 +518,24 @@ shared_ptr<ngraph::Function> ngraph::deserialize(const string& s) ...@@ -494,21 +518,24 @@ shared_ptr<ngraph::Function> ngraph::deserialize(const string& s)
rc = deserializer.deserialize_function(func); rc = deserializer.deserialize_function(func);
} }
} }
return rc; return rc;
} }
json JSONSerializer::serialize_parameter_vector(const ParameterVector& parameters)
{
json json_parameters = json::array();
for (auto param : parameters)
{
json_parameters.push_back(serialize_node_reference(*param));
}
return json_parameters;
}
json JSONSerializer::serialize_function(const Function& f) json JSONSerializer::serialize_function(const Function& f)
{ {
json function; json function;
function["name"] = f.get_name(); function["name"] = f.get_name();
function["parameters"] = serialize_parameter_vector(f.get_parameters());
vector<string> parameter_list;
for (auto param : f.get_parameters())
{
parameter_list.push_back(serialize_node_reference(*param));
}
function["parameters"] = parameter_list;
// TODO Functions can return multiple results // TODO Functions can return multiple results
for (size_t i = 0; i < f.get_output_size(); ++i) for (size_t i = 0; i < f.get_output_size(); ++i)
...@@ -520,7 +547,7 @@ json JSONSerializer::serialize_function(const Function& f) ...@@ -520,7 +547,7 @@ json JSONSerializer::serialize_function(const Function& f)
} }
template <typename T> template <typename T>
T get_value(nlohmann::json js, const string& key) T get_value(json js, const string& key)
{ {
T rc; T rc;
auto it = js.find(key); auto it = js.find(key);
...@@ -531,13 +558,13 @@ T get_value(nlohmann::json js, const string& key) ...@@ -531,13 +558,13 @@ T get_value(nlohmann::json js, const string& key)
return rc; return rc;
} }
shared_ptr<Node> JSONDeserializer::deserialize_node_reference(json& j) shared_ptr<Node> JSONDeserializer::deserialize_node_reference(json j)
{ {
const string& name = j; const string& name = j;
return m_node_map.at(name); return m_node_map.at(name);
} }
Output<Node> JSONDeserializer::deserialize_output(json& j) Output<Node> JSONDeserializer::deserialize_output(json j)
{ {
size_t index; size_t index;
json json_node_reference; json json_node_reference;
...@@ -558,10 +585,48 @@ Output<Node> JSONDeserializer::deserialize_output(json& j) ...@@ -558,10 +585,48 @@ Output<Node> JSONDeserializer::deserialize_output(json& j)
return Output<Node>(deserialize_node_reference(json_node_reference), index); return Output<Node>(deserialize_node_reference(json_node_reference), index);
} }
shared_ptr<Function> JSONDeserializer::deserialize_function(json& func_js) OutputVector JSONDeserializer::deserialize_output_vector(json j)
{
OutputVector result;
if (j.is_array())
{
for (json jelt : j)
{
result.push_back(deserialize_output(jelt));
}
}
return result;
}
json JSONSerializer::serialize_axis_set(const AxisSet& axis_set)
{
return static_cast<set<size_t>>(axis_set);
}
AxisSet JSONDeserializer::deserialize_axis_set(json j)
{
AxisSet result;
if (j.is_array())
{
result = j.get<set<size_t>>();
}
return result;
}
ParameterVector JSONDeserializer::deserialize_parameter_vector(json json_parameters)
{
std::vector<std::shared_ptr<op::Parameter>> params;
for (auto& param_ref : json_parameters)
{
params.push_back(
dynamic_pointer_cast<op::Parameter>(deserialize_node_reference(param_ref)));
}
return params;
}
shared_ptr<Function> JSONDeserializer::deserialize_function(json func_js)
{ {
string func_name = func_js.at("name").get<string>(); string func_name = func_js.at("name").get<string>();
vector<json> func_parameters = func_js.at("parameters");
vector<json> func_result = func_js.at("result"); vector<json> func_result = func_js.at("result");
for (json node_js : func_js.at("ops")) for (json node_js : func_js.at("ops"))
{ {
...@@ -593,12 +658,7 @@ shared_ptr<Function> JSONDeserializer::deserialize_function(json& func_js) ...@@ -593,12 +658,7 @@ shared_ptr<Function> JSONDeserializer::deserialize_function(json& func_js)
"Graph serialization is inconsistent. Some op::Results appear to be missing"); "Graph serialization is inconsistent. Some op::Results appear to be missing");
} }
std::vector<std::shared_ptr<op::Parameter>> params; ParameterVector params = deserialize_parameter_vector(func_js.at("parameters"));
for (auto& param_ref : func_parameters)
{
params.push_back(
dynamic_pointer_cast<op::Parameter>(deserialize_node_reference(param_ref)));
}
shared_ptr<Function> rc{make_shared<Function>(result, params, func_name)}; shared_ptr<Function> rc{make_shared<Function>(result, params, func_name)};
m_function_map[func_name] = rc; m_function_map[func_name] = rc;
...@@ -631,7 +691,12 @@ struct OutputHelper ...@@ -631,7 +691,12 @@ struct OutputHelper
// when all op constructors use the new style arguments. // when all op constructors use the new style arguments.
struct OutputVectorHelper struct OutputVectorHelper
{ {
const OutputHelper& operator[](size_t i) const { return m_vector[i]; } OutputVectorHelper(const OutputVector& output_vector)
: m_vector(output_vector)
{
}
OutputVectorHelper() = default;
OutputHelper operator[](size_t i) const { return OutputHelper(m_vector[i]); }
void push_back(const Output<Node>& output) { m_vector.push_back(output); } void push_back(const Output<Node>& output) { m_vector.push_back(output); }
size_t size() const { return m_vector.size(); } size_t size() const { return m_vector.size(); }
operator vector<shared_ptr<Node>>() const operator vector<shared_ptr<Node>>() const
...@@ -639,14 +704,15 @@ struct OutputVectorHelper ...@@ -639,14 +704,15 @@ struct OutputVectorHelper
vector<shared_ptr<Node>> result; vector<shared_ptr<Node>> result;
for (auto& o : m_vector) for (auto& o : m_vector)
{ {
result.push_back(o); result.push_back(OutputHelper(o));
} }
return result; return result;
} }
vector<OutputHelper> m_vector; operator const OutputVector&() const { return m_vector; }
OutputVector m_vector;
}; };
shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
{ {
shared_ptr<Node> node; shared_ptr<Node> node;
try try
...@@ -654,14 +720,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -654,14 +720,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
string node_name = node_js.at("name").get<string>(); string node_name = node_js.at("name").get<string>();
string node_op = node_js.at("op").get<string>(); string node_op = node_js.at("op").get<string>();
string friendly_name = get_value<string>(node_js, "friendly_name"); string friendly_name = get_value<string>(node_js, "friendly_name");
vector<json> node_inputs = get_value<vector<json>>(node_js, "inputs");
vector<json> control_deps_inputs = get_value<vector<json>>(node_js, "control_deps"); vector<json> control_deps_inputs = get_value<vector<json>>(node_js, "control_deps");
vector<string> node_outputs = get_value<vector<string>>(node_js, "outputs"); vector<string> node_outputs = get_value<vector<string>>(node_js, "outputs");
OutputVectorHelper args; OutputVectorHelper args(deserialize_output_vector(node_js["inputs"]));
for (auto& node_input : node_inputs)
{
args.push_back(deserialize_output(node_input));
}
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8) #if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push #pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch" #pragma GCC diagnostic error "-Wswitch"
...@@ -682,12 +743,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -682,12 +743,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
} }
case OP_TYPEID::Add: case OP_TYPEID::Add:
{ {
node = make_shared<op::Add>(args[0], args[1], read_auto_broadcast(node_js["autob"])); node = make_shared<op::Add>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break; break;
} }
case OP_TYPEID::All: case OP_TYPEID::All:
{ {
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>(); auto reduction_axes = deserialize_axis_set(node_js.at("reduction_axes"));
node = make_shared<op::All>(args[0], reduction_axes); node = make_shared<op::All>(args[0], reduction_axes);
break; break;
} }
...@@ -698,12 +759,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -698,12 +759,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
} }
case OP_TYPEID::And: case OP_TYPEID::And:
{ {
node = make_shared<op::And>(args[0], args[1], read_auto_broadcast(node_js["autob"])); node = make_shared<op::And>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break; break;
} }
case OP_TYPEID::Any: case OP_TYPEID::Any:
{ {
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>(); auto reduction_axes = deserialize_axis_set(node_js.at("reduction_axes"));
node = make_shared<op::Any>(args[0], reduction_axes); node = make_shared<op::Any>(args[0], reduction_axes);
break; break;
} }
...@@ -740,12 +801,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -740,12 +801,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
auto padding_above = node_js.at("padding_above").get<vector<size_t>>(); auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
auto include_padding_in_avg_computation = auto include_padding_in_avg_computation =
node_js.at("include_padding_in_avg_computation").get<bool>(); node_js.at("include_padding_in_avg_computation").get<bool>();
op::PadType pad_type = node_js["pad_type"].empty() op::PadType pad_type = read_pad_type(node_js);
? op::PadType::EXPLICIT bool ceil_mode = get_or_default<bool>(node_js, "ceil_mode", false);
: static_cast<op::PadType>(node_js.at("pad_type"));
bool ceil_mode =
node_js["ceil_mode"].empty() ? false : node_js.at("ceil_mode").get<bool>();
;
node = make_shared<op::AvgPool>(args[0], node = make_shared<op::AvgPool>(args[0],
window_shape, window_shape,
window_movement_strides, window_movement_strides,
...@@ -807,7 +864,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -807,7 +864,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case OP_TYPEID::Broadcast: case OP_TYPEID::Broadcast:
{ {
auto shape = node_js.at("shape").get<vector<size_t>>(); auto shape = node_js.at("shape").get<vector<size_t>>();
auto axes = node_js.at("axes").get<set<size_t>>(); auto axes = deserialize_axis_set(node_js.at("axes"));
node = make_shared<op::Broadcast>(args[0], shape, axes); node = make_shared<op::Broadcast>(args[0], shape, axes);
break; break;
} }
...@@ -818,7 +875,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -818,7 +875,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
} }
case OP_TYPEID::BroadcastLike: case OP_TYPEID::BroadcastLike:
{ {
auto initial_axes = node_js.at("initial_axes").get<set<size_t>>(); auto initial_axes = deserialize_axis_set(node_js.at("initial_axes"));
node = make_shared<op::BroadcastLike>(args[0], args[1], initial_axes); node = make_shared<op::BroadcastLike>(args[0], args[1], initial_axes);
break; break;
} }
...@@ -837,13 +894,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -837,13 +894,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case OP_TYPEID::Concat: case OP_TYPEID::Concat:
{ {
auto axis = node_js.at("axis").get<size_t>(); auto axis = node_js.at("axis").get<size_t>();
node = make_shared<op::Concat>(args, axis); node = make_shared<op::Concat>(static_cast<OutputVector>(args), axis);
break; break;
} }
case OP_TYPEID::Constant: case OP_TYPEID::Constant:
{ {
auto type_node_js = auto type_node_js =
node_js.count("element_type") == 0 ? node_js.at("value_type") : node_js; has_key(node_js, "element_type") ? node_js : node_js.at("value_type");
auto element_type = read_element_type(type_node_js.at("element_type")); auto element_type = read_element_type(type_node_js.at("element_type"));
auto shape = type_node_js.at("shape"); auto shape = type_node_js.at("shape");
auto value = node_js.at("value").get<vector<string>>(); auto value = node_js.at("value").get<vector<string>>();
...@@ -867,17 +924,19 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -867,17 +924,19 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
// For backwards compatibility, we accept "image_dilation_strides" in place of // For backwards compatibility, we accept "image_dilation_strides" in place of
// "data_dilation_strides", and we also allow it to be omitted altogether. // "data_dilation_strides", and we also allow it to be omitted altogether.
auto data_dilation_strides_maybe = node_js["data_dilation_strides"]; json data_dilation_strides;
if (data_dilation_strides_maybe.empty()) if (has_key(node_js, "data_dilation_strides"))
{
data_dilation_strides = node_js["data_dilation_strides"];
}
else if (has_key(node_js, "image_dilation_strides"))
{ {
data_dilation_strides_maybe = node_js["image_dilation_strides"]; data_dilation_strides = node_js["image_dilation_strides"];
} }
op::PadType pad_type = node_js["pad_type"].empty() op::PadType pad_type = read_pad_type(node_js);
? op::PadType::EXPLICIT
: static_cast<op::PadType>(node_js.at("pad_type"));
if (data_dilation_strides_maybe.empty()) if (data_dilation_strides.empty())
{ {
node = make_shared<op::Convolution>(args[0], node = make_shared<op::Convolution>(args[0],
args[1], args[1],
...@@ -888,14 +947,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -888,14 +947,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
} }
else else
{ {
node = make_shared<op::Convolution>( node =
args[0], make_shared<op::Convolution>(args[0],
args[1], args[1],
window_movement_strides, window_movement_strides,
window_dilation_strides, window_dilation_strides,
padding_below, padding_below,
padding_above, padding_above,
data_dilation_strides_maybe.get<std::vector<size_t>>(), data_dilation_strides.get<std::vector<size_t>>(),
pad_type); pad_type);
} }
break; break;
...@@ -1032,33 +1091,28 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1032,33 +1091,28 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case OP_TYPEID::Dequantize: case OP_TYPEID::Dequantize:
{ {
auto type = read_element_type(node_js.at("type")); auto type = read_element_type(node_js.at("type"));
auto axes = node_js.at("axes").get<set<size_t>>(); auto axes = deserialize_axis_set(node_js.at("axes"));
node = make_shared<op::Dequantize>(args[0], args[1], args[2], type, axes); node = make_shared<op::Dequantize>(args[0], args[1], args[2], type, axes);
break; break;
} }
case OP_TYPEID::Divide: case OP_TYPEID::Divide:
{ {
bool pythondiv = true; bool pythondiv = get_or_default(node_js, "pythondiv", true);
if (node_js["pythondiv"].is_object())
{
pythondiv = node_js.at("pythondiv").get<bool>();
}
node = make_shared<op::Divide>( node = make_shared<op::Divide>(
args[0], args[1], pythondiv, read_auto_broadcast(node_js["autob"])); args[0], args[1], pythondiv, read_auto_broadcast(node_js, "autob"));
break; break;
} }
case OP_TYPEID::Dot: case OP_TYPEID::Dot:
{ {
// For backwards compatibility, reduction_axes_count is optional. // For backwards compatibility, reduction_axes_count is optional.
auto obj = node_js["reduction_axes_count"]; if (has_key(node_js, "reduction_axes_count"))
if (obj.empty())
{ {
node = make_shared<op::Dot>(args[0], args[1]); size_t reduction_axes_count = node_js["reduction_axes_count"].get<size_t>();
node = make_shared<op::Dot>(args[0], args[1], reduction_axes_count);
} }
else else
{ {
size_t reduction_axes_count = obj.get<size_t>(); node = make_shared<op::Dot>(args[0], args[1]);
node = make_shared<op::Dot>(args[0], args[1], reduction_axes_count);
} }
break; break;
} }
...@@ -1094,7 +1148,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1094,7 +1148,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
} }
case OP_TYPEID::Equal: case OP_TYPEID::Equal:
{ {
node = make_shared<op::Equal>(args[0], args[1], read_auto_broadcast(node_js["autob"])); node = make_shared<op::Equal>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break; break;
} }
case OP_TYPEID::Erf: case OP_TYPEID::Erf:
...@@ -1159,13 +1213,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1159,13 +1213,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case OP_TYPEID::Greater: case OP_TYPEID::Greater:
{ {
node = node =
make_shared<op::Greater>(args[0], args[1], read_auto_broadcast(node_js["autob"])); make_shared<op::Greater>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break; break;
} }
case OP_TYPEID::GreaterEq: case OP_TYPEID::GreaterEq:
{ {
node = node =
make_shared<op::GreaterEq>(args[0], args[1], read_auto_broadcast(node_js["autob"])); make_shared<op::GreaterEq>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break; break;
} }
case OP_TYPEID::GRN: case OP_TYPEID::GRN:
...@@ -1192,10 +1246,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1192,10 +1246,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
auto data_dilation_strides = node_js.at("data_dilation_strides").get<vector<size_t>>(); auto data_dilation_strides = node_js.at("data_dilation_strides").get<vector<size_t>>();
auto groups = node_js.at("groups").get<size_t>(); auto groups = node_js.at("groups").get<size_t>();
op::PadType pad_type = node_js["pad_type"].empty() op::PadType pad_type = read_pad_type(node_js);
? op::PadType::EXPLICIT
: static_cast<op::PadType>(node_js.at("pad_type"));
node = make_shared<op::GroupConvolution>(args[0], node = make_shared<op::GroupConvolution>(args[0],
args[1], args[1],
window_movement_strides, window_movement_strides,
...@@ -1215,9 +1266,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1215,9 +1266,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
auto padding_end = node_js.at("padding_end").get<vector<ptrdiff_t>>(); auto padding_end = node_js.at("padding_end").get<vector<ptrdiff_t>>();
auto output_padding = node_js.at("output_padding").get<vector<ptrdiff_t>>(); auto output_padding = node_js.at("output_padding").get<vector<ptrdiff_t>>();
auto groups = node_js.at("groups").get<size_t>(); auto groups = node_js.at("groups").get<size_t>();
op::PadType pad_type = node_js["pad_type"].empty() op::PadType pad_type = read_pad_type(node_js);
? op::PadType::EXPLICIT
: static_cast<op::PadType>(node_js.at("pad_type"));
auto output_shape = node_js.at("output_shape").get<vector<size_t>>(); auto output_shape = node_js.at("output_shape").get<vector<size_t>>();
node = make_shared<op::GroupConvolutionTranspose>(args[0], node = make_shared<op::GroupConvolutionTranspose>(args[0],
...@@ -1239,12 +1288,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1239,12 +1288,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
} }
case OP_TYPEID::Less: case OP_TYPEID::Less:
{ {
node = make_shared<op::Less>(args[0], args[1], read_auto_broadcast(node_js["autob"])); node = make_shared<op::Less>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break; break;
} }
case OP_TYPEID::LessEq: case OP_TYPEID::LessEq:
{ {
node = make_shared<op::LessEq>(args[0], args[1], read_auto_broadcast(node_js["autob"])); node = make_shared<op::LessEq>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break; break;
} }
case OP_TYPEID::Log: case OP_TYPEID::Log:
...@@ -1286,7 +1335,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1286,7 +1335,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
} }
case OP_TYPEID::Max: case OP_TYPEID::Max:
{ {
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>(); auto reduction_axes = deserialize_axis_set(node_js.at("reduction_axes"));
node = make_shared<op::Max>(args[0], reduction_axes); node = make_shared<op::Max>(args[0], reduction_axes);
break; break;
} }
...@@ -1297,11 +1346,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1297,11 +1346,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
node_js.at("window_movement_strides").get<vector<size_t>>(); node_js.at("window_movement_strides").get<vector<size_t>>();
// For backwards compatibility, both (but not just one) of the padding_ fields may be // For backwards compatibility, both (but not just one) of the padding_ fields may be
// omitted. // omitted.
auto padding_below_maybe = node_js["padding_below"]; auto padding_below_maybe = get_or_default(node_js, "padding_below", json{});
auto padding_above_maybe = node_js["padding_above"]; auto padding_above_maybe = get_or_default(node_js, "padding_above", json{});
op::PadType pad_type = node_js["pad_type"].empty() op::PadType pad_type = read_pad_type(node_js);
? op::PadType::EXPLICIT
: static_cast<op::PadType>(node_js.at("pad_type"));
if (padding_below_maybe.empty() && !padding_above_maybe.empty()) if (padding_below_maybe.empty() && !padding_above_maybe.empty())
{ {
throw runtime_error( throw runtime_error(
...@@ -1360,31 +1407,31 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1360,31 +1407,31 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case OP_TYPEID::Maximum: case OP_TYPEID::Maximum:
{ {
node = node =
make_shared<op::Maximum>(args[0], args[1], read_auto_broadcast(node_js["autob"])); make_shared<op::Maximum>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break; break;
} }
case OP_TYPEID::Min: case OP_TYPEID::Min:
{ {
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>(); auto reduction_axes = deserialize_axis_set(node_js.at("reduction_axes"));
node = make_shared<op::Min>(args[0], reduction_axes); node = make_shared<op::Min>(args[0], reduction_axes);
break; break;
} }
case OP_TYPEID::Minimum: case OP_TYPEID::Minimum:
{ {
node = node =
make_shared<op::Minimum>(args[0], args[1], read_auto_broadcast(node_js["autob"])); make_shared<op::Minimum>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break; break;
} }
case OP_TYPEID::Multiply: case OP_TYPEID::Multiply:
{ {
node = node =
make_shared<op::Multiply>(args[0], args[1], read_auto_broadcast(node_js["autob"])); make_shared<op::Multiply>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break; break;
} }
case OP_TYPEID::MVN: case OP_TYPEID::MVN:
{ {
auto normalize_variance = node_js.at("normalize_variance").get<bool>(); auto normalize_variance = node_js.at("normalize_variance").get<bool>();
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>(); auto reduction_axes = deserialize_axis_set(node_js.at("reduction_axes"));
auto eps = node_js.at("eps").get<double>(); auto eps = node_js.at("eps").get<double>();
node = make_shared<op::MVN>(args[0], normalize_variance, normalize_variance, eps); node = make_shared<op::MVN>(args[0], normalize_variance, normalize_variance, eps);
break; break;
...@@ -1406,7 +1453,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1406,7 +1453,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case OP_TYPEID::NotEqual: case OP_TYPEID::NotEqual:
{ {
node = node =
make_shared<op::NotEqual>(args[0], args[1], read_auto_broadcast(node_js["autob"])); make_shared<op::NotEqual>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break; break;
} }
case OP_TYPEID::Not: case OP_TYPEID::Not:
...@@ -1423,7 +1470,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1423,7 +1470,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
} }
case OP_TYPEID::Or: case OP_TYPEID::Or:
{ {
node = make_shared<op::Or>(args[0], args[1], read_auto_broadcast(node_js["autob"])); node = make_shared<op::Or>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break; break;
} }
case OP_TYPEID::Pad: case OP_TYPEID::Pad:
...@@ -1440,9 +1487,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1440,9 +1487,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
[](size_t s) { return s == 0; }), [](size_t s) { return s == 0; }),
"Legacy padding_interior field must be zero everywhere."); "Legacy padding_interior field must be zero everywhere.");
auto pad_mode = node_js.count("pad_mode") == 0 auto pad_mode = read_pad_mode(node_js);
? op::PadMode::CONSTANT
: static_cast<op::PadMode>(node_js.at("pad_mode"));
node = make_shared<op::Pad>(args[0], args[1], padding_below, padding_above, pad_mode); node = make_shared<op::Pad>(args[0], args[1], padding_below, padding_above, pad_mode);
break; break;
...@@ -1450,7 +1495,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1450,7 +1495,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case OP_TYPEID::Parameter: case OP_TYPEID::Parameter:
{ {
auto type_node_js = auto type_node_js =
node_js.count("element_type") == 0 ? node_js.at("value_type") : node_js; has_key(node_js, "element_type") ? node_js : node_js.at("value_type");
auto element_type = read_element_type(type_node_js.at("element_type")); auto element_type = read_element_type(type_node_js.at("element_type"));
auto shape = type_node_js.at("shape"); auto shape = type_node_js.at("shape");
auto cacheable = get_or_default<bool>(node_js, "cacheable", false); auto cacheable = get_or_default<bool>(node_js, "cacheable", false);
...@@ -1475,7 +1520,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1475,7 +1520,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
} }
case OP_TYPEID::Power: case OP_TYPEID::Power:
{ {
node = make_shared<op::Power>(args[0], args[1], read_auto_broadcast(node_js["autob"])); node = make_shared<op::Power>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break; break;
} }
case OP_TYPEID::PRelu: case OP_TYPEID::PRelu:
...@@ -1485,14 +1530,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1485,14 +1530,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
} }
case OP_TYPEID::Product: case OP_TYPEID::Product:
{ {
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>(); auto reduction_axes = deserialize_axis_set(node_js.at("reduction_axes"));
node = make_shared<op::Product>(args[0], reduction_axes); node = make_shared<op::Product>(args[0], reduction_axes);
break; break;
} }
case OP_TYPEID::Quantize: case OP_TYPEID::Quantize:
{ {
auto type = read_element_type(node_js.at("type")); auto type = read_element_type(node_js.at("type"));
auto axes = node_js.at("axes").get<set<size_t>>(); auto axes = deserialize_axis_set(node_js.at("axes"));
auto round_mode = node_js.at("round_mode").get<op::Quantize::RoundMode>(); auto round_mode = node_js.at("round_mode").get<op::Quantize::RoundMode>();
node = make_shared<op::Quantize>(args[0], args[1], args[2], type, axes, round_mode); node = make_shared<op::Quantize>(args[0], args[1], args[2], type, axes, round_mode);
break; break;
...@@ -1551,8 +1596,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1551,8 +1596,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
node_js.at("window_movement_strides").get<vector<size_t>>(); node_js.at("window_movement_strides").get<vector<size_t>>();
// For backwards compatibility, both (but not just one) of the padding_ fields may be // For backwards compatibility, both (but not just one) of the padding_ fields may be
// omitted. // omitted.
auto padding_below_maybe = node_js["padding_below"]; auto padding_below_maybe = get_or_default(node_js, "padding_below", json{});
auto padding_above_maybe = node_js["padding_above"]; auto padding_above_maybe = get_or_default(node_js, "padding_above", json{});
auto padding_below = padding_below_maybe.get<vector<size_t>>(); auto padding_below = padding_below_maybe.get<vector<size_t>>();
auto padding_above = padding_above_maybe.get<vector<size_t>>(); auto padding_above = padding_above_maybe.get<vector<size_t>>();
node = make_shared<op::QuantizedMaxPool>( node = make_shared<op::QuantizedMaxPool>(
...@@ -1600,7 +1645,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1600,7 +1645,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
} }
case OP_TYPEID::Reverse: case OP_TYPEID::Reverse:
{ {
auto reversed_axes = node_js.at("reversed_axes").get<set<size_t>>(); auto reversed_axes = deserialize_axis_set(node_js.at("reversed_axes"));
node = make_shared<op::Reverse>(args[0], reversed_axes); node = make_shared<op::Reverse>(args[0], reversed_axes);
break; break;
} }
...@@ -1684,7 +1729,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1684,7 +1729,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
} }
case OP_TYPEID::Softmax: case OP_TYPEID::Softmax:
{ {
auto softmax_axes = node_js.at("softmax_axes").get<set<size_t>>(); auto softmax_axes = deserialize_axis_set(node_js.at("softmax_axes"));
node = make_shared<op::Softmax>(args[0], softmax_axes); node = make_shared<op::Softmax>(args[0], softmax_axes);
break; break;
} }
...@@ -1719,12 +1764,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1719,12 +1764,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case OP_TYPEID::Subtract: case OP_TYPEID::Subtract:
{ {
node = node =
make_shared<op::Subtract>(args[0], args[1], read_auto_broadcast(node_js["autob"])); make_shared<op::Subtract>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break; break;
} }
case OP_TYPEID::Sum: case OP_TYPEID::Sum:
{ {
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>(); auto reduction_axes = deserialize_axis_set(node_js.at("reduction_axes"));
node = make_shared<op::Sum>(args[0], reduction_axes); node = make_shared<op::Sum>(args[0], reduction_axes);
break; break;
} }
...@@ -1860,6 +1905,16 @@ json JSONSerializer::serialize_output(const Output<Node>& output) ...@@ -1860,6 +1905,16 @@ json JSONSerializer::serialize_output(const Output<Node>& output)
return result; return result;
} }
json JSONSerializer::serialize_output_vector(const OutputVector& output_vector)
{
json result;
for (const Output<Node>& output : output_vector)
{
result.push_back(serialize_output(output));
}
return result;
}
json JSONSerializer::serialize_node(const Node& n) json JSONSerializer::serialize_node(const Node& n)
{ {
m_nodes_serialized.insert(&n); m_nodes_serialized.insert(&n);
...@@ -1959,7 +2014,7 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -1959,7 +2014,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::All: case OP_TYPEID::All:
{ {
auto tmp = dynamic_cast<const op::All*>(&n); auto tmp = dynamic_cast<const op::All*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes(); node["reduction_axes"] = serialize_axis_set(tmp->get_reduction_axes());
break; break;
} }
case OP_TYPEID::AllReduce: { break; case OP_TYPEID::AllReduce: { break;
...@@ -1976,7 +2031,7 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -1976,7 +2031,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::Any: case OP_TYPEID::Any:
{ {
auto tmp = dynamic_cast<const op::Any*>(&n); auto tmp = dynamic_cast<const op::Any*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes(); node["reduction_axes"] = serialize_axis_set(tmp->get_reduction_axes());
break; break;
} }
case OP_TYPEID::Asin: { break; case OP_TYPEID::Asin: { break;
...@@ -2032,7 +2087,7 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -2032,7 +2087,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::Broadcast: case OP_TYPEID::Broadcast:
{ {
auto tmp = dynamic_cast<const op::Broadcast*>(&n); auto tmp = dynamic_cast<const op::Broadcast*>(&n);
node["axes"] = tmp->get_broadcast_axes(); node["axes"] = serialize_axis_set(tmp->get_broadcast_axes());
node["shape"] = tmp->get_broadcast_shape(); node["shape"] = tmp->get_broadcast_shape();
break; break;
} }
...@@ -2041,7 +2096,7 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -2041,7 +2096,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::BroadcastLike: case OP_TYPEID::BroadcastLike:
{ {
auto tmp = dynamic_cast<const op::BroadcastLike*>(&n); auto tmp = dynamic_cast<const op::BroadcastLike*>(&n);
node["initial_axes"] = tmp->get_initial_broadcast_axes(); node["initial_axes"] = serialize_axis_set(tmp->get_initial_broadcast_axes());
break; break;
} }
case OP_TYPEID::Ceiling: { break; case OP_TYPEID::Ceiling: { break;
...@@ -2155,7 +2210,7 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -2155,7 +2210,7 @@ json JSONSerializer::serialize_node(const Node& n)
{ {
auto tmp = dynamic_cast<const op::Dequantize*>(&n); auto tmp = dynamic_cast<const op::Dequantize*>(&n);
node["type"] = write_element_type(tmp->get_element_type()); node["type"] = write_element_type(tmp->get_element_type());
node["axes"] = tmp->get_axes(); node["axes"] = serialize_axis_set(tmp->get_axes());
break; break;
} }
case OP_TYPEID::DepthToSpace: case OP_TYPEID::DepthToSpace:
...@@ -2348,7 +2403,7 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -2348,7 +2403,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::Max: case OP_TYPEID::Max:
{ {
auto tmp = dynamic_cast<const op::Max*>(&n); auto tmp = dynamic_cast<const op::Max*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes(); node["reduction_axes"] = serialize_axis_set(tmp->get_reduction_axes());
break; break;
} }
case OP_TYPEID::MaxPool: case OP_TYPEID::MaxPool:
...@@ -2382,7 +2437,7 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -2382,7 +2437,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::Min: case OP_TYPEID::Min:
{ {
auto tmp = dynamic_cast<const op::Min*>(&n); auto tmp = dynamic_cast<const op::Min*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes(); node["reduction_axes"] = serialize_axis_set(tmp->get_reduction_axes());
break; break;
} }
case OP_TYPEID::Minimum: case OP_TYPEID::Minimum:
...@@ -2406,7 +2461,7 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -2406,7 +2461,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::MVN: case OP_TYPEID::MVN:
{ {
auto tmp = dynamic_cast<const op::MVN*>(&n); auto tmp = dynamic_cast<const op::MVN*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes(); node["reduction_axes"] = serialize_axis_set(tmp->get_reduction_axes());
node["normalize_variance"] = tmp->get_normalize_variance(); node["normalize_variance"] = tmp->get_normalize_variance();
node["eps"] = tmp->get_eps(); node["eps"] = tmp->get_eps();
break; break;
...@@ -2486,7 +2541,7 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -2486,7 +2541,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::Product: case OP_TYPEID::Product:
{ {
auto tmp = dynamic_cast<const op::Product*>(&n); auto tmp = dynamic_cast<const op::Product*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes(); node["reduction_axes"] = serialize_axis_set(tmp->get_reduction_axes());
break; break;
} }
case OP_TYPEID::Power: case OP_TYPEID::Power:
...@@ -2502,7 +2557,7 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -2502,7 +2557,7 @@ json JSONSerializer::serialize_node(const Node& n)
{ {
auto tmp = dynamic_cast<const op::Quantize*>(&n); auto tmp = dynamic_cast<const op::Quantize*>(&n);
node["type"] = write_element_type(tmp->get_element_type()); node["type"] = write_element_type(tmp->get_element_type());
node["axes"] = tmp->get_axes(); node["axes"] = serialize_axis_set(tmp->get_axes());
node["round_mode"] = tmp->get_round_mode(); node["round_mode"] = tmp->get_round_mode();
break; break;
} }
...@@ -2577,7 +2632,7 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -2577,7 +2632,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::Reverse: case OP_TYPEID::Reverse:
{ {
auto tmp = dynamic_cast<const op::Reverse*>(&n); auto tmp = dynamic_cast<const op::Reverse*>(&n);
node["reversed_axes"] = tmp->get_reversed_axes(); node["reversed_axes"] = serialize_axis_set(tmp->get_reversed_axes());
break; break;
} }
case OP_TYPEID::ReverseSequence: case OP_TYPEID::ReverseSequence:
...@@ -2664,13 +2719,13 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -2664,13 +2719,13 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::Sum: case OP_TYPEID::Sum:
{ {
auto tmp = dynamic_cast<const op::Sum*>(&n); auto tmp = dynamic_cast<const op::Sum*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes(); node["reduction_axes"] = serialize_axis_set(tmp->get_reduction_axes());
break; break;
} }
case OP_TYPEID::Softmax: case OP_TYPEID::Softmax:
{ {
auto tmp = dynamic_cast<const op::Softmax*>(&n); auto tmp = dynamic_cast<const op::Softmax*>(&n);
node["softmax_axes"] = tmp->get_axes(); node["softmax_axes"] = serialize_axis_set(tmp->get_axes());
break; break;
} }
case OP_TYPEID::Tan: { break; case OP_TYPEID::Tan: { break;
......
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "EyeLike"
attribute {
name: "k"
i: -1
type: INT
}
}
name: "hardmax_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
version: 9
}
...@@ -1482,3 +1482,46 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_shrink_int) ...@@ -1482,3 +1482,46 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_shrink_int)
test_case.run(); test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_eye_like)
{
const auto eye_like_fn = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/eye_like.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(eye_like_fn, "${BACKEND_NAME}");
test_case.add_input<float>({
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
});
test_case.add_expected_output<float>(Shape{3, 4},
{
0.f,
0.f,
0.f,
0.f,
1.f,
0.f,
0.f,
0.f,
0.f,
1.f,
0.f,
0.f,
});
test_case.run();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment