Commit ad12723b authored by Artur Wojcik's avatar Artur Wojcik Committed by Scott Cyphers

onnx [3]: add 'constant' operator (#1197)

* onnx: add 'constant' operator
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: getting attribute value by name
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: fix code style
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: fix clang compilation warnings
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: incorporate review comments
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>
parent e3ad1b31
......@@ -37,6 +37,7 @@ add_library(onnx_import STATIC
model.hpp
node.cpp
op/add.hpp
op/constant.hpp
ops_bridge.cpp
tensor.hpp
value_info.hpp)
......
......@@ -20,6 +20,9 @@
#include "onnx.pb.h"
#include "tensor.hpp"
#define likely(__x) __builtin_expect(!!(__x), 1)
#define unlikely(__x) __builtin_expect(!!(__x), 0)
namespace ngraph
{
namespace onnx_import
......@@ -33,9 +36,9 @@ namespace ngraph
{
namespace detail
{
struct attribute : ngraph_error
struct Attribute : ngraph_error
{
attribute(std::string msg, onnx::AttributeProto_AttributeType type)
Attribute(std::string msg, onnx::AttributeProto_AttributeType type)
: ngraph_error{std::move(msg) + ": " +
onnx::AttributeProto_AttributeType_Name(type)}
{
......@@ -44,18 +47,18 @@ namespace ngraph
} // namespace detail
struct invalid_data : detail::attribute
struct InvalidData : detail::Attribute
{
explicit invalid_data(onnx::AttributeProto_AttributeType type)
: attribute{"invalid attribute type", type}
explicit InvalidData(onnx::AttributeProto_AttributeType type)
: Attribute{"invalid attribute type", type}
{
}
};
struct unsupported_type : detail::attribute
struct UnsupportedType : detail::Attribute
{
explicit unsupported_type(onnx::AttributeProto_AttributeType type)
: attribute("unsupported attribute type", type)
explicit UnsupportedType(onnx::AttributeProto_AttributeType type)
: Attribute{"unsupported attribute type", type}
{
}
};
......@@ -64,6 +67,154 @@ namespace ngraph
} // namespace error
namespace detail
{
namespace attribute
{
template <typename T>
inline T get_value(const onnx::AttributeProto& attribute)
{
throw error::attribute::UnsupportedType{attribute.type()};
}
template <>
inline float get_value(const onnx::AttributeProto& attribute)
{
if (unlikely(attribute.type() != onnx::AttributeProto_AttributeType_FLOAT))
{
throw error::attribute::InvalidData{attribute.type()};
}
return attribute.f();
}
template <>
inline std::vector<float> get_value(const onnx::AttributeProto& attribute)
{
switch (attribute.type())
{
case onnx::AttributeProto_AttributeType_FLOAT: return {attribute.f()};
case onnx::AttributeProto_AttributeType_FLOATS:
return {std::begin(attribute.floats()), std::end(attribute.floats())};
default: throw error::attribute::InvalidData{attribute.type()};
}
}
template <>
inline double get_value(const onnx::AttributeProto& attribute)
{
if (unlikely(attribute.type() != onnx::AttributeProto_AttributeType_FLOAT))
{
throw error::attribute::InvalidData{attribute.type()};
}
return static_cast<double>(attribute.f());
}
template <>
inline std::vector<double> get_value(const onnx::AttributeProto& attribute)
{
switch (attribute.type())
{
case onnx::AttributeProto_AttributeType_FLOAT:
return {static_cast<double>(attribute.f())};
case onnx::AttributeProto_AttributeType_FLOATS:
return {std::begin(attribute.floats()), std::end(attribute.floats())};
default: throw error::attribute::InvalidData{attribute.type()};
}
}
template <>
inline std::size_t get_value(const onnx::AttributeProto& attribute)
{
if (unlikely(attribute.type() != onnx::AttributeProto_AttributeType_INT))
{
throw error::attribute::InvalidData{attribute.type()};
}
return static_cast<std::size_t>(attribute.i());
}
template <>
inline std::vector<std::size_t> get_value(const onnx::AttributeProto& attribute)
{
switch (attribute.type())
{
case onnx::AttributeProto_AttributeType_INT:
return {static_cast<std::size_t>(attribute.i())};
case onnx::AttributeProto_AttributeType_INTS:
return {std::begin(attribute.ints()), std::end(attribute.ints())};
default: throw error::attribute::InvalidData{attribute.type()};
}
}
template <>
inline int64_t get_value(const onnx::AttributeProto& attribute)
{
if (unlikely(attribute.type() != onnx::AttributeProto_AttributeType_INT))
{
throw error::attribute::InvalidData{attribute.type()};
}
return attribute.i();
}
template <>
inline std::vector<int64_t> get_value(const onnx::AttributeProto& attribute)
{
switch (attribute.type())
{
case onnx::AttributeProto_AttributeType_INT: return {attribute.i()};
case onnx::AttributeProto_AttributeType_INTS:
return {std::begin(attribute.ints()), std::end(attribute.ints())};
default: throw error::attribute::InvalidData{attribute.type()};
}
}
template <>
inline const std::string& get_value(const onnx::AttributeProto& attribute)
{
if (unlikely(attribute.type() != onnx::AttributeProto_AttributeType_STRING))
{
throw error::attribute::InvalidData{attribute.type()};
}
return attribute.s();
}
template <>
inline std::vector<std::string> get_value(const onnx::AttributeProto& attribute)
{
switch (attribute.type())
{
case onnx::AttributeProto_AttributeType_STRING: return {attribute.s()};
case onnx::AttributeProto_AttributeType_STRINGS:
return {std::begin(attribute.strings()), std::end(attribute.strings())};
default: throw error::attribute::InvalidData{attribute.type()};
}
}
template <>
inline Tensor get_value(const onnx::AttributeProto& attribute)
{
if (unlikely(attribute.type() != onnx::AttributeProto_AttributeType_TENSOR))
{
throw error::attribute::InvalidData{attribute.type()};
}
return Tensor{attribute.t()};
}
template <>
inline std::vector<Tensor> get_value(const onnx::AttributeProto& attribute)
{
switch (attribute.type())
{
case onnx::AttributeProto_AttributeType_TENSOR: return {Tensor{attribute.t()}};
case onnx::AttributeProto_AttributeType_TENSORS:
return {std::begin(attribute.tensors()), std::end(attribute.tensors())};
default: throw error::attribute::InvalidData{attribute.type()};
}
}
} // namespace attribute
} // namespace detail
class Attribute
{
public:
......@@ -142,6 +293,12 @@ namespace ngraph
return m_attribute_proto.type();
}
template <typename T>
T get_value() const
{
return detail::attribute::get_value<T>(m_attribute_proto);
}
private:
const onnx::AttributeProto& m_attribute_proto;
};
......
......@@ -16,6 +16,7 @@
#pragma once
#include <algorithm>
#include <ostream>
#include <string>
......@@ -30,6 +31,23 @@ namespace ngraph
{
namespace onnx_import
{
namespace error
{
namespace node
{
struct UnknownAttribute : ngraph_error
{
explicit UnknownAttribute(const std::string& node, const std::string& name)
: ngraph_error{"Node (" + node + "): unknown attribute \'" + name + "\'"}
{
}
};
} // namespace node
} // namespace error
// forward declaration
class Graph;
class Node
......@@ -56,6 +74,34 @@ namespace ngraph
const std::string& op_type() const { return m_node_proto.op_type(); }
const std::string& get_name() const { return m_node_proto.name(); }
const std::string& output(int index) const { return m_node_proto.output(index); }
template <typename T>
T get_attribute_value(const std::string& name, T default_value) const
{
auto it{std::find_if(
std::begin(m_attributes),
std::end(m_attributes),
[&](const Attribute& attribute) { return attribute.get_name() == name; })};
if (it == std::end(m_attributes))
{
return default_value;
}
return it->template get_value<T>();
}
template <typename T>
T get_attribute_value(const std::string& name) const
{
auto it{std::find_if(
std::begin(m_attributes),
std::end(m_attributes),
[&](const Attribute& attribute) { return attribute.get_name() == name; })};
if (it == std::end(m_attributes))
{
throw error::node::UnknownAttribute{get_name(), name};
}
return it->template get_value<T>();
}
private:
const onnx::NodeProto& m_node_proto;
const Graph* m_graph;
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/constant.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace detail
{
namespace
{
template <typename T>
inline std::shared_ptr<ngraph::op::Constant>
__make_ng_constant(const element::Type& type, const Tensor& tensor)
{
return std::make_shared<ngraph::op::Constant>(
type, tensor.get_shape(), tensor.get_data<T>());
}
}
template <Tensor::Type>
inline std::shared_ptr<ngraph::op::Constant> make_ng_constant(const Tensor& tensor)
{
throw error::tensor::unsupported_data_type{tensor};
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float16>(const Tensor& tensor)
{
return __make_ng_constant<float>(element::f32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float32>(const Tensor& tensor)
{
return __make_ng_constant<float>(element::f32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float64>(const Tensor& tensor)
{
return __make_ng_constant<double>(element::f64, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::int32>(const Tensor& tensor)
{
return __make_ng_constant<int32_t>(element::i32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint32>(const Tensor& tensor)
{
return __make_ng_constant<uint32_t>(element::u32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint64>(const Tensor& tensor)
{
return __make_ng_constant<uint64_t>(element::u64, tensor);
}
} // namespace detail
inline std::shared_ptr<ngraph::op::Constant> constant(const Tensor& tensor)
{
#define _make_ng_constant(_data_type) \
case _data_type: return detail::make_ng_constant<_data_type>(tensor)
switch (tensor.get_type())
{
_make_ng_constant(Tensor::Type::float16);
_make_ng_constant(Tensor::Type::float32);
_make_ng_constant(Tensor::Type::float64);
_make_ng_constant(Tensor::Type::int32);
_make_ng_constant(Tensor::Type::uint32);
_make_ng_constant(Tensor::Type::uint64);
default: throw error::tensor::invalid_data_type{tensor};
}
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -17,7 +17,9 @@
#include <algorithm>
#include <functional>
#include "attribute.hpp"
#include "ngraph/frontend/onnx_import/op/add.hpp"
#include "ngraph/frontend/onnx_import/op/constant.hpp"
#include "ops_bridge.hpp"
namespace ngraph
......@@ -39,6 +41,11 @@ namespace ngraph
} // namespace error
NodeVector add(const Node& node) { return op::add(node); }
NodeVector constant(const Node& node)
{
return {op::constant(node.get_attribute_value<Tensor>("value"))};
}
class ops_bridge
{
public:
......@@ -61,7 +68,12 @@ namespace ngraph
return instance;
}
ops_bridge() { m_map.emplace("Add", std::bind(add, std::placeholders::_1)); }
ops_bridge()
{
m_map.emplace("Add", std::bind(add, std::placeholders::_1));
m_map.emplace("Constant", std::bind(constant, std::placeholders::_1));
}
NodeVector operator()(const Node& node) const
{
try
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment