Commit fe18f19a authored by fenglei's avatar fenglei

Merge branch 'tfl/send_recv_op' of github.com:NervanaSystems/ngraph into tfl/send_recv_op

parents 8f7f2aec 72df5da0
...@@ -37,6 +37,7 @@ NodeInput = Union[Node, NumericData] ...@@ -37,6 +37,7 @@ NodeInput = Union[Node, NumericData]
ngraph_to_numpy_types_map = [ ngraph_to_numpy_types_map = [
(NgraphType.boolean, np.bool), (NgraphType.boolean, np.bool),
(NgraphType.f16, np.float16),
(NgraphType.f32, np.float32), (NgraphType.f32, np.float32),
(NgraphType.f64, np.float64), (NgraphType.f64, np.float64),
(NgraphType.i8, np.int8), (NgraphType.i8, np.int8),
......
...@@ -28,6 +28,7 @@ void regclass_pyngraph_Type(py::module m) ...@@ -28,6 +28,7 @@ void regclass_pyngraph_Type(py::module m)
py::class_<ngraph::element::Type, std::shared_ptr<ngraph::element::Type>> type(m, "Type"); py::class_<ngraph::element::Type, std::shared_ptr<ngraph::element::Type>> type(m, "Type");
type.doc() = "ngraph.impl.Type wraps ngraph::element::Type"; type.doc() = "ngraph.impl.Type wraps ngraph::element::Type";
type.attr("boolean") = ngraph::element::boolean; type.attr("boolean") = ngraph::element::boolean;
type.attr("f16") = ngraph::element::f16;
type.attr("f32") = ngraph::element::f32; type.attr("f32") = ngraph::element::f32;
type.attr("f64") = ngraph::element::f64; type.attr("f64") = ngraph::element::f64;
type.attr("i8") = ngraph::element::i8; type.attr("i8") = ngraph::element::i8;
......
...@@ -85,6 +85,8 @@ add_library(onnx_import STATIC ...@@ -85,6 +85,8 @@ add_library(onnx_import STATIC
op/equal.hpp op/equal.hpp
op/erf.hpp op/erf.hpp
op/exp.hpp op/exp.hpp
op/eye_like.cpp
op/eye_like.hpp
op/flatten.cpp op/flatten.cpp
op/flatten.hpp op/flatten.hpp
op/floor.hpp op/floor.hpp
...@@ -191,6 +193,7 @@ add_library(onnx_import STATIC ...@@ -191,6 +193,7 @@ add_library(onnx_import STATIC
op/xor.hpp op/xor.hpp
ops_bridge.cpp ops_bridge.cpp
ops_bridge.hpp ops_bridge.hpp
utils/common.cpp
utils/common.hpp utils/common.hpp
utils/convpool.cpp utils/convpool.cpp
utils/convpool.hpp utils/convpool.hpp
......
...@@ -52,6 +52,8 @@ namespace ngraph ...@@ -52,6 +52,8 @@ namespace ngraph
const std::string& output(int index) const; const std::string& output(int index) const;
std::size_t get_outputs_size() const; std::size_t get_outputs_size() const;
bool has_attribute(const std::string& name) const;
template <typename T> template <typename T>
T get_attribute_value(const std::string& name, T default_value) const; T get_attribute_value(const std::string& name, T default_value) const;
...@@ -87,6 +89,15 @@ namespace ngraph ...@@ -87,6 +89,15 @@ namespace ngraph
} }
std::size_t Node::Impl::get_outputs_size() const { return m_output_names.size(); } std::size_t Node::Impl::get_outputs_size() const { return m_output_names.size(); }
bool Node::Impl::has_attribute(const std::string& name) const
{
auto it = std::find_if(
std::begin(m_attributes), std::end(m_attributes), [&](const Attribute& attribute) {
return attribute.get_name() == name;
});
return it != std::end(m_attributes);
}
template <typename T> template <typename T>
T Node::Impl::get_attribute_value(const std::string& name, T default_value) const T Node::Impl::get_attribute_value(const std::string& name, T default_value) const
{ {
...@@ -185,6 +196,11 @@ namespace ngraph ...@@ -185,6 +196,11 @@ namespace ngraph
const std::string& Node::output(int index) const { return m_pimpl->output(index); } const std::string& Node::output(int index) const { return m_pimpl->output(index); }
std::size_t Node::get_outputs_size() const { return m_pimpl->get_outputs_size(); } std::size_t Node::get_outputs_size() const { return m_pimpl->get_outputs_size(); }
bool Node::has_attribute(const std::string& name) const
{
return m_pimpl->has_attribute(name);
}
template <> template <>
float Node::get_attribute_value(const std::string& name, float default_value) const float Node::get_attribute_value(const std::string& name, float default_value) const
{ {
......
...@@ -78,6 +78,8 @@ namespace ngraph ...@@ -78,6 +78,8 @@ namespace ngraph
const std::string& output(int index) const; const std::string& output(int index) const;
std::size_t get_outputs_size() const; std::size_t get_outputs_size() const;
bool has_attribute(const std::string& name) const;
template <typename T> template <typename T>
T get_attribute_value(const std::string& name, T default_value) const; T get_attribute_value(const std::string& name, T default_value) const;
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "node.hpp" #include "node.hpp"
#include "tensor.hpp" #include "tensor.hpp"
#include "utils/common.hpp"
#include "weight.hpp" #include "weight.hpp"
namespace ngraph namespace ngraph
...@@ -41,17 +42,8 @@ namespace ngraph ...@@ -41,17 +42,8 @@ namespace ngraph
{ {
} }
}; };
struct unsupported_element_type : ngraph_error } // namespace value_info
{ } // namespace error
explicit unsupported_element_type(TensorProto_DataType type)
: ngraph_error{"unsupported value info element type: " +
onnx::TensorProto_DataType_Name(
static_cast<onnx::TensorProto_DataType>(type))}
{
}
};
}
}
class ValueInfo class ValueInfo
{ {
...@@ -83,24 +75,8 @@ namespace ngraph ...@@ -83,24 +75,8 @@ namespace ngraph
{ {
throw error::value_info::unspecified_element_type{}; throw error::value_info::unspecified_element_type{};
} }
switch (m_value_info_proto->type().tensor_type().elem_type()) return common::get_ngraph_element_type(
{ m_value_info_proto->type().tensor_type().elem_type());
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: return element::boolean;
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: return element::f32;
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE: return element::f64;
case onnx::TensorProto_DataType::TensorProto_DataType_INT8: return element::i8;
case onnx::TensorProto_DataType::TensorProto_DataType_INT16: return element::i16;
case onnx::TensorProto_DataType::TensorProto_DataType_INT32: return element::i32;
case onnx::TensorProto_DataType::TensorProto_DataType_INT64: return element::i64;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8: return element::u8;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: return element::u16;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: return element::u32;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: return element::u64;
default:
throw error::value_info::unsupported_element_type{
m_value_info_proto->type().tensor_type().elem_type()};
}
} }
std::shared_ptr<ngraph::Node> std::shared_ptr<ngraph::Node>
......
...@@ -13,14 +13,13 @@ ...@@ -13,14 +13,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <memory> #include <memory>
#include <onnx-ml.pb.h>
#include "cast.hpp" #include "cast.hpp"
#include "exceptions.hpp" #include "exceptions.hpp"
#include "ngraph/op/convert.hpp" #include "ngraph/op/convert.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -34,25 +33,7 @@ namespace ngraph ...@@ -34,25 +33,7 @@ namespace ngraph
{ {
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
int64_t target_type = node.get_attribute_value<int64_t>("to"); int64_t target_type = node.get_attribute_value<int64_t>("to");
element::Type elem_type; element::Type elem_type = common::get_ngraph_element_type(target_type);
switch (target_type)
{
case onnx::TensorProto_DataType_BOOL: elem_type = element::boolean; break;
case onnx::TensorProto_DataType_DOUBLE: elem_type = element::f64; break;
case onnx::TensorProto_DataType_FLOAT16:
case onnx::TensorProto_DataType_FLOAT: elem_type = element::f32; break;
case onnx::TensorProto_DataType_INT8: elem_type = element::i8; break;
case onnx::TensorProto_DataType_INT16: elem_type = element::i16; break;
case onnx::TensorProto_DataType_INT32: elem_type = element::i32; break;
case onnx::TensorProto_DataType_INT64: elem_type = element::i64; break;
case onnx::TensorProto_DataType_UINT8: elem_type = element::u8; break;
case onnx::TensorProto_DataType_UINT16: elem_type = element::u16; break;
case onnx::TensorProto_DataType_UINT32: elem_type = element::u32; break;
case onnx::TensorProto_DataType_UINT64: elem_type = element::u64; break;
case onnx::TensorProto_DataType_UNDEFINED: elem_type = element::dynamic; break;
default: ASSERT_IS_SUPPORTED(node, false) << "unsupported type";
}
return {std::make_shared<ngraph::op::Convert>(data, elem_type)}; return {std::make_shared<ngraph::op::Convert>(data, elem_type)};
} }
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "eye_like.hpp"
#include "exceptions.hpp"
#include "ngraph/frontend/onnx_import/utils/common.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector eye_like(const Node& node)
{
const auto input = node.get_ng_inputs().at(0);
const auto& input_shape = input->get_shape();
std::int64_t dtype;
element::Type target_type;
std::int64_t shift = node.get_attribute_value<std::int64_t>("k", 0);
if (node.has_attribute("dtype"))
{
dtype = node.get_attribute_value<std::int64_t>("dtype");
target_type = common::get_ngraph_element_type(dtype);
}
else
{
target_type = input->get_element_type();
}
ASSERT_VALID_ARGUMENT(node, input_shape.size() == 2)
<< "The provided shape rank: " << input_shape.size()
<< " is unsupported, only 2D shapes are supported";
std::shared_ptr<ngraph::Node> eye_like_matrix =
common::shifted_square_identity(input_shape, target_type, shift);
return {eye_like_matrix};
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector eye_like(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -53,6 +53,7 @@ ...@@ -53,6 +53,7 @@
#include "op/equal.hpp" #include "op/equal.hpp"
#include "op/erf.hpp" #include "op/erf.hpp"
#include "op/exp.hpp" #include "op/exp.hpp"
#include "op/eye_like.hpp"
#include "op/flatten.hpp" #include "op/flatten.hpp"
#include "op/floor.hpp" #include "op/floor.hpp"
#include "op/gather.hpp" #include "op/gather.hpp"
...@@ -260,6 +261,7 @@ namespace ngraph ...@@ -260,6 +261,7 @@ namespace ngraph
REGISTER_OPERATOR("Equal", 1, equal); REGISTER_OPERATOR("Equal", 1, equal);
REGISTER_OPERATOR("Erf", 1, erf); REGISTER_OPERATOR("Erf", 1, erf);
REGISTER_OPERATOR("Exp", 1, exp); REGISTER_OPERATOR("Exp", 1, exp);
REGISTER_OPERATOR("EyeLike", 1, eye_like);
REGISTER_OPERATOR("Flatten", 1, flatten); REGISTER_OPERATOR("Flatten", 1, flatten);
REGISTER_OPERATOR("Floor", 1, floor); REGISTER_OPERATOR("Floor", 1, floor);
REGISTER_OPERATOR("Gather", 1, gather); REGISTER_OPERATOR("Gather", 1, gather);
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <onnx-ml.pb.h> // onnx types
#include "common.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace common
{
const ngraph::element::Type& get_ngraph_element_type(int64_t onnx_type)
{
switch (onnx_type)
{
case onnx::TensorProto_DataType_BOOL: return element::boolean;
case onnx::TensorProto_DataType_DOUBLE: return element::f64;
case onnx::TensorProto_DataType_FLOAT16: return element::f16;
case onnx::TensorProto_DataType_FLOAT: return element::f32;
case onnx::TensorProto_DataType_INT8: return element::i8;
case onnx::TensorProto_DataType_INT16: return element::i16;
case onnx::TensorProto_DataType_INT32: return element::i32;
case onnx::TensorProto_DataType_INT64: return element::i64;
case onnx::TensorProto_DataType_UINT8: return element::u8;
case onnx::TensorProto_DataType_UINT16: return element::u16;
case onnx::TensorProto_DataType_UINT32: return element::u32;
case onnx::TensorProto_DataType_UINT64: return element::u64;
case onnx::TensorProto_DataType_UNDEFINED: return element::dynamic;
}
throw ngraph_error("unsupported element type: " +
onnx::TensorProto_DataType_Name(
static_cast<onnx::TensorProto_DataType>(onnx_type)));
}
} // namespace common
} // namespace onnx_import
} // namespace ngraph
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <algorithm> // std::generate #include <algorithm> // std::generate
#include <cmath> // std::floor, std::min #include <cmath> // std::floor, std::min
#include <cstddef> // std::size_t #include <cstddef> // std::size_t
#include <cstdint> // std::int64_t
#include <iterator> // std::begin, std::end #include <iterator> // std::begin, std::end
#include <memory> // std::shared_ptr, std::make_shared #include <memory> // std::shared_ptr, std::make_shared
#include <type_traits> // std::enable_if #include <type_traits> // std::enable_if
...@@ -27,6 +28,7 @@ ...@@ -27,6 +28,7 @@
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -34,6 +36,8 @@ namespace ngraph ...@@ -34,6 +36,8 @@ namespace ngraph
{ {
namespace common namespace common
{ {
const ngraph::element::Type& get_ngraph_element_type(std::int64_t onnx_type);
/// \brief Return a monotonic sequence. /// \brief Return a monotonic sequence.
/// ///
/// \note Limitations: this function may not work for very large integer values /// \note Limitations: this function may not work for very large integer values
...@@ -86,6 +90,43 @@ namespace ngraph ...@@ -86,6 +90,43 @@ namespace ngraph
} }
} }
/// \brief Creates a shifted square identity matrix.
/// \note Shifting in the context of this operator means that
/// the matrix can be created with elements equal to 1 not only in the main diagonal.
/// Shifting adds an offset and moves the diagonal up or down
///
/// \param[in] output_shape Shape of the resulting matrix.
/// \param[in] output_type Element type of the resulting matrix.
/// \param[in] shift Shifting of diagonal.
///
/// \return A Constant node representing shifted identity matrix.
template <typename T = double>
std::shared_ptr<ngraph::op::Constant>
shifted_square_identity(const Shape output_shape,
const element::Type& output_type,
const std::int64_t shift)
{
std::vector<T> identity_matrix(shape_size(output_shape), T{0});
std::int64_t rows = output_shape[0];
std::int64_t cols = output_shape[1];
for (std::int64_t row = 0; row < rows; ++row)
{
const std::int64_t diagonal_element_idx = (row * cols) + row + shift;
if (row + shift < 0)
{
continue;
}
else if (row + shift >= cols)
{
break;
}
identity_matrix.at(diagonal_element_idx) = T{1};
}
return std::make_shared<ngraph::op::Constant>(
output_type, output_shape, identity_matrix);
}
/// \brief Creates a square identity matrix. /// \brief Creates a square identity matrix.
/// ///
/// \param[in] n Order of the resulting matrix. /// \param[in] n Order of the resulting matrix.
...@@ -95,16 +136,9 @@ namespace ngraph ...@@ -95,16 +136,9 @@ namespace ngraph
std::shared_ptr<ngraph::op::Constant> square_identity(const size_t n, std::shared_ptr<ngraph::op::Constant> square_identity(const size_t n,
const element::Type& type) const element::Type& type)
{ {
std::vector<T> identity_matrix(n * n, T{0}); return shifted_square_identity(Shape{n, n}, type, 0);
for (size_t row = 0; row < n; ++row)
{
const size_t diagonal_element = (n * row) + row;
identity_matrix.at(diagonal_element) = T{1};
}
return std::make_shared<ngraph::op::Constant>(type, Shape{{n, n}}, identity_matrix);
} }
} // namespace common } // namespace common
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -216,6 +216,7 @@ namespace ngraph ...@@ -216,6 +216,7 @@ namespace ngraph
virtual bool is_op() const { return false; } virtual bool is_op() const { return false; }
virtual bool is_commutative() { return false; } virtual bool is_commutative() { return false; }
virtual bool is_dynamic() const; virtual bool is_dynamic() const;
virtual bool has_state() const { return false; }
size_t get_instance_id() const { return m_instance_id; } size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&); friend std::ostream& operator<<(std::ostream&, const Node&);
virtual std::ostream& write_short_description(std::ostream&) const; virtual std::ostream& write_short_description(std::ostream&) const;
......
...@@ -72,6 +72,8 @@ namespace ngraph ...@@ -72,6 +72,8 @@ namespace ngraph
/// \brief Returns the seed value supplied to a random generator /// \brief Returns the seed value supplied to a random generator
uint64_t get_seed() const { return m_seed; } uint64_t get_seed() const { return m_seed; }
bool get_use_seed() const { return m_use_seed; } bool get_use_seed() const { return m_use_seed; }
/// GenerateMask has state.
bool has_state() const override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override const NodeVector& deltas) override
......
...@@ -924,7 +924,8 @@ using namespace ngraph::runtime; ...@@ -924,7 +924,8 @@ using namespace ngraph::runtime;
// Always enable nodes computing output tensors or nodes whose outputs might get // Always enable nodes computing output tensors or nodes whose outputs might get
// overwritten due to inplace kernels // overwritten due to inplace kernels
// TODO (jbobba) - Do we need to handle cacheability // TODO (jbobba) - Do we need to handle cacheability
if (computes_result(node.get()) || possibly_overwritten(node.get())) if (computes_result(node.get()) || possibly_overwritten(node.get()) ||
node->has_state())
{ {
writer << " || 1"; writer << " || 1";
} }
...@@ -1423,7 +1424,7 @@ void runtime::cpu::CPU_ExternalFunction::build(ngraph::pass::PassConfig& pass_co ...@@ -1423,7 +1424,7 @@ void runtime::cpu::CPU_ExternalFunction::build(ngraph::pass::PassConfig& pass_co
bool disable_caching = bool disable_caching =
(reuse_memory && (reuse_memory &&
!cacheable) // Check cacheability only if we are reusing intermediate tensors !cacheable) // Check cacheability only if we are reusing intermediate tensors
|| computes_result(node.get()) || possibly_overwritten(node.get()); || computes_result(node.get()) || possibly_overwritten(node.get()) || node->has_state();
vector<reference_wrapper<bool>> in_stale, out_stale; vector<reference_wrapper<bool>> in_stale, out_stale;
for (const auto& name : in_names) for (const auto& name : in_names)
......
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "EyeLike"
attribute {
name: "k"
i: -1
type: INT
}
}
name: "hardmax_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
version: 9
}
...@@ -1482,3 +1482,46 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_shrink_int) ...@@ -1482,3 +1482,46 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_shrink_int)
test_case.run(); test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_eye_like)
{
const auto eye_like_fn = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/eye_like.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(eye_like_fn, "${BACKEND_NAME}");
test_case.add_input<float>({
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
});
test_case.add_expected_output<float>(Shape{3, 4},
{
0.f,
0.f,
0.f,
0.f,
1.f,
0.f,
0.f,
0.f,
0.f,
1.f,
0.f,
0.f,
});
test_case.run();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment