Commit da928d61 authored by amy.zhuang's avatar amy.zhuang

Merge branch 'ayzhuang/in-place-concat' of…

Merge branch 'ayzhuang/in-place-concat' of https://github.com/NervanaSystems/ngraph into ayzhuang/in-place-concat
parents 5e0fdc60 f42f5d1d
...@@ -11,23 +11,24 @@ Dequantize ...@@ -11,23 +11,24 @@ Dequantize
Description Description
=========== ===========
Produces a tensor of element type ``type`` and the same shape as ``input`` Produces a tensor of element type ``type`` and the same shape as ``input``
where the value of each coordinate :math:`i` of ``output`` is the corresponding coordinate of where the value of each coordinate :math:`i` of ``output`` is the corresponding coordinate of
``input`` plus ``offset`` quantity multiplied by ``scale``. The coordinate :math:`j` of ``input`` minus ``offset`` quantity multiplied by ``scale``.
``scale`` and ``offset`` is the coordinate of ``output`` projected onto ``axes``. The coordinate :math:`j` of ``scale`` and ``offset`` is the coordinate of ``output``
projected onto ``axes``.
Inputs Inputs
------ ------
+-----------------+-------------------------+---------------------------------------+ +-----------------+-------------------------+------------------------------------------+
| Name | Element Type | Shape | | Name | Element Type | Shape |
+=================+=========================+=======================================+ +=================+=========================+==========================================+
| ``input`` | Any quantized type | Any | | ``input`` | Any quantized type | Any |
+-----------------+-------------------------+---------------------------------------+ +-----------------+-------------------------+------------------------------------------+
| ``scale`` | Same as ``output`` | ``input`` shape projected on ``axes`` | | ``scale`` | Same as ``output`` | ``input`` shape projected onto ``axes`` |
+-----------------+-------------------------+---------------------------------------+ +-----------------+-------------------------+------------------------------------------+
| ``offset`` | Same as ``input`` | ``input`` shape projected on ``axes`` | | ``offset`` | Same as ``input`` | ``input`` shape projected onto ``axes`` |
+-----------------+-------------------------+---------------------------------------+ +-----------------+-------------------------+------------------------------------------+
Attributes Attributes
---------- ----------
...@@ -35,20 +36,22 @@ Attributes ...@@ -35,20 +36,22 @@ Attributes
+-------------------------------+----------------------------------------------------------------+ +-------------------------------+----------------------------------------------------------------+
| Name | Description | | Name | Description |
+===============================+================================================================+ +===============================+================================================================+
| ``type`` | ``output`` element type | | ``type`` | ``output`` element type; any real type |
+-------------------------------+----------------------------------------------------------------+ +-------------------------------+----------------------------------------------------------------+
| ``axes`` | Axis positions on which ``scale`` and ``offset`` are specified | | ``axes`` | Axis positions on which ``scale`` and ``offset`` are specified |
+-------------------------------+----------------------------------------------------------------+ +-------------------------------+----------------------------------------------------------------+
Outputs Outputs
------- -------
+-----------------+-------------------------+---------------------------------------+ +-----------------+-------------------------+---------------------------------------+
| Name | Element Type | Shape | | Name | Element Type | Shape |
+=================+=========================+=======================================+ +=================+=========================+=======================================+
| ``output`` | is_real() | Same as ``input`` | | ``output`` | ``type`` | Same as ``input`` |
+-----------------+-------------------------+---------------------------------------+ +-----------------+-------------------------+---------------------------------------+
Mathematical Definition Mathematical Definition
...@@ -56,8 +59,7 @@ Mathematical Definition ...@@ -56,8 +59,7 @@ Mathematical Definition
.. math:: .. math::
\mathtt{output}_{i} = (\mathtt{input}_{i} + \mathtt{offset}_{j}) \mathtt{scale}_{j} \mathtt{output}_{i,j} = (\mathtt{input}_{i,j} - \mathtt{offset}_{j}) \mathtt{scale}_{j}
C++ Interface C++ Interface
============= =============
......
...@@ -11,24 +11,24 @@ Quantize ...@@ -11,24 +11,24 @@ Quantize
Description Description
=========== ===========
Produces a tensor of element type ``type`` and the same shape as ``input`` Produces a tensor of element type ``type`` and the same shape as ``input``
where the value of each coordinate :math:`i` of ``output`` is the corresponding where the value of each coordinate :math:`i` of ``output`` is the corresponding coordinate of
coordinate of ``input`` divided by ``scale`` rounded as specified by ``input`` divided by ``scale`` rounded as specified by ``round_mode`` plus ``offset``.
``round_mode`` minus ``offset``. The coordinate :math:`j` of ``scale`` and The coordinate :math:`j` of ``scale`` and ``offset`` is the coordinate of ``output``
``offset`` is the coordinate of ``output`` projected onto ``axes``. projected onto ``axes``.
Inputs Inputs
------ ------
+-----------------+-------------------------+---------------------------------------+ +-----------------+-------------------------+------------------------------------------+
| Name | Element Type | Shape | | Name | Element Type | Shape |
+=================+=========================+=======================================+ +=================+=========================+==========================================+
| ``input`` | is_real() | Any | | ``input`` | Any real type | Any |
+-----------------+-------------------------+---------------------------------------+ +-----------------+-------------------------+------------------------------------------+
| ``scale`` | Same as ``input`` | ``input`` shape projected on ``axes`` | | ``scale`` | Same as ``input`` | ``input`` shape projected onto ``axes`` |
+-----------------+-------------------------+---------------------------------------+ +-----------------+-------------------------+------------------------------------------+
| ``offset`` | Same as ``output`` | ``input`` shape projected on ``axes`` | | ``offset`` | Same as ``output`` | ``input`` shape projected onto ``axes`` |
+-----------------+-------------------------+---------------------------------------+ +-----------------+-------------------------+------------------------------------------+
Attributes Attributes
---------- ----------
...@@ -36,7 +36,7 @@ Attributes ...@@ -36,7 +36,7 @@ Attributes
+-------------------------------+----------------------------------------------------------------+ +-------------------------------+----------------------------------------------------------------+
| Name | Description | | Name | Description |
+===============================+================================================================+ +===============================+================================================================+
| ``type`` | The output element type, which must be a quantized type | | ``type`` | ``output`` element type; any quantized type |
+-------------------------------+----------------------------------------------------------------+ +-------------------------------+----------------------------------------------------------------+
| ``axes`` | Axis positions on which ``scale`` and ``offset`` are specified | | ``axes`` | Axis positions on which ``scale`` and ``offset`` are specified |
+-------------------------------+----------------------------------------------------------------+ +-------------------------------+----------------------------------------------------------------+
...@@ -51,7 +51,7 @@ Outputs ...@@ -51,7 +51,7 @@ Outputs
+-----------------+-------------------------+---------------------------------------+ +-----------------+-------------------------+---------------------------------------+
| Name | Element Type | Shape | | Name | Element Type | Shape |
+=================+=========================+=======================================+ +=================+=========================+=======================================+
| ``output`` | type | Same as ``input`` | | ``output`` | ``type`` | Same as ``input`` |
+-----------------+-------------------------+---------------------------------------+ +-----------------+-------------------------+---------------------------------------+
Mathematical Definition Mathematical Definition
...@@ -59,9 +59,7 @@ Mathematical Definition ...@@ -59,9 +59,7 @@ Mathematical Definition
.. math:: .. math::
\mathtt{output}_{i} = \mathtt{round}(\frac{\mathtt{input}_{i}}{\mathtt{scale}_{j}}) - \mathtt{offset}_{j} \mathtt{output}_{i,j} = \mathtt{round}\left(\frac{\mathtt{input}_{i,j}}{\mathtt{scale}_{j}}\right) + \mathtt{offset}_{j}
C++ Interface C++ Interface
============= =============
......
...@@ -25,15 +25,15 @@ namespace ngraph ...@@ -25,15 +25,15 @@ namespace ngraph
namespace op namespace op
{ {
/// \brief Dequantize operation /// \brief Dequantize operation
/// Maps quantized input (q) to real output (r) using scale (s) and offset (o) /// Maps quantized input (q) to real output (r) using scale (s) and offset (o):
/// q = (r + o) * s /// r = (q - o) * s
class Dequantize : public ngraph::op::Op class Dequantize : public ngraph::op::Op
{ {
public: public:
/// \brief Constructs a Dequantize operation /// \brief Constructs a Dequantize operation
/// \param input quantized input /// \param input quantized input
/// \param scale element type: same as `type`, shape: input shape projected along `axes` /// \param scale scale used for mapping
/// \param offset element type: same as `input`, shape: input shape projected along `axes` /// \param offset offset used for mapping
/// \param type output element type /// \param type output element type
/// \param axes axis positions on which `scale` and `offset` are specified /// \param axes axis positions on which `scale` and `offset` are specified
Dequantize(std::shared_ptr<Node> input, Dequantize(std::shared_ptr<Node> input,
......
...@@ -32,40 +32,56 @@ op::Pad::Pad(const shared_ptr<Node>& arg, ...@@ -32,40 +32,56 @@ op::Pad::Pad(const shared_ptr<Node>& arg,
, m_padding_interior(padding_interior) , m_padding_interior(padding_interior)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
}
void op::Pad::validate_and_infer_types()
{
element::Type result_et;
NODE_VALIDATION_ASSERT(this, get_input_element_type(0) == get_input_element_type(1)) NODE_VALIDATION_ASSERT(
this, element::Type::merge(result_et, get_input_element_type(0), get_input_element_type(1)))
<< "Argument element types do not match (arg0 element type: " << get_input_element_type(0) << "Argument element types do not match (arg0 element type: " << get_input_element_type(0)
<< ", arg1 element type: " << get_input_element_type(1) << ")."; << ", arg1 element type: " << get_input_element_type(1) << ").";
NODE_VALIDATION_ASSERT(this, get_input_shape(1) == Shape{}) NODE_VALIDATION_ASSERT(this, get_input_partial_shape(1).compatible(PartialShape{}))
<< "Argument for padding value is not a scalar (shape: " << get_input_shape(1) << ")."; << "Argument for padding value is not a scalar (shape: " << get_input_partial_shape(1)
<< ").";
auto arg_shape = get_input_shape(0); auto arg_shape = get_input_partial_shape(0);
NODE_VALIDATION_ASSERT(this, arg_shape.size() == padding_below.size()) NODE_VALIDATION_ASSERT(this,
<< "Rank for padding below does not match the rank of the data argument (padding below: " m_padding_below.size() == m_padding_above.size() &&
<< padding_below << ", data argument shape: " << arg_shape << ")."; m_padding_below.size() == m_padding_interior.size())
<< "Ranks for padding below (" << m_padding_below << "), padding above (" << m_padding_above
<< ") and interior padding (" << m_padding_interior << ") "
<< "do not match.";
NODE_VALIDATION_ASSERT(this, arg_shape.size() == padding_above.size()) size_t implied_rank = m_padding_below.size();
<< "Rank for padding above does not match the rank of the data argument (padding above: "
<< padding_above << ", data argument shape: " << arg_shape << ").";
NODE_VALIDATION_ASSERT(this, arg_shape.size() == padding_interior.size()) NODE_VALIDATION_ASSERT(this, arg_shape.rank().compatible(implied_rank))
<< "Rank for interior padding does not match the rank of the data argument (interior " << "Rank for padding below/padding above/interior padding does not match the rank of the "
"padding: " << "data argument (padding below: " << m_padding_below << ", "
<< padding_interior << ", data argument shape: " << arg_shape << ")."; << ", padding above: " << m_padding_above << ", interior padding: " << m_padding_interior
<< ").";
Shape result_shape; std::vector<Dimension> result_dims(implied_rank, Dimension::dynamic());
for (size_t i = 0; i < arg_shape.size(); i++) if (arg_shape.rank().is_static())
{ {
result_shape.push_back( for (size_t i = 0; i < implied_rank; i++)
padding_below[i] + {
subtract_or_zero(arg_shape[i] * (padding_interior[i] + 1), padding_interior[i]) + if (arg_shape[i].is_static())
padding_above[i]); {
result_dims[i] =
m_padding_below[i] +
subtract_or_zero(size_t(arg_shape[i]) * (m_padding_interior[i] + 1),
m_padding_interior[i]) +
m_padding_above[i];
}
}
} }
set_output_type(0, get_input_element_type(0), result_shape); set_output_type(0, result_et, PartialShape(result_dims));
} }
shared_ptr<Node> op::Pad::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Pad::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -51,6 +51,7 @@ namespace ngraph ...@@ -51,6 +51,7 @@ namespace ngraph
virtual std::shared_ptr<Node> get_default_value() const override; virtual std::shared_ptr<Node> get_default_value() const override;
protected: protected:
void validate_and_infer_types() override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
Shape m_padding_below; Shape m_padding_below;
......
...@@ -25,8 +25,8 @@ namespace ngraph ...@@ -25,8 +25,8 @@ namespace ngraph
namespace op namespace op
{ {
/// \brief Quantize operation /// \brief Quantize operation
/// Maps real input (r) to quantized output (q) using scale (s), offset (o) and round mode /// Maps real input (r) to quantized output (q) using scale (s), offset (o) and round mode:
/// q = ROUND(r / s) - o /// q = ROUND(r / s) + o
class Quantize : public ngraph::op::Op class Quantize : public ngraph::op::Op
{ {
public: public:
...@@ -42,8 +42,8 @@ namespace ngraph ...@@ -42,8 +42,8 @@ namespace ngraph
/// \brief Constructs a Quantize operation /// \brief Constructs a Quantize operation
/// \param input real input /// \param input real input
/// \param scale element type: same as `input`, shape: `input` shape projected along `axes` /// \param scale scale used for mapping
/// \param offset element type: same as `type`, shape: `input` shape projected along `axes` /// \param offset offset used for mapping
/// \param type output element type /// \param type output element type
/// \param axes axis positions on which `scale` and `offset` are specified /// \param axes axis positions on which `scale` and `offset` are specified
/// \param round_mode describes how to perform ROUND function /// \param round_mode describes how to perform ROUND function
......
...@@ -278,7 +278,20 @@ extern "C" void delete_backend(runtime::Backend* backend) ...@@ -278,7 +278,20 @@ extern "C" void delete_backend(runtime::Backend* backend)
runtime::intelgpu::IntelGPUBackend::IntelGPUBackend() runtime::intelgpu::IntelGPUBackend::IntelGPUBackend()
{ {
ocl_engine = make_shared<cldnn::engine>(); bool profiling = false;
if (getenv("NGRAPH_INTELGPU_STAT") != nullptr)
{
profiling = true;
}
if (getenv("NGRAPH_INTELGPU_DISABLE_OPTIMIZATIONS") != nullptr)
{
m_disable_backend_optimizations = true;
}
cldnn::engine_configuration cldnn_configuration(profiling);
ocl_engine = make_shared<cldnn::engine>(cldnn_configuration);
} }
shared_ptr<runtime::Tensor> shared_ptr<runtime::Tensor>
...@@ -304,17 +317,21 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func) ...@@ -304,17 +317,21 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func)
} }
cldnn::topology topology; cldnn::topology topology;
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::NopElimination>(); if (!m_disable_backend_optimizations)
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>(); {
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>(); ngraph::pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::ReshapeElimination>();
// GetOutputElementElimination must be after CommonSubexpressionElimination pass_manager.register_pass<ngraph::pass::NopElimination>();
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>(); pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.register_pass<ngraph::pass::ReshapeElimination>();
pass_manager.run_passes(func); // GetOutputElementElimination must be after CommonSubexpressionElimination
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>();
pass_manager.run_passes(func);
}
for (shared_ptr<Node> op : func->get_ops()) for (shared_ptr<Node> op : func->get_ops())
{ {
...@@ -1461,3 +1478,76 @@ bool runtime::intelgpu::IntelGPUBackend::call(shared_ptr<Function> func, ...@@ -1461,3 +1478,76 @@ bool runtime::intelgpu::IntelGPUBackend::call(shared_ptr<Function> func,
return true; return true;
} }
void runtime::intelgpu::IntelGPUBackend::remove_compiled_function(shared_ptr<Function> func)
{
ocl_networks.erase(func);
}
void runtime::intelgpu::IntelGPUBackend::enable_performance_data(shared_ptr<Function> func,
bool enable)
{
FunctionInstance& instance = ocl_networks[func];
if (instance.ocl_network != nullptr)
{
throw runtime_error("Performance data collection must be enabled prior to compiling.");
}
instance.m_performance_counters_enabled = enable;
}
// The cldnn::network contains something like "generic_layer_0_Parameter_254_0" names
// This function should return "Parameter_254" from the example above
static string convert_cldnn_names(shared_ptr<Function> func, const string& cldnn_name)
{
const string key("_");
string result;
const size_t last_key = cldnn_name.rfind(key);
const size_t pre_last_key = cldnn_name.rfind(key, last_key - 1);
const size_t pre_pre_last_key = cldnn_name.rfind(key, pre_last_key - 1);
if (pre_pre_last_key == std::string::npos)
{
result = cldnn_name.substr(0, last_key);
}
else
{
result = cldnn_name.substr(pre_pre_last_key + 1, last_key - pre_pre_last_key - 1);
}
return result;
}
vector<runtime::PerformanceCounter>
runtime::intelgpu::IntelGPUBackend::get_performance_data(shared_ptr<Function> func) const
{
vector<runtime::PerformanceCounter> rc;
auto it = ocl_networks.find(func);
if (it != ocl_networks.end())
{
const shared_ptr<cldnn::network> network = it->second.ocl_network;
if (network != nullptr && it->second.m_performance_counters_enabled)
{
const map<cldnn::primitive_id, cldnn::event>& primitives =
network->get_executed_primitives();
for (const auto& p : primitives)
{
// Let's generate the primitive name that matches to the name in Function
const string primitive_name = convert_cldnn_names(func, p.first);
size_t usec = 0;
for (const auto& q : p.second.get_profiling_info())
{
usec += chrono::duration_cast<
chrono::duration<int64_t, chrono::milliseconds::period>>(
q.value->value())
.count();
}
const runtime::PerformanceCounter perf_counter(primitive_name.c_str(), usec, 1);
rc.push_back(perf_counter);
}
}
}
return rc;
}
...@@ -53,13 +53,21 @@ public: ...@@ -53,13 +53,21 @@ public:
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs, const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) override; const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) override;
void remove_compiled_function(std::shared_ptr<Function> func) override;
void enable_performance_data(std::shared_ptr<Function> func, bool enable) override;
std::vector<PerformanceCounter>
get_performance_data(std::shared_ptr<Function> func) const override;
private: private:
class FunctionInstance class FunctionInstance
{ {
public: public:
std::shared_ptr<cldnn::network> ocl_network = nullptr; std::shared_ptr<cldnn::network> ocl_network = nullptr;
bool m_performance_counters_enabled = false;
}; };
std::map<std::shared_ptr<Function>, FunctionInstance> ocl_networks; std::map<std::shared_ptr<Function>, FunctionInstance> ocl_networks;
std::shared_ptr<cldnn::engine> ocl_engine; std::shared_ptr<cldnn::engine> ocl_engine;
bool m_disable_backend_optimizations = false;
}; };
...@@ -70,6 +70,13 @@ vector<PerfShape> to_perf_shape(shared_ptr<Function> f, ...@@ -70,6 +70,13 @@ vector<PerfShape> to_perf_shape(shared_ptr<Function> f,
for (const runtime::PerformanceCounter& p : perf_data) for (const runtime::PerformanceCounter& p : perf_data)
{ {
auto node = node_map[p.name()]; auto node = node_map[p.name()];
if (node == nullptr)
{
ostringstream os;
os << "Can't find \"" << p.name() << "\" in Function \"" << f->get_name() << "\".";
throw runtime_error(os.str());
}
Shape shape = node->get_outputs()[0].get_shape(); Shape shape = node->get_outputs()[0].get_shape();
result.push_back(PerfShape(p, shape)); result.push_back(PerfShape(p, shape));
} }
......
...@@ -6691,7 +6691,8 @@ TEST(type_prop, pad_deduce_below_padding_wrong_rank) ...@@ -6691,7 +6691,8 @@ TEST(type_prop, pad_deduce_below_padding_wrong_rank)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Rank for padding below does not match the rank of the data argument")); std::string("Ranks for padding below (Shape{5, 3, 0, 6}), padding above (Shape{6, 9, "
"4}) and interior padding (Shape{2, 3, 0}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -6717,9 +6718,10 @@ TEST(type_prop, pad_deduce_above_padding_wrong_rank) ...@@ -6717,9 +6718,10 @@ TEST(type_prop, pad_deduce_above_padding_wrong_rank)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), std::string("Ranks for padding below (Shape{5, 3, 0}), "
std::string("Rank for padding above does not match the rank of the data argument")); "padding above (Shape{6, 9}) and interior "
"padding (Shape{2, 3, 0}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -6747,7 +6749,158 @@ TEST(type_prop, pad_deduce_interior_padding_wrong_rank) ...@@ -6747,7 +6749,158 @@ TEST(type_prop, pad_deduce_interior_padding_wrong_rank)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Rank for interior padding does not match the rank of the data argument")); std::string("Ranks for padding below (Shape{5, 3, 0}), padding above (Shape{6, 9, 4}) "
"and interior padding (Shape{2, 3, 0, 9, 3}) do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, pad_partial_data_rank_dynamic_padding_rank_dynamic_ok)
{
auto param0 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
Shape padding_below{2, 4, 6};
Shape padding_above{8, 2, 3};
Shape padding_interior{1, 0, 1};
auto pad = make_shared<op::Pad>(param0, param1, padding_below, padding_above, padding_interior);
ASSERT_EQ(pad->get_output_element_type(0), element::f32);
ASSERT_TRUE(pad->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop, pad_partial_data_rank_dynamic_padding_rank_dynamic_attribs_rank_inconsistent)
{
auto param0 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
Shape padding_below{2, 4, 6};
Shape padding_above{8, 2, 3, 0};
Shape padding_interior{1, 0, 1};
try
{
auto pad =
make_shared<op::Pad>(param0, param1, padding_below, padding_above, padding_interior);
FAIL() << "Inconsistent attribute ranks not detected (rank-dynamic/rank-dynamic arguments)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Ranks for padding below (Shape{2, 4, 6}), padding above (Shape{8, 2, 3, "
"0}) and interior padding (Shape{1, 0, 1}) do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, pad_partial_data_rank_static_dynamic_padding_rank_dynamic_ok)
{
auto param0 = make_shared<op::Parameter>(
element::f32,
PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()});
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
Shape padding_below{2, 4, 6};
Shape padding_above{8, 2, 3};
Shape padding_interior{1, 0, 1};
auto pad = make_shared<op::Pad>(param0, param1, padding_below, padding_above, padding_interior);
ASSERT_EQ(pad->get_output_element_type(0), element::f32);
ASSERT_TRUE(pad->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop, pad_partial_data_rank_static_dynamic_some_dims_known_padding_rank_dynamic_ok)
{
auto param0 =
make_shared<op::Parameter>(element::f32, PartialShape{3, 5, Dimension::dynamic()});
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
Shape padding_below{2, 4, 6};
Shape padding_above{8, 2, 3};
Shape padding_interior{1, 0, 1};
auto pad = make_shared<op::Pad>(param0, param1, padding_below, padding_above, padding_interior);
ASSERT_EQ(pad->get_output_element_type(0), element::f32);
ASSERT_TRUE(
pad->get_output_partial_shape(0).same_scheme(PartialShape{15, 11, Dimension::dynamic()}));
}
TEST(type_prop, pad_partial_data_rank_dynamic_padding_static_ok)
{
auto param0 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, Shape{});
Shape padding_below{2, 4, 6};
Shape padding_above{8, 2, 3};
Shape padding_interior{1, 0, 1};
auto pad = make_shared<op::Pad>(param0, param1, padding_below, padding_above, padding_interior);
ASSERT_EQ(pad->get_output_element_type(0), element::f32);
ASSERT_TRUE(pad->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop, pad_partial_data_rank_dynamic_padding_static_wrong_padding_rank)
{
auto param0 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, Shape{2, 3, 8});
Shape padding_below{2, 4, 6};
Shape padding_above{8, 2, 3};
Shape padding_interior{1, 0, 1};
try
{
auto pad =
make_shared<op::Pad>(param0, param1, padding_below, padding_above, padding_interior);
FAIL() << "Wrong padding rank not detected (rank-dynamic/static arguments)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Argument for padding value is not a scalar (shape: {2,3,8})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, pad_partial_data_rank_dynamic_padding_static_attribs_rank_inconsistent)
{
auto param0 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, Shape{});
Shape padding_below{2, 4, 6};
Shape padding_above{8, 2, 3, 4};
Shape padding_interior{1, 0, 1};
try
{
auto pad =
make_shared<op::Pad>(param0, param1, padding_below, padding_above, padding_interior);
FAIL() << "Wrong padding rank not detected (rank-dynamic/static arguments)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Ranks for padding below (Shape{2, 4, 6}), padding above (Shape{8, 2, 3, "
"4}) and interior padding (Shape{1, 0, 1}) do not match"));
} }
catch (...) catch (...)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment