Unverified Commit 31ee5658 authored by Michał Karzyński's avatar Michał Karzyński Committed by GitHub

[ONNX] Attribute helper functions (#1468)

parent 33f4f394
......@@ -80,11 +80,12 @@ namespace ngraph
template <>
inline float get_value(const onnx::AttributeProto& attribute)
{
if (unlikely(attribute.type() != onnx::AttributeProto_AttributeType_FLOAT))
switch (attribute.type())
{
throw error::attribute::InvalidData{attribute.type()};
case onnx::AttributeProto_AttributeType_INT: return attribute.i();
case onnx::AttributeProto_AttributeType_FLOAT: return attribute.f();
default: throw error::attribute::InvalidData{attribute.type()};
}
return attribute.f();
}
template <>
......@@ -92,6 +93,10 @@ namespace ngraph
{
switch (attribute.type())
{
case onnx::AttributeProto_AttributeType_INT:
return {static_cast<float>(attribute.i())};
case onnx::AttributeProto_AttributeType_INTS:
return {std::begin(attribute.floats()), std::end(attribute.floats())};
case onnx::AttributeProto_AttributeType_FLOAT: return {attribute.f()};
case onnx::AttributeProto_AttributeType_FLOATS:
return {std::begin(attribute.floats()), std::end(attribute.floats())};
......@@ -102,11 +107,13 @@ namespace ngraph
template <>
inline double get_value(const onnx::AttributeProto& attribute)
{
if (unlikely(attribute.type() != onnx::AttributeProto_AttributeType_FLOAT))
switch (attribute.type())
{
throw error::attribute::InvalidData{attribute.type()};
}
case onnx::AttributeProto_AttributeType_FLOAT:
return static_cast<double>(attribute.f());
case onnx::AttributeProto_AttributeType_INT: return attribute.i();
default: throw error::attribute::InvalidData{attribute.type()};
}
}
template <>
......@@ -114,6 +121,10 @@ namespace ngraph
{
switch (attribute.type())
{
case onnx::AttributeProto_AttributeType_INT:
return {static_cast<double>(attribute.i())};
case onnx::AttributeProto_AttributeType_INTS:
return {std::begin(attribute.ints()), std::end(attribute.ints())};
case onnx::AttributeProto_AttributeType_FLOAT:
return {static_cast<double>(attribute.f())};
case onnx::AttributeProto_AttributeType_FLOATS:
......
......@@ -119,6 +119,65 @@ namespace ngraph
return (outs << "<Node(" << node.op_type() << "): " << node.get_name() << ">");
}
namespace attribute
{
/**
* @brief Get shape of kernel (filter) in pixels.
*
* @param node The Node ptr representing Conv or Pool operation.
* @return The kernel Shape object representing its dimensions (height, width, depth).
*/
inline Shape get_kernel_shape(const Node& node)
{
return node.get_attribute_value<std::vector<std::size_t>>("kernel_shape", {1, 1});
}
namespace detail
{
inline Strides get_strides_helper(const Node& node,
const std::string& name,
const Shape& kernel_shape)
{
return node.get_attribute_value<std::vector<std::size_t>>(
name, std::vector<std::size_t>(kernel_shape.size(), 1UL));
}
} // namespace detail
/**
* @brief Get number of pixels to stride operation by in each direction.
*
* @param node The Node ptr representing Conv or Pool operation.
* @param kernel_shape The shape of the kernel which we retrieve strides for.
* @return The kernel Shape object representing its dimensions (height, width, depth).
*/
inline Strides get_strides(const Node& node, const Shape& kernel_shape)
{
return detail::get_strides_helper(node, "strides", kernel_shape);
}
/**
* @brief Get number of pixels to stride operation by in each direction.
*
* @param node The Node ptr representing Conv or Pool operation.
* @return The kernel Shape object representing its dimensions (height, width, depth).
*/
inline Strides get_strides(const Node& node)
{
return get_strides(node, get_kernel_shape(node));
}
/**
* @brief Get number of pixels for filter dilation in each direction.
*
* @param node The Node ptr representing ONNX operation.
* @return The Strides object containing number of pixels for filter dilation
* (height, width, depth).
*/
inline Strides get_dilations(const Node& node)
{
return detail::get_strides_helper(node, "dilations", get_kernel_shape(node));
}
} // namespace attribute
} // namespace onnx_import
} // namespace ngraph
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment