Unverified Commit 31ee5658 authored by Michał Karzyński's avatar Michał Karzyński Committed by GitHub

[ONNX] Attribute helper functions (#1468)

parent 33f4f394
...@@ -80,11 +80,12 @@ namespace ngraph ...@@ -80,11 +80,12 @@ namespace ngraph
template <> template <>
inline float get_value(const onnx::AttributeProto& attribute) inline float get_value(const onnx::AttributeProto& attribute)
{ {
if (unlikely(attribute.type() != onnx::AttributeProto_AttributeType_FLOAT)) switch (attribute.type())
{ {
throw error::attribute::InvalidData{attribute.type()}; case onnx::AttributeProto_AttributeType_INT: return attribute.i();
case onnx::AttributeProto_AttributeType_FLOAT: return attribute.f();
default: throw error::attribute::InvalidData{attribute.type()};
} }
return attribute.f();
} }
template <> template <>
...@@ -92,6 +93,10 @@ namespace ngraph ...@@ -92,6 +93,10 @@ namespace ngraph
{ {
switch (attribute.type()) switch (attribute.type())
{ {
case onnx::AttributeProto_AttributeType_INT:
return {static_cast<float>(attribute.i())};
case onnx::AttributeProto_AttributeType_INTS:
return {std::begin(attribute.floats()), std::end(attribute.floats())};
case onnx::AttributeProto_AttributeType_FLOAT: return {attribute.f()}; case onnx::AttributeProto_AttributeType_FLOAT: return {attribute.f()};
case onnx::AttributeProto_AttributeType_FLOATS: case onnx::AttributeProto_AttributeType_FLOATS:
return {std::begin(attribute.floats()), std::end(attribute.floats())}; return {std::begin(attribute.floats()), std::end(attribute.floats())};
...@@ -102,11 +107,13 @@ namespace ngraph ...@@ -102,11 +107,13 @@ namespace ngraph
template <> template <>
inline double get_value(const onnx::AttributeProto& attribute) inline double get_value(const onnx::AttributeProto& attribute)
{ {
if (unlikely(attribute.type() != onnx::AttributeProto_AttributeType_FLOAT)) switch (attribute.type())
{ {
throw error::attribute::InvalidData{attribute.type()}; case onnx::AttributeProto_AttributeType_FLOAT:
return static_cast<double>(attribute.f());
case onnx::AttributeProto_AttributeType_INT: return attribute.i();
default: throw error::attribute::InvalidData{attribute.type()};
} }
return static_cast<double>(attribute.f());
} }
template <> template <>
...@@ -114,6 +121,10 @@ namespace ngraph ...@@ -114,6 +121,10 @@ namespace ngraph
{ {
switch (attribute.type()) switch (attribute.type())
{ {
case onnx::AttributeProto_AttributeType_INT:
return {static_cast<double>(attribute.i())};
case onnx::AttributeProto_AttributeType_INTS:
return {std::begin(attribute.ints()), std::end(attribute.ints())};
case onnx::AttributeProto_AttributeType_FLOAT: case onnx::AttributeProto_AttributeType_FLOAT:
return {static_cast<double>(attribute.f())}; return {static_cast<double>(attribute.f())};
case onnx::AttributeProto_AttributeType_FLOATS: case onnx::AttributeProto_AttributeType_FLOATS:
......
...@@ -119,6 +119,65 @@ namespace ngraph ...@@ -119,6 +119,65 @@ namespace ngraph
return (outs << "<Node(" << node.op_type() << "): " << node.get_name() << ">"); return (outs << "<Node(" << node.op_type() << "): " << node.get_name() << ">");
} }
namespace attribute
{
/**
* @brief Get shape of kernel (filter) in pixels.
*
* @param node The Node ptr representing Conv or Pool operation.
* @return The kernel Shape object representing its dimensions (height, width, depth).
*/
inline Shape get_kernel_shape(const Node& node)
{
return node.get_attribute_value<std::vector<std::size_t>>("kernel_shape", {1, 1});
}
namespace detail
{
inline Strides get_strides_helper(const Node& node,
const std::string& name,
const Shape& kernel_shape)
{
return node.get_attribute_value<std::vector<std::size_t>>(
name, std::vector<std::size_t>(kernel_shape.size(), 1UL));
}
} // namespace detail
/**
* @brief Get number of pixels to stride operation by in each direction.
*
* @param node The Node ptr representing Conv or Pool operation.
* @param kernel_shape The shape of the kernel which we retrieve strides for.
* @return The kernel Shape object representing its dimensions (height, width, depth).
*/
inline Strides get_strides(const Node& node, const Shape& kernel_shape)
{
return detail::get_strides_helper(node, "strides", kernel_shape);
}
/**
* @brief Get number of pixels to stride operation by in each direction.
*
* @param node The Node ptr representing Conv or Pool operation.
* @return The kernel Shape object representing its dimensions (height, width, depth).
*/
inline Strides get_strides(const Node& node)
{
return get_strides(node, get_kernel_shape(node));
}
/**
* @brief Get number of pixels for filter dilation in each direction.
*
* @param node The Node ptr representing ONNX operation.
* @return The Strides object containing number of pixels for filter dilation
* (height, width, depth).
*/
inline Strides get_dilations(const Node& node)
{
return detail::get_strides_helper(node, "dilations", get_kernel_shape(node));
}
} // namespace attribute
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment