Commit ad12723b authored by Artur Wojcik's avatar Artur Wojcik Committed by Scott Cyphers

onnx [3]: add 'constant' operator (#1197)

* onnx: add 'constant' operator
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: getting attribute value by name
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: fix code style
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: fix clang compilation warnings
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: incorporate review comments
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>
parent e3ad1b31
...@@ -37,6 +37,7 @@ add_library(onnx_import STATIC ...@@ -37,6 +37,7 @@ add_library(onnx_import STATIC
model.hpp model.hpp
node.cpp node.cpp
op/add.hpp op/add.hpp
op/constant.hpp
ops_bridge.cpp ops_bridge.cpp
tensor.hpp tensor.hpp
value_info.hpp) value_info.hpp)
......
...@@ -20,6 +20,9 @@ ...@@ -20,6 +20,9 @@
#include "onnx.pb.h" #include "onnx.pb.h"
#include "tensor.hpp" #include "tensor.hpp"
#define likely(__x) __builtin_expect(!!(__x), 1)
#define unlikely(__x) __builtin_expect(!!(__x), 0)
namespace ngraph namespace ngraph
{ {
namespace onnx_import namespace onnx_import
...@@ -33,9 +36,9 @@ namespace ngraph ...@@ -33,9 +36,9 @@ namespace ngraph
{ {
namespace detail namespace detail
{ {
struct attribute : ngraph_error struct Attribute : ngraph_error
{ {
attribute(std::string msg, onnx::AttributeProto_AttributeType type) Attribute(std::string msg, onnx::AttributeProto_AttributeType type)
: ngraph_error{std::move(msg) + ": " + : ngraph_error{std::move(msg) + ": " +
onnx::AttributeProto_AttributeType_Name(type)} onnx::AttributeProto_AttributeType_Name(type)}
{ {
...@@ -44,18 +47,18 @@ namespace ngraph ...@@ -44,18 +47,18 @@ namespace ngraph
} // namespace detail } // namespace detail
struct invalid_data : detail::attribute struct InvalidData : detail::Attribute
{ {
explicit invalid_data(onnx::AttributeProto_AttributeType type) explicit InvalidData(onnx::AttributeProto_AttributeType type)
: attribute{"invalid attribute type", type} : Attribute{"invalid attribute type", type}
{ {
} }
}; };
struct unsupported_type : detail::attribute struct UnsupportedType : detail::Attribute
{ {
explicit unsupported_type(onnx::AttributeProto_AttributeType type) explicit UnsupportedType(onnx::AttributeProto_AttributeType type)
: attribute("unsupported attribute type", type) : Attribute{"unsupported attribute type", type}
{ {
} }
}; };
...@@ -64,6 +67,154 @@ namespace ngraph ...@@ -64,6 +67,154 @@ namespace ngraph
} // namespace error } // namespace error
namespace detail
{
namespace attribute
{
template <typename T>
inline T get_value(const onnx::AttributeProto& attribute)
{
throw error::attribute::UnsupportedType{attribute.type()};
}
template <>
inline float get_value(const onnx::AttributeProto& attribute)
{
if (unlikely(attribute.type() != onnx::AttributeProto_AttributeType_FLOAT))
{
throw error::attribute::InvalidData{attribute.type()};
}
return attribute.f();
}
template <>
inline std::vector<float> get_value(const onnx::AttributeProto& attribute)
{
switch (attribute.type())
{
case onnx::AttributeProto_AttributeType_FLOAT: return {attribute.f()};
case onnx::AttributeProto_AttributeType_FLOATS:
return {std::begin(attribute.floats()), std::end(attribute.floats())};
default: throw error::attribute::InvalidData{attribute.type()};
}
}
template <>
inline double get_value(const onnx::AttributeProto& attribute)
{
if (unlikely(attribute.type() != onnx::AttributeProto_AttributeType_FLOAT))
{
throw error::attribute::InvalidData{attribute.type()};
}
return static_cast<double>(attribute.f());
}
template <>
inline std::vector<double> get_value(const onnx::AttributeProto& attribute)
{
switch (attribute.type())
{
case onnx::AttributeProto_AttributeType_FLOAT:
return {static_cast<double>(attribute.f())};
case onnx::AttributeProto_AttributeType_FLOATS:
return {std::begin(attribute.floats()), std::end(attribute.floats())};
default: throw error::attribute::InvalidData{attribute.type()};
}
}
template <>
inline std::size_t get_value(const onnx::AttributeProto& attribute)
{
if (unlikely(attribute.type() != onnx::AttributeProto_AttributeType_INT))
{
throw error::attribute::InvalidData{attribute.type()};
}
return static_cast<std::size_t>(attribute.i());
}
template <>
inline std::vector<std::size_t> get_value(const onnx::AttributeProto& attribute)
{
switch (attribute.type())
{
case onnx::AttributeProto_AttributeType_INT:
return {static_cast<std::size_t>(attribute.i())};
case onnx::AttributeProto_AttributeType_INTS:
return {std::begin(attribute.ints()), std::end(attribute.ints())};
default: throw error::attribute::InvalidData{attribute.type()};
}
}
template <>
inline int64_t get_value(const onnx::AttributeProto& attribute)
{
if (unlikely(attribute.type() != onnx::AttributeProto_AttributeType_INT))
{
throw error::attribute::InvalidData{attribute.type()};
}
return attribute.i();
}
template <>
inline std::vector<int64_t> get_value(const onnx::AttributeProto& attribute)
{
switch (attribute.type())
{
case onnx::AttributeProto_AttributeType_INT: return {attribute.i()};
case onnx::AttributeProto_AttributeType_INTS:
return {std::begin(attribute.ints()), std::end(attribute.ints())};
default: throw error::attribute::InvalidData{attribute.type()};
}
}
template <>
inline const std::string& get_value(const onnx::AttributeProto& attribute)
{
if (unlikely(attribute.type() != onnx::AttributeProto_AttributeType_STRING))
{
throw error::attribute::InvalidData{attribute.type()};
}
return attribute.s();
}
template <>
inline std::vector<std::string> get_value(const onnx::AttributeProto& attribute)
{
switch (attribute.type())
{
case onnx::AttributeProto_AttributeType_STRING: return {attribute.s()};
case onnx::AttributeProto_AttributeType_STRINGS:
return {std::begin(attribute.strings()), std::end(attribute.strings())};
default: throw error::attribute::InvalidData{attribute.type()};
}
}
template <>
inline Tensor get_value(const onnx::AttributeProto& attribute)
{
if (unlikely(attribute.type() != onnx::AttributeProto_AttributeType_TENSOR))
{
throw error::attribute::InvalidData{attribute.type()};
}
return Tensor{attribute.t()};
}
template <>
inline std::vector<Tensor> get_value(const onnx::AttributeProto& attribute)
{
switch (attribute.type())
{
case onnx::AttributeProto_AttributeType_TENSOR: return {Tensor{attribute.t()}};
case onnx::AttributeProto_AttributeType_TENSORS:
return {std::begin(attribute.tensors()), std::end(attribute.tensors())};
default: throw error::attribute::InvalidData{attribute.type()};
}
}
} // namespace attribute
} // namespace detail
class Attribute class Attribute
{ {
public: public:
...@@ -142,6 +293,12 @@ namespace ngraph ...@@ -142,6 +293,12 @@ namespace ngraph
return m_attribute_proto.type(); return m_attribute_proto.type();
} }
template <typename T>
T get_value() const
{
return detail::attribute::get_value<T>(m_attribute_proto);
}
private: private:
const onnx::AttributeProto& m_attribute_proto; const onnx::AttributeProto& m_attribute_proto;
}; };
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#pragma once #pragma once
#include <algorithm>
#include <ostream> #include <ostream>
#include <string> #include <string>
...@@ -30,6 +31,23 @@ namespace ngraph ...@@ -30,6 +31,23 @@ namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
namespace error
{
namespace node
{
struct UnknownAttribute : ngraph_error
{
explicit UnknownAttribute(const std::string& node, const std::string& name)
: ngraph_error{"Node (" + node + "): unknown attribute \'" + name + "\'"}
{
}
};
} // namespace node
} // namespace error
// forward declaration
class Graph; class Graph;
class Node class Node
...@@ -56,6 +74,34 @@ namespace ngraph ...@@ -56,6 +74,34 @@ namespace ngraph
const std::string& op_type() const { return m_node_proto.op_type(); } const std::string& op_type() const { return m_node_proto.op_type(); }
const std::string& get_name() const { return m_node_proto.name(); } const std::string& get_name() const { return m_node_proto.name(); }
const std::string& output(int index) const { return m_node_proto.output(index); } const std::string& output(int index) const { return m_node_proto.output(index); }
template <typename T>
T get_attribute_value(const std::string& name, T default_value) const
{
auto it{std::find_if(
std::begin(m_attributes),
std::end(m_attributes),
[&](const Attribute& attribute) { return attribute.get_name() == name; })};
if (it == std::end(m_attributes))
{
return default_value;
}
return it->template get_value<T>();
}
template <typename T>
T get_attribute_value(const std::string& name) const
{
auto it{std::find_if(
std::begin(m_attributes),
std::end(m_attributes),
[&](const Attribute& attribute) { return attribute.get_name() == name; })};
if (it == std::end(m_attributes))
{
throw error::node::UnknownAttribute{get_name(), name};
}
return it->template get_value<T>();
}
private: private:
const onnx::NodeProto& m_node_proto; const onnx::NodeProto& m_node_proto;
const Graph* m_graph; const Graph* m_graph;
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/constant.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace detail
{
namespace
{
template <typename T>
inline std::shared_ptr<ngraph::op::Constant>
__make_ng_constant(const element::Type& type, const Tensor& tensor)
{
return std::make_shared<ngraph::op::Constant>(
type, tensor.get_shape(), tensor.get_data<T>());
}
}
template <Tensor::Type>
inline std::shared_ptr<ngraph::op::Constant> make_ng_constant(const Tensor& tensor)
{
throw error::tensor::unsupported_data_type{tensor};
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float16>(const Tensor& tensor)
{
return __make_ng_constant<float>(element::f32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float32>(const Tensor& tensor)
{
return __make_ng_constant<float>(element::f32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float64>(const Tensor& tensor)
{
return __make_ng_constant<double>(element::f64, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::int32>(const Tensor& tensor)
{
return __make_ng_constant<int32_t>(element::i32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint32>(const Tensor& tensor)
{
return __make_ng_constant<uint32_t>(element::u32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint64>(const Tensor& tensor)
{
return __make_ng_constant<uint64_t>(element::u64, tensor);
}
} // namespace detail
inline std::shared_ptr<ngraph::op::Constant> constant(const Tensor& tensor)
{
#define _make_ng_constant(_data_type) \
case _data_type: return detail::make_ng_constant<_data_type>(tensor)
switch (tensor.get_type())
{
_make_ng_constant(Tensor::Type::float16);
_make_ng_constant(Tensor::Type::float32);
_make_ng_constant(Tensor::Type::float64);
_make_ng_constant(Tensor::Type::int32);
_make_ng_constant(Tensor::Type::uint32);
_make_ng_constant(Tensor::Type::uint64);
default: throw error::tensor::invalid_data_type{tensor};
}
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -17,7 +17,9 @@ ...@@ -17,7 +17,9 @@
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include "attribute.hpp"
#include "ngraph/frontend/onnx_import/op/add.hpp" #include "ngraph/frontend/onnx_import/op/add.hpp"
#include "ngraph/frontend/onnx_import/op/constant.hpp"
#include "ops_bridge.hpp" #include "ops_bridge.hpp"
namespace ngraph namespace ngraph
...@@ -39,6 +41,11 @@ namespace ngraph ...@@ -39,6 +41,11 @@ namespace ngraph
} // namespace error } // namespace error
NodeVector add(const Node& node) { return op::add(node); } NodeVector add(const Node& node) { return op::add(node); }
NodeVector constant(const Node& node)
{
return {op::constant(node.get_attribute_value<Tensor>("value"))};
}
class ops_bridge class ops_bridge
{ {
public: public:
...@@ -61,7 +68,12 @@ namespace ngraph ...@@ -61,7 +68,12 @@ namespace ngraph
return instance; return instance;
} }
ops_bridge() { m_map.emplace("Add", std::bind(add, std::placeholders::_1)); } ops_bridge()
{
m_map.emplace("Add", std::bind(add, std::placeholders::_1));
m_map.emplace("Constant", std::bind(constant, std::placeholders::_1));
}
NodeVector operator()(const Node& node) const NodeVector operator()(const Node& node) const
{ {
try try
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment