Unverified Commit da1cacde authored by Adam Procter's avatar Adam Procter Committed by GitHub

Add more infrastructure for specialization of cloned graphs (#2949)

* Virtualize some things that crash when layout descriptor is missing

* More shape specialization

* (very bare) skeleton for dyn elimination

* Miscellaneous

* Lift i32->int64-only restriction on constant folding for Convert

* Add constant folding for ShapeOf, and some tests for new constant folders

* Tests for DynElimination

* Rename specialize_shapes to specialize_function, and add a unit test for value substitution

* Roll back overeager API change in dyn slice bprop (it has to handle right-indexed axes; bummer)

* Add a test for dynamic usage of transpose op

* Fix warning/error about variable shadowing

* Strengthen checks in apply_permutation

* Propagate Constant shapes through Transpose

* Add CHANGE_DYNAMIC_STATE where appropriate

* PR feedback, and fix unit test failure

* Fix PR reference in comment

* PR comments

* Comments for helper funcs

* Remove unique_ptr indirection for the AlignedBuffers

* Fix incorrect indexing of AlignedBuffer vector (whoops\!)

* Remove unnecessary CHANGE_DYAMIC_STATEs

* De-update pass property unit test for const folding

* Replace mystery runes with all_pass_property_off
parent ba546455
...@@ -358,6 +358,8 @@ set (SRC ...@@ -358,6 +358,8 @@ set (SRC
pass/cse.hpp pass/cse.hpp
pass/dump_sorted.cpp pass/dump_sorted.cpp
pass/dump_sorted.hpp pass/dump_sorted.hpp
pass/dyn_elimination.cpp
pass/dyn_elimination.hpp
pass/fused_op_decomposition.cpp pass/fused_op_decomposition.cpp
pass/fused_op_decomposition.hpp pass/fused_op_decomposition.hpp
pass/get_output_element_elimination.cpp pass/get_output_element_elimination.cpp
...@@ -435,8 +437,8 @@ set (SRC ...@@ -435,8 +437,8 @@ set (SRC
shape.hpp shape.hpp
shape_util.cpp shape_util.cpp
shape_util.hpp shape_util.hpp
specialize_shapes.cpp specialize_function.cpp
specialize_shapes.hpp specialize_function.hpp
state/rng_state.cpp state/rng_state.cpp
strides.cpp strides.cpp
strides.hpp strides.hpp
......
...@@ -165,5 +165,5 @@ ...@@ -165,5 +165,5 @@
#include "ngraph/runtime/tensor.hpp" #include "ngraph/runtime/tensor.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/shape_util.hpp" #include "ngraph/shape_util.hpp"
#include "ngraph/specialize_shapes.hpp" #include "ngraph/specialize_function.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
...@@ -171,6 +171,54 @@ Strides op::Constant::get_strides_val() const ...@@ -171,6 +171,54 @@ Strides op::Constant::get_strides_val() const
return output_strides; return output_strides;
} }
Coordinate op::Constant::get_coordinate_val() const
{
NGRAPH_CHECK(m_element_type == element::i64);
std::vector<int64_t> out_coordinate = get_vector<int64_t>();
Coordinate output_coordinate(shape_size(m_shape));
std::transform(out_coordinate.begin(),
out_coordinate.end(),
output_coordinate.begin(),
[&](const int64_t& v) { return (v > 0) ? v : 0; });
return output_coordinate;
}
CoordinateDiff op::Constant::get_coordinate_diff_val() const
{
NGRAPH_CHECK(m_element_type == element::i64);
std::vector<int64_t> out_coordinate_diff = get_vector<int64_t>();
CoordinateDiff output_coordinate_diff(shape_size(m_shape));
std::transform(out_coordinate_diff.begin(),
out_coordinate_diff.end(),
output_coordinate_diff.begin(),
[&](const int64_t& v) { return (v > 0) ? v : 0; });
return output_coordinate_diff;
}
AxisVector op::Constant::get_axis_vector_val() const
{
NGRAPH_CHECK(m_element_type == element::i64);
std::vector<int64_t> out_axis_vector = get_vector<int64_t>();
AxisVector output_axis_vector(shape_size(m_shape));
std::transform(out_axis_vector.begin(),
out_axis_vector.end(),
output_axis_vector.begin(),
[&](const int64_t& v) { return (v > 0) ? v : 0; });
return output_axis_vector;
}
AxisSet op::Constant::get_axis_set_val() const
{
NGRAPH_CHECK(m_element_type == element::i64);
std::vector<int64_t> out_axis_set = get_vector<int64_t>();
AxisSet output_axis_set;
for (auto& axis : get_vector<int64_t>())
{
output_axis_set.insert(axis > 0 ? axis : 0);
}
return output_axis_set;
}
shared_ptr<Node> op::Constant::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Constant::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <cstring> #include <cstring>
#include <sstream> #include <sstream>
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/runtime/aligned_buffer.hpp" #include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/type/bfloat16.hpp" #include "ngraph/type/bfloat16.hpp"
...@@ -156,10 +157,31 @@ namespace ngraph ...@@ -156,10 +157,31 @@ namespace ngraph
/// Can only be used on element::i64 nodes and interprets /// Can only be used on element::i64 nodes and interprets
/// negative values as zeros. /// negative values as zeros.
Shape get_shape_val() const; Shape get_shape_val() const;
/// \brief Returns the value of the constant node as a Strides object /// \brief Returns the value of the constant node as a Strides
/// object
/// Can only be used on element::i64 nodes and interprets /// Can only be used on element::i64 nodes and interprets
/// negative values as zeros. /// negative values as zeros.
Strides get_strides_val() const; Strides get_strides_val() const;
/// \brief Returns the value of the constant node as a Coordinate
/// object
/// Can only be used on element::i64 nodes and interprets
/// negative values as zeros.
Coordinate get_coordinate_val() const;
/// \brief Returns the value of the constant node as a
/// CoordinateDiff object
/// Can only be used on element::i64 nodes.
CoordinateDiff get_coordinate_diff_val() const;
/// \brief Returns the value of the constant node as an AxisVector
/// object
/// Can only be used on element::i64 nodes and interprets
/// negative values as zeros.
AxisVector get_axis_vector_val() const;
/// \brief Returns the value of the constant node as an AxisSet
/// object
/// Can only be used on element::i64 nodes and interprets
/// negative values as zeros.
/// Repeated values are allowed.
AxisSet get_axis_set_val() const;
/// \brief Wrapper around constructing a shared_ptr of a Constant /// \brief Wrapper around constructing a shared_ptr of a Constant
/// ///
......
...@@ -49,14 +49,7 @@ void op::DynReshape::validate_and_infer_types() ...@@ -49,14 +49,7 @@ void op::DynReshape::validate_and_infer_types()
set_input_is_relevant_to_shape(1); set_input_is_relevant_to_shape(1);
if (auto const_shape = dynamic_pointer_cast<op::Constant>(get_argument(1))) if (auto const_shape = dynamic_pointer_cast<op::Constant>(get_argument(1)))
{ {
// TODO: replace with const_shape->get_shapes_val() set_output_type(0, get_input_element_type(0), const_shape->get_shape_val());
auto out_shape = const_shape->get_vector<int64_t>();
Shape output_shape(shape_size(const_shape->get_shape()));
std::transform(out_shape.begin(),
out_shape.end(),
output_shape.begin(),
[&](const int64_t& v) { return max(v, int64_t(0)); });
set_output_type(0, get_input_element_type(0), output_shape);
} }
else else
{ {
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <iostream> #include <iostream>
#include "ngraph/op/constant.hpp"
#include "ngraph/op/experimental/transpose.hpp" #include "ngraph/op/experimental/transpose.hpp"
using namespace std; using namespace std;
...@@ -43,7 +44,23 @@ void op::Transpose::validate_and_infer_types() ...@@ -43,7 +44,23 @@ void op::Transpose::validate_and_infer_types()
"Input order must have shape [n], where n is the rank of arg."); "Input order must have shape [n], where n is the rank of arg.");
set_input_is_relevant_to_shape(1); set_input_is_relevant_to_shape(1);
if (auto input_const = std::dynamic_pointer_cast<op::Constant>(get_argument(1)))
{
auto permutation = input_const->get_axis_vector_val();
NODE_VALIDATION_CHECK(this,
is_valid_permutation(permutation, arg_shape.rank()),
"Permutation ",
permutation,
" is not valid for input shape ",
arg_shape);
set_output_type(
0, get_input_element_type(0), ngraph::apply_permutation(arg_shape, permutation));
}
else
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(arg_shape.rank())); set_output_type(0, get_input_element_type(0), PartialShape::dynamic(arg_shape.rank()));
}
} }
shared_ptr<Node> op::Transpose::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Transpose::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -22,8 +22,10 @@ ...@@ -22,8 +22,10 @@
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/dequantize.hpp" #include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/maximum.hpp" #include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp" #include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
...@@ -39,6 +41,7 @@ ...@@ -39,6 +41,7 @@
#include "ngraph/runtime/reference/abs.hpp" #include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/add.hpp" #include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/broadcast.hpp" #include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/convert.hpp"
#include "ngraph/runtime/reference/dequantize.hpp" #include "ngraph/runtime/reference/dequantize.hpp"
#include "ngraph/runtime/reference/divide.hpp" #include "ngraph/runtime/reference/divide.hpp"
#include "ngraph/runtime/reference/maximum.hpp" #include "ngraph/runtime/reference/maximum.hpp"
...@@ -763,3 +766,188 @@ void pass::ConstantFolding::construct_constant_quantize() ...@@ -763,3 +766,188 @@ void pass::ConstantFolding::construct_constant_quantize()
this->add_matcher( this->add_matcher(
quantize_matcher, constant_quantize_callback, PassProperty::REQUIRE_STATIC_SHAPE); quantize_matcher, constant_quantize_callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
// Helper for mapping element::Types to runtime::reference::convert, which is templated in C++
// data types. Used by fold_constant_convert and fold_constant_convert_helper0, which respectively
// determine the appropriate C++ types for "TI" (input type) and "TO" (output type).
template <typename TI, typename TO>
shared_ptr<op::Constant> fold_constant_convert_helper1(shared_ptr<op::Constant> constant,
const element::Type& output_element_type)
{
auto out_shape = constant->get_shape();
vector<TO> out_vec(shape_size(out_shape));
runtime::reference::convert<TI, TO>(
constant->get_vector<TI>().data(), out_vec.data(), shape_size(out_shape));
return make_shared<op::Constant>(output_element_type, out_shape, out_vec);
}
// Helper for mapping element::Types to runtime::reference::convert, which is templated in C++
// data types. Used by fold_constant_convert, which determines the appropriate C++ type for "TI"
// (input type).
template <typename TI>
shared_ptr<op::Constant> fold_constant_convert_helper0(shared_ptr<op::Constant> constant,
const element::Type& output_element_type)
{
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (output_element_type.get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_convert");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_convert");
break;
case element::Type_t::boolean:
return fold_constant_convert_helper1<TI, char>(constant, output_element_type);
case element::Type_t::bf16:
return fold_constant_convert_helper1<TI, bfloat16>(constant, output_element_type);
case element::Type_t::f16:
return fold_constant_convert_helper1<TI, float16>(constant, output_element_type);
case element::Type_t::f32:
return fold_constant_convert_helper1<TI, float>(constant, output_element_type);
case element::Type_t::f64:
return fold_constant_convert_helper1<TI, double>(constant, output_element_type);
case element::Type_t::i8:
return fold_constant_convert_helper1<TI, int8_t>(constant, output_element_type);
case element::Type_t::i16:
return fold_constant_convert_helper1<TI, int16_t>(constant, output_element_type);
case element::Type_t::i32:
return fold_constant_convert_helper1<TI, int32_t>(constant, output_element_type);
case element::Type_t::i64:
return fold_constant_convert_helper1<TI, int64_t>(constant, output_element_type);
case element::Type_t::u8:
return fold_constant_convert_helper1<TI, uint8_t>(constant, output_element_type);
case element::Type_t::u16:
return fold_constant_convert_helper1<TI, uint16_t>(constant, output_element_type);
case element::Type_t::u32:
return fold_constant_convert_helper1<TI, uint32_t>(constant, output_element_type);
case element::Type_t::u64:
return fold_constant_convert_helper1<TI, uint64_t>(constant, output_element_type);
}
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#endif
}
static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> constant,
const element::Type& output_element_type)
{
auto& input_element_type = constant->get_output_element_type(0);
if (input_element_type == output_element_type)
{
return constant;
}
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (input_element_type.get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_convert");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_convert");
break;
case element::Type_t::boolean:
return fold_constant_convert_helper0<char>(constant, output_element_type);
case element::Type_t::bf16:
return fold_constant_convert_helper0<bfloat16>(constant, output_element_type);
case element::Type_t::f16:
return fold_constant_convert_helper0<float16>(constant, output_element_type);
case element::Type_t::f32:
return fold_constant_convert_helper0<float>(constant, output_element_type);
case element::Type_t::f64:
return fold_constant_convert_helper0<double>(constant, output_element_type);
case element::Type_t::i8:
return fold_constant_convert_helper0<int8_t>(constant, output_element_type);
case element::Type_t::i16:
return fold_constant_convert_helper0<int16_t>(constant, output_element_type);
case element::Type_t::i32:
return fold_constant_convert_helper0<int32_t>(constant, output_element_type);
case element::Type_t::i64:
return fold_constant_convert_helper0<int64_t>(constant, output_element_type);
case element::Type_t::u8:
return fold_constant_convert_helper0<uint8_t>(constant, output_element_type);
case element::Type_t::u16:
return fold_constant_convert_helper0<uint16_t>(constant, output_element_type);
case element::Type_t::u32:
return fold_constant_convert_helper0<uint32_t>(constant, output_element_type);
case element::Type_t::u64:
return fold_constant_convert_helper0<uint64_t>(constant, output_element_type);
}
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#endif
}
void pass::ConstantFolding::construct_constant_convert()
{
auto constant_label = make_shared<pattern::op::Label>(
element::i32, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto convert_op = make_shared<op::Convert>(constant_label, element::i64);
auto constant_convert_callback = [constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_convert_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto convert_match = static_pointer_cast<op::Convert>(m.get_match_root());
replace_node(
m.get_match_root(),
fold_constant_convert(constant_match, convert_match->get_output_element_type(0)));
return true;
};
auto convert_matcher =
make_shared<pattern::Matcher>(convert_op, "ConstantFolding.ConstantConvert");
this->add_matcher(convert_matcher, constant_convert_callback, all_pass_property_off);
}
// ShapeOf is a bit of an odd duck: it doesn't matter if the input's value is
// constant, as long as it has static shape.
void pass::ConstantFolding::construct_constant_shape_of()
{
auto arg_label = make_shared<pattern::op::Label>(element::i32, Shape{2, 3, 4});
auto shape_of_op = make_shared<op::ShapeOf>(arg_label);
auto constant_shape_of_callback = [arg_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_shape_of_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto arg_match = pattern_map[arg_label];
if (arg_match->get_output_partial_shape(0).is_static())
{
auto arg_shape = arg_match->get_output_shape(0);
auto replacement =
make_shared<op::Constant>(element::i64, Shape{arg_shape.size()}, arg_shape.data());
replace_node(m.get_match_root(), replacement);
return true;
}
else
{
return false;
}
};
auto shape_of_matcher =
make_shared<pattern::Matcher>(shape_of_op, "ConstantFolding.ConstantShapeOf");
this->add_matcher(shape_of_matcher, constant_shape_of_callback, all_pass_property_off);
}
...@@ -38,7 +38,9 @@ public: ...@@ -38,7 +38,9 @@ public:
DEQUANTIZE, DEQUANTIZE,
UNARY, UNARY,
BINARY, BINARY,
QUANTIZE QUANTIZE,
CONVERT,
SHAPE_OF
}; };
ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap()) ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap())
...@@ -52,6 +54,8 @@ public: ...@@ -52,6 +54,8 @@ public:
construct_constant_binary(); construct_constant_binary();
construct_constant_quantize(); construct_constant_quantize();
construct_constant_dequantize(); construct_constant_dequantize();
construct_constant_convert();
construct_constant_shape_of();
} }
//this allows to specify the order in which matchers will be run //this allows to specify the order in which matchers will be run
...@@ -72,6 +76,8 @@ public: ...@@ -72,6 +76,8 @@ public:
case CFTransformations::BINARY: construct_constant_binary(); break; case CFTransformations::BINARY: construct_constant_binary(); break;
case CFTransformations::DEQUANTIZE: construct_constant_dequantize(); break; case CFTransformations::DEQUANTIZE: construct_constant_dequantize(); break;
case CFTransformations::QUANTIZE: construct_constant_quantize(); break; case CFTransformations::QUANTIZE: construct_constant_quantize(); break;
case CFTransformations::CONVERT: construct_constant_convert(); break;
case CFTransformations::SHAPE_OF: construct_constant_shape_of(); break;
} }
} }
} }
...@@ -84,6 +90,8 @@ private: ...@@ -84,6 +90,8 @@ private:
void construct_constant_binary(); void construct_constant_binary();
void construct_constant_quantize(); void construct_constant_quantize();
void construct_constant_dequantize(); void construct_constant_dequantize();
void construct_constant_convert();
void construct_constant_shape_of();
ngraph::BuildNodeExecutorMap m_cfmap; ngraph::BuildNodeExecutorMap m_cfmap;
}; };
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "dyn_elimination.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
using namespace std;
using namespace ngraph;
pass::DynElimination::DynElimination()
: GraphRewrite()
{
construct_transpose();
}
void pass::DynElimination::construct_transpose()
{
auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
auto perm_arg_label =
make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
auto transpose = make_shared<op::Transpose>(data_arg_label, perm_arg_label);
auto transpose_callback = [data_arg_label, perm_arg_label](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
auto data_arg = pattern_map[data_arg_label];
auto perm_arg = static_pointer_cast<op::Constant>(pattern_map[perm_arg_label]);
// TODO(amprocte): Can't handle the case where data shape is dynamic, because static
// Reshape requries the exact output shape to be declared. See if we can come up with a
// workaround.
if (data_arg->get_output_partial_shape(0).is_dynamic())
{
return false;
}
auto& data_shape = data_arg->get_output_shape(0);
// TODO(amprocte): These should be redundant if the graph is validated. Necessary?
if (perm_arg->get_element_type() != element::i64 ||
perm_arg->get_output_partial_shape(0).is_dynamic() ||
perm_arg->get_output_shape(0).size() != 1)
{
return false;
}
auto perm = perm_arg->get_axis_vector_val();
auto output_shape = ngraph::apply_permutation(data_shape, perm);
auto replacement = std::make_shared<op::Reshape>(data_arg, perm, output_shape);
replace_node(m.get_match_root(), replacement);
return true;
};
auto transpose_matcher = make_shared<pattern::Matcher>(transpose, "DynElimination.Transpose");
add_matcher(transpose_matcher, transpose_callback, all_pass_property_off);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace pass
{
class DynElimination : public GraphRewrite
{
public:
DynElimination();
private:
void construct_transpose();
};
}
}
...@@ -16,9 +16,11 @@ ...@@ -16,9 +16,11 @@
#include "ngraph/runtime/dynamic/dynamic_backend.hpp" #include "ngraph/runtime/dynamic/dynamic_backend.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/pass/constant_folding.hpp"
#include "ngraph/pass/dyn_elimination.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/shape_relevance.hpp" #include "ngraph/pass/shape_relevance.hpp"
#include "ngraph/specialize_shapes.hpp" #include "ngraph/specialize_function.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std; using namespace std;
...@@ -79,16 +81,51 @@ bool runtime::dynamic::DynamicExecutable::call( ...@@ -79,16 +81,51 @@ bool runtime::dynamic::DynamicExecutable::call(
// (1) all shapes; // (1) all shapes;
// (2) all values of shape-relevant input tensors. // (2) all values of shape-relevant input tensors.
NGRAPH_CHECK(m_wrapped_function->get_parameters().size() == inputs.size());
std::vector<std::shared_ptr<runtime::Tensor>> wrapped_inputs; std::vector<std::shared_ptr<runtime::Tensor>> wrapped_inputs;
std::vector<element::Type> arg_element_types; std::vector<element::Type> arg_element_types;
std::vector<PartialShape> arg_shapes; std::vector<PartialShape> arg_shapes;
std::shared_ptr<Function> clone;
{
// We'll use AlignedBuffers to back the base pointers, storing them in this vector for RAII
// purposes.
std::vector<AlignedBuffer> arg_buffers;
arg_buffers.reserve(inputs.size());
std::vector<void*> arg_value_base_pointers(inputs.size());
size_t i = 0;
for (auto& input : inputs) for (auto& input : inputs)
{ {
if (auto dynamic_tensor = std::dynamic_pointer_cast<runtime::dynamic::DynamicTensor>(input)) if (m_wrapped_function->get_parameters()[i]->is_relevant_to_shapes())
{
// TODO(amprocte): Move has_storage() to runtime::Tensor?
if (auto dynamic_tensor =
std::dynamic_pointer_cast<runtime::dynamic::DynamicTensor>(input))
{
NGRAPH_CHECK(dynamic_tensor->has_storage());
}
arg_buffers.emplace_back(input->get_size_in_bytes(), /*alignment=*/64);
arg_value_base_pointers[i] = arg_buffers.back().get_ptr();
// TODO(amprocte): For host-resident tensors we should be able to skip the read,
// but no API for that yet.
input->read(arg_value_base_pointers[i], 0, input->get_size_in_bytes());
}
else
{
arg_value_base_pointers[i] = nullptr;
}
if (auto dynamic_tensor =
std::dynamic_pointer_cast<runtime::dynamic::DynamicTensor>(input))
{ {
NGRAPH_CHECK(dynamic_tensor->has_storage()); NGRAPH_CHECK(dynamic_tensor->has_storage());
arg_element_types.push_back(dynamic_tensor->get_wrapped_tensor()->get_element_type()); arg_element_types.push_back(
dynamic_tensor->get_wrapped_tensor()->get_element_type());
arg_shapes.push_back(dynamic_tensor->get_wrapped_tensor()->get_shape()); arg_shapes.push_back(dynamic_tensor->get_wrapped_tensor()->get_shape());
wrapped_inputs.push_back(dynamic_tensor->get_wrapped_tensor()); wrapped_inputs.push_back(dynamic_tensor->get_wrapped_tensor());
} }
...@@ -98,11 +135,19 @@ bool runtime::dynamic::DynamicExecutable::call( ...@@ -98,11 +135,19 @@ bool runtime::dynamic::DynamicExecutable::call(
arg_shapes.push_back(input->get_shape()); arg_shapes.push_back(input->get_shape());
wrapped_inputs.push_back(input); wrapped_inputs.push_back(input);
} }
i++;
}
clone = specialize_function(
m_wrapped_function, arg_element_types, arg_shapes, arg_value_base_pointers);
} }
// TODO: specialize_shapes needs to fill in values of shape-relevant params. pass::Manager passes;
auto clone = specialize_shapes(m_wrapped_function, arg_element_types, arg_shapes); passes.register_pass<pass::ConstantFolding>();
// TODO: run constant folding and de-dynification on clone. passes.register_pass<pass::DynElimination>();
passes.run_passes(clone);
const ResultVector& results = clone->get_results(); const ResultVector& results = clone->get_results();
NGRAPH_CHECK(results.size() == outputs.size()); NGRAPH_CHECK(results.size() == outputs.size());
...@@ -140,6 +185,27 @@ runtime::dynamic::DynamicTensor::DynamicTensor( ...@@ -140,6 +185,27 @@ runtime::dynamic::DynamicTensor::DynamicTensor(
{ {
} }
Strides runtime::dynamic::DynamicTensor::get_strides() const
{
NGRAPH_CHECK(m_wrapped_tensor != nullptr,
"asked for strides of a dynamic tensor with no allocated storage");
return ngraph::row_major_strides(m_wrapped_tensor->get_shape());
}
size_t runtime::dynamic::DynamicTensor::get_size_in_bytes() const
{
NGRAPH_CHECK(m_wrapped_tensor != nullptr,
"asked for size in bytes of a dynamic tensor with no allocated storage");
return get_element_count() * get_element_type().size();
}
size_t runtime::dynamic::DynamicTensor::get_element_count() const
{
NGRAPH_CHECK(m_wrapped_tensor != nullptr,
"asked for element count of a dynamic tensor with no allocated storage");
return shape_size(m_wrapped_tensor->get_shape());
}
const element::Type& runtime::dynamic::DynamicTensor::get_element_type() const const element::Type& runtime::dynamic::DynamicTensor::get_element_type() const
{ {
if (m_wrapped_tensor == nullptr) if (m_wrapped_tensor == nullptr)
......
...@@ -127,6 +127,9 @@ public: ...@@ -127,6 +127,9 @@ public:
DynamicTensor(const element::Type& element_type, DynamicTensor(const element::Type& element_type,
const PartialShape& shape, const PartialShape& shape,
const std::shared_ptr<runtime::Backend>& wrapped_backend); const std::shared_ptr<runtime::Backend>& wrapped_backend);
virtual ngraph::Strides get_strides() const override;
virtual size_t get_size_in_bytes() const override;
virtual size_t get_element_count() const override;
virtual const element::Type& get_element_type() const override; virtual const element::Type& get_element_type() const override;
virtual const ngraph::Shape& get_shape() const override; virtual const ngraph::Shape& get_shape() const override;
virtual void write(const void* p, size_t offset, size_t n) override; virtual void write(const void* p, size_t offset, size_t n) override;
......
...@@ -58,7 +58,7 @@ namespace ngraph ...@@ -58,7 +58,7 @@ namespace ngraph
/// \brief Get tensor strides /// \brief Get tensor strides
/// \return Strides /// \return Strides
ngraph::Strides get_strides() const; virtual ngraph::Strides get_strides() const;
/// \brief Get tensor element type /// \brief Get tensor element type
/// \return element::Type /// \return element::Type
...@@ -66,11 +66,11 @@ namespace ngraph ...@@ -66,11 +66,11 @@ namespace ngraph
/// \brief Get number of elements in the tensor /// \brief Get number of elements in the tensor
/// \return number of elements in the tensor /// \return number of elements in the tensor
size_t get_element_count() const; virtual size_t get_element_count() const;
/// \brief Get the size in bytes of the tensor /// \brief Get the size in bytes of the tensor
/// \return number of bytes in tensor's allocation /// \return number of bytes in tensor's allocation
size_t get_size_in_bytes() const; virtual size_t get_size_in_bytes() const;
/// \brief Get tensor's unique name /// \brief Get tensor's unique name
/// \return tensor's name /// \return tensor's name
......
...@@ -14,17 +14,20 @@ ...@@ -14,17 +14,20 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "ngraph/specialize_shapes.hpp" #include "ngraph/specialize_function.hpp"
#include "ngraph/op/constant.hpp"
using namespace ngraph; using namespace ngraph;
std::shared_ptr<Function> std::shared_ptr<Function>
ngraph::specialize_shapes(std::shared_ptr<Function> f, ngraph::specialize_function(std::shared_ptr<Function> f,
const std::vector<element::Type>& parameter_element_types, const std::vector<element::Type>& parameter_element_types,
const std::vector<PartialShape>& parameter_shapes) const std::vector<PartialShape>& parameter_shapes,
const std::vector<void*>& parameter_values)
{ {
NGRAPH_CHECK(f->get_parameters().size() == parameter_shapes.size()); NGRAPH_CHECK(f->get_parameters().size() == parameter_shapes.size());
NGRAPH_CHECK(f->get_parameters().size() == parameter_element_types.size()); NGRAPH_CHECK(f->get_parameters().size() == parameter_element_types.size());
NGRAPH_CHECK(f->get_parameters().size() == parameter_values.size());
NodeMap m; NodeMap m;
...@@ -35,9 +38,18 @@ std::shared_ptr<Function> ...@@ -35,9 +38,18 @@ std::shared_ptr<Function>
NGRAPH_CHECK(f->get_parameters()[i]->get_element_type().is_dynamic() || NGRAPH_CHECK(f->get_parameters()[i]->get_element_type().is_dynamic() ||
parameter_element_types[i] == f->get_parameters()[i]->get_element_type()); parameter_element_types[i] == f->get_parameters()[i]->get_element_type());
if (parameter_values[i] != nullptr && parameter_shapes[i].is_static() &&
parameter_element_types[i].is_static())
{
m[f->get_parameters()[i].get()] = std::make_shared<op::Constant>(
parameter_element_types[i], parameter_shapes[i].to_shape(), parameter_values[i]);
}
else
{
m[f->get_parameters()[i].get()] = m[f->get_parameters()[i].get()] =
std::make_shared<op::Parameter>(parameter_element_types[i], parameter_shapes[i]); std::make_shared<op::Parameter>(parameter_element_types[i], parameter_shapes[i]);
} }
}
for (auto old_node : f->get_ordered_ops()) for (auto old_node : f->get_ordered_ops())
{ {
...@@ -59,7 +71,16 @@ std::shared_ptr<Function> ...@@ -59,7 +71,16 @@ std::shared_ptr<Function>
ParameterVector new_parameters = f->get_parameters(); ParameterVector new_parameters = f->get_parameters();
for (size_t i = 0; i < new_parameters.size(); i++) for (size_t i = 0; i < new_parameters.size(); i++)
{ {
new_parameters[i] = std::static_pointer_cast<op::Parameter>(m[new_parameters[i].get()]); new_parameters[i] = std::dynamic_pointer_cast<op::Parameter>(m[new_parameters[i].get()]);
// If the replacement for a Parameter is not itself a Parameter, we must have replaced it
// with a constant. We will insert a dead Parameter into the clone's parameters, in order
// to maintain the arity of the original function.
if (new_parameters[i] == nullptr)
{
new_parameters[i] =
std::make_shared<op::Parameter>(parameter_element_types[i], parameter_shapes[i]);
}
} }
ResultVector new_results = f->get_results(); ResultVector new_results = f->get_results();
......
...@@ -20,17 +20,23 @@ ...@@ -20,17 +20,23 @@
namespace ngraph namespace ngraph
{ {
/// \brief Creates a clone of a function, with the shapes of that function's parameters /// \brief Creates a "specialized" clone of a function. The partial shapes and element types of
/// specialized to some more specific element types and shapes. /// the function's parameters may be narrowed to more specific shapes and element types,
/// and constant values may optionally be substituted for any or all of the parameters.
/// \param f The function to be cloned. /// \param f The function to be cloned.
/// \param parameter_element_types The new parameter element types to substitute. /// \param parameter_element_types The new parameter element types to substitute. Length must
/// \param parameter_shapes The new parameter shapes to substitute. /// be equal to the number of parameters of f.
/// \return A clone of f, with the parameter element types and shapes specialized. /// \param parameter_shapes The new parameter shapes to substitute. Length must be equal to the
/// \throws CheckFailure if parameter_element_types or parameter_shapes is not valid /// number of parameters of f.
/// \param parameter_values Parameter values to substitute. Length must be equal to the number
/// of parameters of f, with nullptr indicating that no substitution is to be made for
/// the corresponding parameter.
/// \return A clone of f, with the parameter element types, shapes, and values specialized.
/// \throws CheckFailure if parameter_element_types, parameter_shapes is not valid
/// (see details). /// (see details).
/// \throws NodeValidationError if node validation fails as the clone is being constructed. /// \throws NodeValidationError if node validation fails as the clone is being constructed.
/// ///
/// Creates a "shape-specialized" clone of an nGraph Function function. /// Creates a "specialized" clone of an nGraph Function.
/// ///
/// For example, suppose that a function f has three parameters with partial shapes: /// For example, suppose that a function f has three parameters with partial shapes:
/// ///
...@@ -75,17 +81,31 @@ namespace ngraph ...@@ -75,17 +81,31 @@ namespace ngraph
/// specialized to itself (e.g., specialization does not allow you to change `element::i32` /// specialized to itself (e.g., specialization does not allow you to change `element::i32`
/// to `element::i64`). /// to `element::i64`).
/// ///
/// Finally, it is possible to specialize parameter values. If the ith element of
/// `parameter_values` is not `nullptr`, and fully static element type and shape has been
/// specified for the ith parameter, a `Constant` node will be created and substituted for the
/// ith parameter, with its data drawn from `parameter_values[i]`. Note that the Parameter node
/// remains (in order to maintain the arity of the function), but will no longer have any
/// users.
///
/// It is required that: /// It is required that:
/// 1. The length of parameter_element_types and parameter_shapes is the same as the number /// 1. The length of parameter_element_types, parameter_shapes, and parameter_values is the
/// of f's parameters. /// same as the number of f's parameters.
/// 2. Each shape in parameter_shapes is a refinement of the shape of the corresponding /// 2. Each shape in parameter_shapes is a refinement of the shape of the corresponding
/// parameter of f. Roughly speaking, a shape s1 is said to "refine" s2 if s1 can be /// parameter of f. Roughly speaking, a shape s1 is said to "refine" s2 if s1 can be
/// obtained from s2 by filling in s2's question marks. See PartialShape::refines for /// obtained from s2 by filling in s2's question marks. See PartialShape::refines for
/// more details. /// more details.
/// 3. For all i, either the element type of fp_i is dynamic, or fp_i is the same as /// 3. For all i, either the element type of fp_i is dynamic, or fp_i is the same as
/// parameter_element_types[i]. (Here fp_i is the ith parameter of f.) /// parameter_element_types[i]. (Here fp_i is the ith parameter of f.)
/// 4. For all i where parameter_values[i] != nullptr and parameter_element_types[i] is
/// static and parameter_shapes[i] is static, parameter_values[i] points to a buffer from
/// which a Constant node with element type parameter_element_types[i] and shape
/// parameter_shapes[i] can be created.
///
/// TODO(amprocte): convert this to a pass.
std::shared_ptr<Function> std::shared_ptr<Function>
specialize_shapes(std::shared_ptr<Function> f, specialize_function(std::shared_ptr<Function> f,
const std::vector<element::Type>& parameter_element_types, const std::vector<element::Type>& parameter_element_types,
const std::vector<PartialShape>& parameter_shapes); const std::vector<PartialShape>& parameter_shapes,
const std::vector<void*>& parameter_values);
} }
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/result.hpp" #include "ngraph/op/result.hpp"
#include "ngraph/partial_shape.hpp"
#include "ngraph/runtime/backend.hpp" #include "ngraph/runtime/backend.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -457,14 +458,35 @@ void ngraph::check_fp_values_isnan(const char* name, const double* array, size_t ...@@ -457,14 +458,35 @@ void ngraph::check_fp_values_isnan(const char* name, const double* array, size_t
} }
} }
template <typename T> bool ngraph::is_valid_permutation(ngraph::AxisVector permutation, ngraph::Rank rank)
T ngraph::apply_permutation(T input, AxisVector order)
{ {
if (input.size() != order.size()) std::vector<bool> axis_occurs(permutation.size(), false);
for (auto& axis : permutation)
{
axis_occurs[axis] = true;
}
for (size_t axis = 0; axis < permutation.size(); axis++)
{ {
throw "input and order sizes don't match!"; if (!axis_occurs[axis])
{
return false;
}
} }
return (rank.is_dynamic() || permutation.size() == static_cast<size_t>(rank));
}
template <typename T>
T ngraph::apply_permutation(T input, AxisVector order)
{
NGRAPH_CHECK(is_valid_permutation(order, input.size()),
"Permutation ",
order,
" is not valid for ",
input);
T output(input.size()); T output(input.size());
for (size_t i = 0; i < order.size(); i++) for (size_t i = 0; i < order.size(); i++)
...@@ -485,6 +507,35 @@ template ngraph::CoordinateDiff ...@@ -485,6 +507,35 @@ template ngraph::CoordinateDiff
template ngraph::Strides ngraph::apply_permutation<ngraph::Strides>(ngraph::Strides input, template ngraph::Strides ngraph::apply_permutation<ngraph::Strides>(ngraph::Strides input,
ngraph::AxisVector order); ngraph::AxisVector order);
namespace ngraph
{
template <>
PartialShape apply_permutation(PartialShape input, AxisVector order)
{
NGRAPH_CHECK(is_valid_permutation(order, input.rank()),
"Permutation ",
order,
" is not valid for ",
input);
// Here's the special part: if AxisVector is a viable permutation of _some_ rank, and input
// has dynamic rank, we just stick with dynamic rank.
if (input.rank().is_dynamic())
{
return input;
}
PartialShape output{PartialShape::dynamic(order.size())};
for (size_t i = 0; i < order.size(); i++)
{
output[i] = input[order.at(i)];
}
return output;
}
}
AxisVector ngraph::get_default_order(const Shape& shape) AxisVector ngraph::get_default_order(const Shape& shape)
{ {
return get_default_order(shape.size()); return get_default_order(shape.size());
......
...@@ -196,6 +196,7 @@ namespace ngraph ...@@ -196,6 +196,7 @@ namespace ngraph
void ngraph_free(void*); void ngraph_free(void*);
size_t round_up(size_t size, size_t alignment); size_t round_up(size_t size, size_t alignment);
bool is_valid_permutation(ngraph::AxisVector permutation, ngraph::Rank rank = Rank::dynamic());
template <typename T> template <typename T>
T apply_permutation(T input, ngraph::AxisVector order); T apply_permutation(T input, ngraph::AxisVector order);
......
...@@ -42,6 +42,7 @@ set(SRC ...@@ -42,6 +42,7 @@ set(SRC
copy.cpp copy.cpp
cpio.cpp cpio.cpp
cse.cpp cse.cpp
dyn_elimination.cpp
element_type.cpp element_type.cpp
file_util.cpp file_util.cpp
float16.cpp float16.cpp
...@@ -63,7 +64,7 @@ set(SRC ...@@ -63,7 +64,7 @@ set(SRC
reshape_elimination.cpp reshape_elimination.cpp
reshape_sinking.cpp reshape_sinking.cpp
shape.cpp shape.cpp
specialize_shapes.cpp specialize_function.cpp
tensor.cpp tensor.cpp
type_prop.cpp type_prop.cpp
type_prop_layers.cpp type_prop_layers.cpp
......
...@@ -260,6 +260,104 @@ TEST(constant_folding, const_quantize) ...@@ -260,6 +260,104 @@ TEST(constant_folding, const_quantize)
vector<output_c_type> values_quantize{2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5}; vector<output_c_type> values_quantize{2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5};
ASSERT_EQ(values_quantize, values_out); ASSERT_EQ(values_quantize, values_out);
} }
TEST(constant_folding, const_convert)
{
Shape input_shape{3, 4};
vector<int32_t> values_in{1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7};
auto constant = op::Constant::create(element::f32, input_shape, values_in);
auto convert = make_shared<op::Convert>(constant, element::u64);
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Convert>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
ASSERT_EQ(new_const->get_output_element_type(0), element::u64);
auto values_out = new_const->get_vector<uint64_t>();
vector<uint64_t> values_expected{1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, shape_of)
{
Shape input_shape{3, 4, 0, 22, 608, 909, 3};
auto param = make_shared<op::Parameter>(element::boolean, input_shape);
auto shape_of = make_shared<op::ShapeOf>(param);
auto f = make_shared<Function>(shape_of, ParameterVector{param});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::ShapeOf>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
ASSERT_EQ(new_const->get_output_element_type(0), element::i64);
auto values_out = new_const->get_vector<int64_t>();
ASSERT_EQ((vector<int64_t>{3, 4, 0, 22, 608, 909, 3}), values_out);
}
// A bit of an unusual case here: constant folding will not succeed on ShapeOf
// if the argument doesn't have dynamic shape. We want to make sure it fails
// gracefully, leaving the ShapeOf op in place.
TEST(constant_folding, shape_of_dynamic)
{
PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
auto param = make_shared<op::Parameter>(element::boolean, input_shape);
auto shape_of = make_shared<op::ShapeOf>(param);
auto f = make_shared<Function>(shape_of, ParameterVector{param});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::ShapeOf>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 0);
auto result_as_shape_of =
std::dynamic_pointer_cast<op::ShapeOf>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(result_as_shape_of);
ASSERT_EQ(result_as_shape_of->get_output_shape(0), Shape{7});
}
// Similar to shape_of_dynamic above but here even the rank is dynamic.
TEST(constant_folding, shape_of_rank_dynamic)
{
PartialShape input_shape{PartialShape::dynamic()};
auto param = make_shared<op::Parameter>(element::boolean, input_shape);
auto shape_of = make_shared<op::ShapeOf>(param);
auto f = make_shared<Function>(shape_of, ParameterVector{param});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::ShapeOf>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 0);
auto result_as_shape_of =
std::dynamic_pointer_cast<op::ShapeOf>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(result_as_shape_of);
ASSERT_TRUE(result_as_shape_of->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic()}));
}
TEST(constant_folding, pass_property) TEST(constant_folding, pass_property)
{ {
auto pass = std::make_shared<ngraph::pass::ConstantFolding>(); auto pass = std::make_shared<ngraph::pass::ConstantFolding>();
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/dyn_elimination.hpp"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "util/all_close_f.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(dyn_elimination, transpose)
{
Shape shape_in{2, 4, 6, 8};
auto param = make_shared<op::Parameter>(element::boolean, shape_in);
auto constant_perm =
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{2, 3, 1, 0});
auto transpose = make_shared<op::Transpose>(param, constant_perm);
auto f = make_shared<Function>(transpose, ParameterVector{param});
pass::Manager pass_manager;
pass_manager.register_pass<pass::DynElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Transpose>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 1);
auto new_reshape =
std::dynamic_pointer_cast<op::Reshape>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_reshape);
ASSERT_EQ(new_reshape->get_input_order(), (AxisVector{2, 3, 1, 0}));
ASSERT_EQ(new_reshape->output(0).get_shape(), (Shape{6, 8, 4, 2}));
ASSERT_EQ(new_reshape->get_output_element_type(0), element::boolean);
}
// For now, we can't handle the case where the input has dynamic shapes,
// because the classic Reshape op demands a Shape. Probably won't be able to
// deal with this until/unless we make a "StaticTranspose". Just make sure
// we don't crash or mangle the graph.
TEST(dyn_elimination, transpose_dyn_shape)
{
PartialShape shape_in{2, 4, Dimension::dynamic(), 8};
auto param = make_shared<op::Parameter>(element::boolean, shape_in);
auto constant_perm =
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{2, 3, 1, 0});
auto transpose = make_shared<op::Transpose>(param, constant_perm);
auto f = make_shared<Function>(transpose, ParameterVector{param});
pass::Manager pass_manager;
pass_manager.register_pass<pass::DynElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Transpose>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_transpose =
std::dynamic_pointer_cast<op::Transpose>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_transpose);
ASSERT_EQ(new_transpose->get_output_element_type(0), element::boolean);
ASSERT_TRUE(new_transpose->get_output_partial_shape(0).relaxes(
PartialShape{Dimension::dynamic(), 8, 4, 2}));
}
...@@ -112,3 +112,50 @@ NGRAPH_TEST(dynamic_${BACKEND_NAME}, abc) ...@@ -112,3 +112,50 @@ NGRAPH_TEST(dynamic_${BACKEND_NAME}, abc)
EXPECT_TRUE(test::all_close_f(results, expected_values)); EXPECT_TRUE(test::all_close_f(results, expected_values));
} }
} }
NGRAPH_TEST(dynamic_${BACKEND_NAME}, transpose)
{
//
// Create a graph for f(x,perm) = Transpose(x,Convert<i64>(perm)). We'll do the permutation in
// i32 and cast it to i64, just for fun (and to mirror the TensorFlow test I am porting here).
//
auto x = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto perm = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto perm_i64 = make_shared<op::Convert>(perm, element::i64);
auto x_transpose = make_shared<op::Transpose>(x, perm_i64);
auto f = make_shared<Function>(NodeVector{x_transpose}, ParameterVector{x, perm});
auto backend = runtime::Backend::create("${BACKEND_NAME}", true);
auto ex = backend->compile(f);
auto t_r = backend->create_dynamic_tensor(element::f32, PartialShape::dynamic());
std::vector<Shape> x_shapes{Shape{2, 3}, Shape{2, 3}, Shape{2, 2, 3}};
std::vector<std::vector<int32_t>> perms{{0, 1}, {1, 0}, {2, 1, 0}};
std::vector<std::vector<float>> inputs{
{1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
std::vector<Shape> expected_result_shapes{Shape{2, 3}, Shape{3, 2}, {3, 2, 2}};
// Generated with numpy, so don't worry. :)
std::vector<std::vector<float>> expected_results{
{1, 2, 3, 4, 5, 6}, {1, 4, 2, 5, 3, 6}, {1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12}};
for (size_t i = 0; i < x_shapes.size(); i++)
{
auto t_x = backend->create_tensor(element::f32, x_shapes[i]);
auto t_perm = backend->create_tensor(element::i32, Shape{perms[i].size()});
copy_data(t_x, inputs[i]);
copy_data(t_perm, perms[i]);
ex->call_with_validate({t_r}, {t_x, t_perm});
ASSERT_EQ(t_r->get_shape(), expected_result_shapes[i]);
auto results = read_vector<float>(t_r);
ASSERT_TRUE(test::all_close_f(results, expected_results[i], MIN_FLOAT_TOLERANCE_BITS));
}
}
...@@ -17,13 +17,13 @@ ...@@ -17,13 +17,13 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/specialize_shapes.hpp" #include "ngraph/specialize_function.hpp"
using namespace ngraph; using namespace ngraph;
// Simple case: create a function with static parameter shapes and "specialize" them to the same // Simple case: create a function with static parameter shapes and "specialize" them to the same
// shapes. // shapes.
TEST(specialize_shapes, et_shape_static) TEST(specialize_function, et_shape_static)
{ {
auto p0 = std::make_shared<op::Parameter>(element::f32, Shape{1, 2, 3}); auto p0 = std::make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
auto p1 = std::make_shared<op::Parameter>(element::i32, Shape{1, 2, 3}); auto p1 = std::make_shared<op::Parameter>(element::i32, Shape{1, 2, 3});
...@@ -33,15 +33,19 @@ TEST(specialize_shapes, et_shape_static) ...@@ -33,15 +33,19 @@ TEST(specialize_shapes, et_shape_static)
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1}); auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
auto g = specialize_shapes( std::vector<void*> param_vals{nullptr, nullptr};
f, {element::f32, element::i32}, {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}});
auto g = specialize_function(f,
{element::f32, element::i32},
{PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
param_vals);
ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3})); ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3}));
ASSERT_EQ(g->get_output_element_type(0), element::f32); ASSERT_EQ(g->get_output_element_type(0), element::f32);
} }
// Test specialization of dynamic element types. // Test specialization of dynamic element types.
TEST(specialize_shapes, et_dynamic_shape_static) TEST(specialize_function, et_dynamic_shape_static)
{ {
auto p0 = std::make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3}); auto p0 = std::make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3});
auto p1 = std::make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3}); auto p1 = std::make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3});
...@@ -51,15 +55,19 @@ TEST(specialize_shapes, et_dynamic_shape_static) ...@@ -51,15 +55,19 @@ TEST(specialize_shapes, et_dynamic_shape_static)
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1}); auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
auto g = specialize_shapes( std::vector<void*> param_vals{nullptr, nullptr};
f, {element::f32, element::i32}, {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}});
auto g = specialize_function(f,
{element::f32, element::i32},
{PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
param_vals);
ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3})); ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3}));
ASSERT_EQ(g->get_output_element_type(0), element::f32); ASSERT_EQ(g->get_output_element_type(0), element::f32);
} }
// Test specialization of rank-dynamic shapes. // Test specialization of rank-dynamic shapes.
TEST(specialize_shapes, et_static_shape_rank_dynamic) TEST(specialize_function, et_static_shape_rank_dynamic)
{ {
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic()); auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic()); auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
...@@ -69,15 +77,19 @@ TEST(specialize_shapes, et_static_shape_rank_dynamic) ...@@ -69,15 +77,19 @@ TEST(specialize_shapes, et_static_shape_rank_dynamic)
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1}); auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
auto g = specialize_shapes( std::vector<void*> param_vals{nullptr, nullptr};
f, {element::f32, element::i32}, {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}});
auto g = specialize_function(f,
{element::f32, element::i32},
{PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
param_vals);
ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3})); ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3}));
ASSERT_EQ(g->get_output_element_type(0), element::f32); ASSERT_EQ(g->get_output_element_type(0), element::f32);
} }
// Test specialization of rank-static dynamic shapes. // Test specialization of rank-static dynamic shapes.
TEST(specialize_shapes, et_static_shape_rank_static_dynamic) TEST(specialize_function, et_static_shape_rank_static_dynamic)
{ {
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3)); auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic(3)); auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic(3));
...@@ -87,17 +99,56 @@ TEST(specialize_shapes, et_static_shape_rank_static_dynamic) ...@@ -87,17 +99,56 @@ TEST(specialize_shapes, et_static_shape_rank_static_dynamic)
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1}); auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
auto g = specialize_shapes( std::vector<void*> param_vals{nullptr, nullptr};
f, {element::f32, element::i32}, {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}});
auto g = specialize_function(f,
{element::f32, element::i32},
{PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
param_vals);
ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3})); ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3}));
ASSERT_EQ(g->get_output_element_type(0), element::f32); ASSERT_EQ(g->get_output_element_type(0), element::f32);
} }
// Test specialization of values to a shape-dynamic parameters.
TEST(specialize_function, et_static_shape_rank_static_dynamic_subst_val)
{
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic(3));
auto k = std::make_shared<op::Convert>(p1, element::f32);
auto a = p0 + k;
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
std::vector<int32_t> p1_subst_vals{5, 0, 3, 8, 5, 8};
std::vector<void*> param_vals{nullptr, p1_subst_vals.data()};
auto g = specialize_function(f,
{element::f32, element::i32},
{PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
param_vals);
ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3}));
ASSERT_EQ(g->get_output_element_type(0), element::f32);
auto plus_node = std::dynamic_pointer_cast<op::Add>(g->get_results().at(0)->get_argument(0));
ASSERT_TRUE(plus_node);
auto convert_node = std::dynamic_pointer_cast<op::Convert>(plus_node->get_argument(1));
ASSERT_TRUE(convert_node);
auto const_node = std::dynamic_pointer_cast<op::Constant>(convert_node->get_argument(0));
ASSERT_TRUE(const_node);
ASSERT_EQ(const_node->get_output_element_type(0), element::i32);
ASSERT_EQ(const_node->get_output_shape(0), (Shape{1, 2, 3}));
ASSERT_EQ(const_node->get_vector<int32_t>(), p1_subst_vals);
}
// Test specialization of rank-dynamic shapes to a case where validation will fail. // Test specialization of rank-dynamic shapes to a case where validation will fail.
// //
// (The input shapes we provide at specialization time are inconsistent.) // (The input shapes we provide at specialization time are inconsistent.)
TEST(specialize_shapes, et_static_shape_rank_dynamic_validation_fails) TEST(specialize_function, et_static_shape_rank_dynamic_validation_fails)
{ {
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic()); auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic()); auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
...@@ -107,10 +158,14 @@ TEST(specialize_shapes, et_static_shape_rank_dynamic_validation_fails) ...@@ -107,10 +158,14 @@ TEST(specialize_shapes, et_static_shape_rank_dynamic_validation_fails)
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1}); auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
std::vector<void*> param_vals{nullptr, nullptr};
ASSERT_THROW( ASSERT_THROW(
{ {
specialize_shapes( specialize_function(f,
f, {element::f32, element::i32}, {PartialShape{1, 2, 3}, PartialShape{1, 2, 3, 4}}); {element::f32, element::i32},
{PartialShape{1, 2, 3}, PartialShape{1, 2, 3, 4}},
param_vals);
}, },
NodeValidationFailure); NodeValidationFailure);
} }
...@@ -118,7 +173,7 @@ TEST(specialize_shapes, et_static_shape_rank_dynamic_validation_fails) ...@@ -118,7 +173,7 @@ TEST(specialize_shapes, et_static_shape_rank_dynamic_validation_fails)
// Test specialization of dynamic element types to a case where validation will fail. // Test specialization of dynamic element types to a case where validation will fail.
// //
// (The input element types we provide at specialization time are inconsistent.) // (The input element types we provide at specialization time are inconsistent.)
TEST(specialize_shapes, et_dynamic_shape_static_validation_fails) TEST(specialize_function, et_dynamic_shape_static_validation_fails)
{ {
auto p0 = std::make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3}); auto p0 = std::make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3});
auto p1 = std::make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3}); auto p1 = std::make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3});
...@@ -128,10 +183,14 @@ TEST(specialize_shapes, et_dynamic_shape_static_validation_fails) ...@@ -128,10 +183,14 @@ TEST(specialize_shapes, et_dynamic_shape_static_validation_fails)
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1}); auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
std::vector<void*> param_vals{nullptr, nullptr};
ASSERT_THROW( ASSERT_THROW(
{ {
specialize_shapes( specialize_function(f,
f, {element::u32, element::i32}, {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}}); {element::u32, element::i32},
{PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
param_vals);
}, },
NodeValidationFailure); NodeValidationFailure);
} }
...@@ -142,7 +201,7 @@ TEST(specialize_shapes, et_dynamic_shape_static_validation_fails) ...@@ -142,7 +201,7 @@ TEST(specialize_shapes, et_dynamic_shape_static_validation_fails)
// (Note that we are testing for a different exception class here because the failure is in // (Note that we are testing for a different exception class here because the failure is in
// specialize_shape's pre-checks, which use NGRAPH_CHECK, rather than inside validation as we // specialize_shape's pre-checks, which use NGRAPH_CHECK, rather than inside validation as we
// reconstruct the graph.) // reconstruct the graph.)
TEST(specialize_shapes, et_static_shape_rank_static_dynamic_rank_mismatch) TEST(specialize_function, et_static_shape_rank_static_dynamic_rank_mismatch)
{ {
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3)); auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic(3)); auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic(3));
...@@ -152,10 +211,14 @@ TEST(specialize_shapes, et_static_shape_rank_static_dynamic_rank_mismatch) ...@@ -152,10 +211,14 @@ TEST(specialize_shapes, et_static_shape_rank_static_dynamic_rank_mismatch)
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1}); auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
std::vector<void*> param_vals{nullptr, nullptr};
ASSERT_THROW( ASSERT_THROW(
{ {
specialize_shapes( specialize_function(f,
f, {element::f32, element::i32}, {PartialShape{1, 2, 3}, PartialShape{1, 2, 3, 4}}); {element::f32, element::i32},
{PartialShape{1, 2, 3}, PartialShape{1, 2, 3, 4}},
param_vals);
}, },
CheckFailure); CheckFailure);
} }
...@@ -166,7 +229,7 @@ TEST(specialize_shapes, et_static_shape_rank_static_dynamic_rank_mismatch) ...@@ -166,7 +229,7 @@ TEST(specialize_shapes, et_static_shape_rank_static_dynamic_rank_mismatch)
// (Note that we are testing for a different exception class here because the failure is in // (Note that we are testing for a different exception class here because the failure is in
// specialize_shape's pre-checks, which use NGRAPH_CHECK, rather than inside validation as we // specialize_shape's pre-checks, which use NGRAPH_CHECK, rather than inside validation as we
// reconstruct the graph.) // reconstruct the graph.)
TEST(specialize_shapes, et_static_shape_rank_static_dynamic_dim_mismatch) TEST(specialize_function, et_static_shape_rank_static_dynamic_dim_mismatch)
{ {
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3}); auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
auto p1 = auto p1 =
...@@ -177,16 +240,20 @@ TEST(specialize_shapes, et_static_shape_rank_static_dynamic_dim_mismatch) ...@@ -177,16 +240,20 @@ TEST(specialize_shapes, et_static_shape_rank_static_dynamic_dim_mismatch)
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1}); auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
std::vector<void*> param_vals{nullptr, nullptr};
ASSERT_THROW( ASSERT_THROW(
{ {
specialize_shapes( specialize_function(f,
f, {element::f32, element::i32}, {PartialShape{1, 2, 3}, PartialShape{1, 9, 4}}); {element::f32, element::i32},
{PartialShape{1, 2, 3}, PartialShape{1, 9, 4}},
param_vals);
}, },
CheckFailure); CheckFailure);
} }
// Test for failure when we supply the wrong number of replacement element types. // Test for failure when we supply the wrong number of replacement element types.
TEST(specialize_shapes, et_count_wrong) TEST(specialize_function, et_count_wrong)
{ {
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3}); auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape{1, 2, 3}); auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape{1, 2, 3});
...@@ -196,17 +263,20 @@ TEST(specialize_shapes, et_count_wrong) ...@@ -196,17 +263,20 @@ TEST(specialize_shapes, et_count_wrong)
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1}); auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
std::vector<void*> param_vals{nullptr, nullptr};
ASSERT_THROW( ASSERT_THROW(
{ {
specialize_shapes(f, specialize_function(f,
{element::f32, element::i32, element::u32}, {element::f32, element::i32, element::u32},
{PartialShape{1, 2, 3}, PartialShape{1, 2, 3}}); {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
param_vals);
}, },
CheckFailure); CheckFailure);
} }
// Test for failure when we supply the wrong number of replacement shapes. // Test for failure when we supply the wrong number of replacement shapes.
TEST(specialize_shapes, shape_count_wrong) TEST(specialize_function, shape_count_wrong)
{ {
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3}); auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape{1, 2, 3}); auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape{1, 2, 3});
...@@ -216,12 +286,38 @@ TEST(specialize_shapes, shape_count_wrong) ...@@ -216,12 +286,38 @@ TEST(specialize_shapes, shape_count_wrong)
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1}); auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
std::vector<void*> param_vals{nullptr, nullptr};
ASSERT_THROW( ASSERT_THROW(
{ {
specialize_shapes( specialize_function(
f, f,
{element::f32, element::i32}, {element::f32, element::i32},
{PartialShape{1, 2, 3}, PartialShape{1, 2, 3}, PartialShape{4, 5, 6}}); {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}, PartialShape{4, 5, 6}},
param_vals);
},
CheckFailure);
}
// Test for failure when we supply the wrong number of replacement parameter values.
TEST(specialize_function, value_count_wrong)
{
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape{1, 2, 3});
auto k = std::make_shared<op::Convert>(p1, element::f32);
auto a = p0 + k;
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
std::vector<void*> param_vals{nullptr, nullptr, nullptr};
ASSERT_THROW(
{
specialize_function(f,
{element::f32, element::i32},
{PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
param_vals);
}, },
CheckFailure); CheckFailure);
} }
...@@ -12442,6 +12442,39 @@ TEST(type_prop, transpose_arg_static_input_order_static_ok) ...@@ -12442,6 +12442,39 @@ TEST(type_prop, transpose_arg_static_input_order_static_ok)
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4))); EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
} }
TEST(type_prop, transpose_arg_static_input_order_constant_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto input_order = op::Constant::create(element::i64, Shape{4}, vector<int64_t>{2, 1, 0, 3});
auto r = make_shared<op::Transpose>(arg, input_order);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{6, 4, 2, 8}));
}
TEST(type_prop, transpose_arg_static_input_order_constant_invalid_perm)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto input_order = op::Constant::create(element::i64, Shape{4}, vector<int64_t>{2, 9, 0, 3});
try
{
auto r = make_shared<op::Transpose>(arg, input_order);
FAIL() << "Did not detect invalid permutation";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Permutation AxisVector{2, 9, 0, 3} is not valid for input shape"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, transpose_arg_rank_static_dynamic_input_order_static_ok) TEST(type_prop, transpose_arg_rank_static_dynamic_input_order_static_ok)
{ {
auto arg = make_shared<op::Parameter>( auto arg = make_shared<op::Parameter>(
......
...@@ -537,3 +537,74 @@ TEST(util, enum_mask_operators) ...@@ -537,3 +537,74 @@ TEST(util, enum_mask_operators)
EXPECT_EQ(false, n[Type::d]); EXPECT_EQ(false, n[Type::d]);
EXPECT_EQ(true, n[Type::b]); EXPECT_EQ(true, n[Type::b]);
} }
TEST(util, apply_permutation)
{
ASSERT_EQ(apply_permutation(Shape{0, 1, 2, 3}, AxisVector{2, 1, 0, 3}), (Shape{2, 1, 0, 3}));
}
TEST(util, apply_permutation_too_short_fails)
{
ASSERT_THROW(apply_permutation(Shape{0, 1, 2, 3}, AxisVector{0, 1, 2}), CheckFailure);
}
TEST(util, apply_permutation_too_long_fails)
{
ASSERT_THROW(apply_permutation(Shape{0, 1, 2, 3}, AxisVector{0, 1, 2, 3, 3}), CheckFailure);
}
TEST(util, apply_permutation_oob_axis_fails)
{
ASSERT_THROW(apply_permutation(Shape{0, 1, 2, 3}, AxisVector{0, 1, 2, 4}), CheckFailure);
}
TEST(util, apply_permutation_repeated_axis_fails)
{
ASSERT_THROW(apply_permutation(Shape{0, 1, 2, 3}, AxisVector{0, 1, 2, 2}), CheckFailure);
}
TEST(util, apply_permutation_pshape)
{
ASSERT_TRUE(
apply_permutation(PartialShape{0, Dimension::dynamic(), 2, 3}, AxisVector{2, 1, 0, 3})
.same_scheme(PartialShape{2, Dimension::dynamic(), 0, 3}));
}
TEST(util, apply_permutation_pshape_rank_dynamic)
{
ASSERT_TRUE(apply_permutation(PartialShape::dynamic(), AxisVector{2, 1, 0, 3})
.same_scheme(PartialShape::dynamic()));
}
TEST(util, apply_permutation_pshape_too_short_fails)
{
ASSERT_THROW(
apply_permutation(PartialShape{0, Dimension::dynamic(), 2, 3}, AxisVector{0, 1, 2}),
CheckFailure);
}
TEST(util, apply_permutation_pshape_too_long_fails)
{
ASSERT_THROW(
apply_permutation(PartialShape{0, Dimension::dynamic(), 2, 3}, AxisVector{0, 1, 2, 3, 3}),
CheckFailure);
}
TEST(util, apply_permutation_pshape_oob_axis_fails)
{
ASSERT_THROW(
apply_permutation(PartialShape{0, Dimension::dynamic(), 2, 3}, AxisVector{0, 1, 2, 4}),
CheckFailure);
}
TEST(util, apply_permutation_pshape_repeated_axis_fails)
{
ASSERT_THROW(
apply_permutation(PartialShape{0, Dimension::dynamic(), 2, 3}, AxisVector{0, 1, 2, 2}),
CheckFailure);
}
TEST(util, apply_permutation_pshape_rank_dynamic_inviable_permutation_fails)
{
ASSERT_THROW(apply_permutation(PartialShape::dynamic(), AxisVector{0, 1, 2, 2}), CheckFailure);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment