Unverified Commit da1cacde authored by Adam Procter's avatar Adam Procter Committed by GitHub

Add more infrastructure for specialization of cloned graphs (#2949)

* Virtualize some things that crash when layout descriptor is missing

* More shape specialization

* (very bare) skeleton for dyn elimination

* Miscellaneous

* Lift i32->int64-only restriction on constant folding for Convert

* Add constant folding for ShapeOf, and some tests for new constant folders

* Tests for DynElimination

* Rename specialize_shapes to specialize_function, and add a unit test for value substitution

* Roll back overeager API change in dyn slice bprop (it has to handle right-indexed axes; bummer)

* Add a test for dynamic usage of transpose op

* Fix warning/error about variable shadowing

* Strengthen checks in apply_permutation

* Propagate Constant shapes through Transpose

* Add CHANGE_DYNAMIC_STATE where appropriate

* PR feedback, and fix unit test failure

* Fix PR reference in comment

* PR comments

* Comments for helper funcs

* Remove unique_ptr indirection for the AlignedBuffers

* Fix incorrect indexing of AlignedBuffer vector (whoops\!)

* Remove unnecessary CHANGE_DYAMIC_STATEs

* De-update pass property unit test for const folding

* Replace mystery runes with all_pass_property_off
parent ba546455
......@@ -358,6 +358,8 @@ set (SRC
pass/cse.hpp
pass/dump_sorted.cpp
pass/dump_sorted.hpp
pass/dyn_elimination.cpp
pass/dyn_elimination.hpp
pass/fused_op_decomposition.cpp
pass/fused_op_decomposition.hpp
pass/get_output_element_elimination.cpp
......@@ -435,8 +437,8 @@ set (SRC
shape.hpp
shape_util.cpp
shape_util.hpp
specialize_shapes.cpp
specialize_shapes.hpp
specialize_function.cpp
specialize_function.hpp
state/rng_state.cpp
strides.cpp
strides.hpp
......
......@@ -165,5 +165,5 @@
#include "ngraph/runtime/tensor.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/shape_util.hpp"
#include "ngraph/specialize_shapes.hpp"
#include "ngraph/specialize_function.hpp"
#include "ngraph/type/element_type.hpp"
......@@ -171,6 +171,54 @@ Strides op::Constant::get_strides_val() const
return output_strides;
}
Coordinate op::Constant::get_coordinate_val() const
{
NGRAPH_CHECK(m_element_type == element::i64);
std::vector<int64_t> out_coordinate = get_vector<int64_t>();
Coordinate output_coordinate(shape_size(m_shape));
std::transform(out_coordinate.begin(),
out_coordinate.end(),
output_coordinate.begin(),
[&](const int64_t& v) { return (v > 0) ? v : 0; });
return output_coordinate;
}
CoordinateDiff op::Constant::get_coordinate_diff_val() const
{
NGRAPH_CHECK(m_element_type == element::i64);
std::vector<int64_t> out_coordinate_diff = get_vector<int64_t>();
CoordinateDiff output_coordinate_diff(shape_size(m_shape));
std::transform(out_coordinate_diff.begin(),
out_coordinate_diff.end(),
output_coordinate_diff.begin(),
[&](const int64_t& v) { return (v > 0) ? v : 0; });
return output_coordinate_diff;
}
AxisVector op::Constant::get_axis_vector_val() const
{
NGRAPH_CHECK(m_element_type == element::i64);
std::vector<int64_t> out_axis_vector = get_vector<int64_t>();
AxisVector output_axis_vector(shape_size(m_shape));
std::transform(out_axis_vector.begin(),
out_axis_vector.end(),
output_axis_vector.begin(),
[&](const int64_t& v) { return (v > 0) ? v : 0; });
return output_axis_vector;
}
AxisSet op::Constant::get_axis_set_val() const
{
NGRAPH_CHECK(m_element_type == element::i64);
std::vector<int64_t> out_axis_set = get_vector<int64_t>();
AxisSet output_axis_set;
for (auto& axis : get_vector<int64_t>())
{
output_axis_set.insert(axis > 0 ? axis : 0);
}
return output_axis_set;
}
shared_ptr<Node> op::Constant::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......
......@@ -19,6 +19,7 @@
#include <cstring>
#include <sstream>
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/node.hpp"
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/type/bfloat16.hpp"
......@@ -156,10 +157,31 @@ namespace ngraph
/// Can only be used on element::i64 nodes and interprets
/// negative values as zeros.
Shape get_shape_val() const;
/// \brief Returns the value of the constant node as a Strides object
/// \brief Returns the value of the constant node as a Strides
/// object
/// Can only be used on element::i64 nodes and interprets
/// negative values as zeros.
Strides get_strides_val() const;
/// \brief Returns the value of the constant node as a Coordinate
/// object
/// Can only be used on element::i64 nodes and interprets
/// negative values as zeros.
Coordinate get_coordinate_val() const;
/// \brief Returns the value of the constant node as a
/// CoordinateDiff object
/// Can only be used on element::i64 nodes.
CoordinateDiff get_coordinate_diff_val() const;
/// \brief Returns the value of the constant node as an AxisVector
/// object
/// Can only be used on element::i64 nodes and interprets
/// negative values as zeros.
AxisVector get_axis_vector_val() const;
/// \brief Returns the value of the constant node as an AxisSet
/// object
/// Can only be used on element::i64 nodes and interprets
/// negative values as zeros.
/// Repeated values are allowed.
AxisSet get_axis_set_val() const;
/// \brief Wrapper around constructing a shared_ptr of a Constant
///
......
......@@ -49,14 +49,7 @@ void op::DynReshape::validate_and_infer_types()
set_input_is_relevant_to_shape(1);
if (auto const_shape = dynamic_pointer_cast<op::Constant>(get_argument(1)))
{
// TODO: replace with const_shape->get_shapes_val()
auto out_shape = const_shape->get_vector<int64_t>();
Shape output_shape(shape_size(const_shape->get_shape()));
std::transform(out_shape.begin(),
out_shape.end(),
output_shape.begin(),
[&](const int64_t& v) { return max(v, int64_t(0)); });
set_output_type(0, get_input_element_type(0), output_shape);
set_output_type(0, get_input_element_type(0), const_shape->get_shape_val());
}
else
{
......
......@@ -16,6 +16,7 @@
#include <iostream>
#include "ngraph/op/constant.hpp"
#include "ngraph/op/experimental/transpose.hpp"
using namespace std;
......@@ -43,7 +44,23 @@ void op::Transpose::validate_and_infer_types()
"Input order must have shape [n], where n is the rank of arg.");
set_input_is_relevant_to_shape(1);
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(arg_shape.rank()));
if (auto input_const = std::dynamic_pointer_cast<op::Constant>(get_argument(1)))
{
auto permutation = input_const->get_axis_vector_val();
NODE_VALIDATION_CHECK(this,
is_valid_permutation(permutation, arg_shape.rank()),
"Permutation ",
permutation,
" is not valid for input shape ",
arg_shape);
set_output_type(
0, get_input_element_type(0), ngraph::apply_permutation(arg_shape, permutation));
}
else
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(arg_shape.rank()));
}
}
shared_ptr<Node> op::Transpose::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -22,8 +22,10 @@
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
......@@ -39,6 +41,7 @@
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/convert.hpp"
#include "ngraph/runtime/reference/dequantize.hpp"
#include "ngraph/runtime/reference/divide.hpp"
#include "ngraph/runtime/reference/maximum.hpp"
......@@ -763,3 +766,188 @@ void pass::ConstantFolding::construct_constant_quantize()
this->add_matcher(
quantize_matcher, constant_quantize_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
// Helper for mapping element::Types to runtime::reference::convert, which is templated in C++
// data types. Used by fold_constant_convert and fold_constant_convert_helper0, which respectively
// determine the appropriate C++ types for "TI" (input type) and "TO" (output type).
template <typename TI, typename TO>
shared_ptr<op::Constant> fold_constant_convert_helper1(shared_ptr<op::Constant> constant,
const element::Type& output_element_type)
{
auto out_shape = constant->get_shape();
vector<TO> out_vec(shape_size(out_shape));
runtime::reference::convert<TI, TO>(
constant->get_vector<TI>().data(), out_vec.data(), shape_size(out_shape));
return make_shared<op::Constant>(output_element_type, out_shape, out_vec);
}
// Helper for mapping element::Types to runtime::reference::convert, which is templated in C++
// data types. Used by fold_constant_convert, which determines the appropriate C++ type for "TI"
// (input type).
template <typename TI>
shared_ptr<op::Constant> fold_constant_convert_helper0(shared_ptr<op::Constant> constant,
const element::Type& output_element_type)
{
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (output_element_type.get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_convert");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_convert");
break;
case element::Type_t::boolean:
return fold_constant_convert_helper1<TI, char>(constant, output_element_type);
case element::Type_t::bf16:
return fold_constant_convert_helper1<TI, bfloat16>(constant, output_element_type);
case element::Type_t::f16:
return fold_constant_convert_helper1<TI, float16>(constant, output_element_type);
case element::Type_t::f32:
return fold_constant_convert_helper1<TI, float>(constant, output_element_type);
case element::Type_t::f64:
return fold_constant_convert_helper1<TI, double>(constant, output_element_type);
case element::Type_t::i8:
return fold_constant_convert_helper1<TI, int8_t>(constant, output_element_type);
case element::Type_t::i16:
return fold_constant_convert_helper1<TI, int16_t>(constant, output_element_type);
case element::Type_t::i32:
return fold_constant_convert_helper1<TI, int32_t>(constant, output_element_type);
case element::Type_t::i64:
return fold_constant_convert_helper1<TI, int64_t>(constant, output_element_type);
case element::Type_t::u8:
return fold_constant_convert_helper1<TI, uint8_t>(constant, output_element_type);
case element::Type_t::u16:
return fold_constant_convert_helper1<TI, uint16_t>(constant, output_element_type);
case element::Type_t::u32:
return fold_constant_convert_helper1<TI, uint32_t>(constant, output_element_type);
case element::Type_t::u64:
return fold_constant_convert_helper1<TI, uint64_t>(constant, output_element_type);
}
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#endif
}
static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> constant,
const element::Type& output_element_type)
{
auto& input_element_type = constant->get_output_element_type(0);
if (input_element_type == output_element_type)
{
return constant;
}
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (input_element_type.get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_convert");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_convert");
break;
case element::Type_t::boolean:
return fold_constant_convert_helper0<char>(constant, output_element_type);
case element::Type_t::bf16:
return fold_constant_convert_helper0<bfloat16>(constant, output_element_type);
case element::Type_t::f16:
return fold_constant_convert_helper0<float16>(constant, output_element_type);
case element::Type_t::f32:
return fold_constant_convert_helper0<float>(constant, output_element_type);
case element::Type_t::f64:
return fold_constant_convert_helper0<double>(constant, output_element_type);
case element::Type_t::i8:
return fold_constant_convert_helper0<int8_t>(constant, output_element_type);
case element::Type_t::i16:
return fold_constant_convert_helper0<int16_t>(constant, output_element_type);
case element::Type_t::i32:
return fold_constant_convert_helper0<int32_t>(constant, output_element_type);
case element::Type_t::i64:
return fold_constant_convert_helper0<int64_t>(constant, output_element_type);
case element::Type_t::u8:
return fold_constant_convert_helper0<uint8_t>(constant, output_element_type);
case element::Type_t::u16:
return fold_constant_convert_helper0<uint16_t>(constant, output_element_type);
case element::Type_t::u32:
return fold_constant_convert_helper0<uint32_t>(constant, output_element_type);
case element::Type_t::u64:
return fold_constant_convert_helper0<uint64_t>(constant, output_element_type);
}
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#endif
}
void pass::ConstantFolding::construct_constant_convert()
{
auto constant_label = make_shared<pattern::op::Label>(
element::i32, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto convert_op = make_shared<op::Convert>(constant_label, element::i64);
auto constant_convert_callback = [constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_convert_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto convert_match = static_pointer_cast<op::Convert>(m.get_match_root());
replace_node(
m.get_match_root(),
fold_constant_convert(constant_match, convert_match->get_output_element_type(0)));
return true;
};
auto convert_matcher =
make_shared<pattern::Matcher>(convert_op, "ConstantFolding.ConstantConvert");
this->add_matcher(convert_matcher, constant_convert_callback, all_pass_property_off);
}
// ShapeOf is a bit of an odd duck: it doesn't matter if the input's value is
// constant, as long as it has static shape.
void pass::ConstantFolding::construct_constant_shape_of()
{
auto arg_label = make_shared<pattern::op::Label>(element::i32, Shape{2, 3, 4});
auto shape_of_op = make_shared<op::ShapeOf>(arg_label);
auto constant_shape_of_callback = [arg_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_shape_of_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto arg_match = pattern_map[arg_label];
if (arg_match->get_output_partial_shape(0).is_static())
{
auto arg_shape = arg_match->get_output_shape(0);
auto replacement =
make_shared<op::Constant>(element::i64, Shape{arg_shape.size()}, arg_shape.data());
replace_node(m.get_match_root(), replacement);
return true;
}
else
{
return false;
}
};
auto shape_of_matcher =
make_shared<pattern::Matcher>(shape_of_op, "ConstantFolding.ConstantShapeOf");
this->add_matcher(shape_of_matcher, constant_shape_of_callback, all_pass_property_off);
}
......@@ -38,7 +38,9 @@ public:
DEQUANTIZE,
UNARY,
BINARY,
QUANTIZE
QUANTIZE,
CONVERT,
SHAPE_OF
};
ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap())
......@@ -52,6 +54,8 @@ public:
construct_constant_binary();
construct_constant_quantize();
construct_constant_dequantize();
construct_constant_convert();
construct_constant_shape_of();
}
//this allows to specify the order in which matchers will be run
......@@ -72,6 +76,8 @@ public:
case CFTransformations::BINARY: construct_constant_binary(); break;
case CFTransformations::DEQUANTIZE: construct_constant_dequantize(); break;
case CFTransformations::QUANTIZE: construct_constant_quantize(); break;
case CFTransformations::CONVERT: construct_constant_convert(); break;
case CFTransformations::SHAPE_OF: construct_constant_shape_of(); break;
}
}
}
......@@ -84,6 +90,8 @@ private:
void construct_constant_binary();
void construct_constant_quantize();
void construct_constant_dequantize();
void construct_constant_convert();
void construct_constant_shape_of();
ngraph::BuildNodeExecutorMap m_cfmap;
};
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "dyn_elimination.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
using namespace std;
using namespace ngraph;
pass::DynElimination::DynElimination()
: GraphRewrite()
{
construct_transpose();
}
void pass::DynElimination::construct_transpose()
{
auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
auto perm_arg_label =
make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
auto transpose = make_shared<op::Transpose>(data_arg_label, perm_arg_label);
auto transpose_callback = [data_arg_label, perm_arg_label](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
auto data_arg = pattern_map[data_arg_label];
auto perm_arg = static_pointer_cast<op::Constant>(pattern_map[perm_arg_label]);
// TODO(amprocte): Can't handle the case where data shape is dynamic, because static
// Reshape requries the exact output shape to be declared. See if we can come up with a
// workaround.
if (data_arg->get_output_partial_shape(0).is_dynamic())
{
return false;
}
auto& data_shape = data_arg->get_output_shape(0);
// TODO(amprocte): These should be redundant if the graph is validated. Necessary?
if (perm_arg->get_element_type() != element::i64 ||
perm_arg->get_output_partial_shape(0).is_dynamic() ||
perm_arg->get_output_shape(0).size() != 1)
{
return false;
}
auto perm = perm_arg->get_axis_vector_val();
auto output_shape = ngraph::apply_permutation(data_shape, perm);
auto replacement = std::make_shared<op::Reshape>(data_arg, perm, output_shape);
replace_node(m.get_match_root(), replacement);
return true;
};
auto transpose_matcher = make_shared<pattern::Matcher>(transpose, "DynElimination.Transpose");
add_matcher(transpose_matcher, transpose_callback, all_pass_property_off);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace pass
{
class DynElimination : public GraphRewrite
{
public:
DynElimination();
private:
void construct_transpose();
};
}
}
......@@ -16,9 +16,11 @@
#include "ngraph/runtime/dynamic/dynamic_backend.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/pass/constant_folding.hpp"
#include "ngraph/pass/dyn_elimination.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/shape_relevance.hpp"
#include "ngraph/specialize_shapes.hpp"
#include "ngraph/specialize_function.hpp"
#include "ngraph/util.hpp"
using namespace std;
......@@ -79,30 +81,73 @@ bool runtime::dynamic::DynamicExecutable::call(
// (1) all shapes;
// (2) all values of shape-relevant input tensors.
NGRAPH_CHECK(m_wrapped_function->get_parameters().size() == inputs.size());
std::vector<std::shared_ptr<runtime::Tensor>> wrapped_inputs;
std::vector<element::Type> arg_element_types;
std::vector<PartialShape> arg_shapes;
for (auto& input : inputs)
std::shared_ptr<Function> clone;
{
if (auto dynamic_tensor = std::dynamic_pointer_cast<runtime::dynamic::DynamicTensor>(input))
{
NGRAPH_CHECK(dynamic_tensor->has_storage());
arg_element_types.push_back(dynamic_tensor->get_wrapped_tensor()->get_element_type());
arg_shapes.push_back(dynamic_tensor->get_wrapped_tensor()->get_shape());
wrapped_inputs.push_back(dynamic_tensor->get_wrapped_tensor());
}
else
// We'll use AlignedBuffers to back the base pointers, storing them in this vector for RAII
// purposes.
std::vector<AlignedBuffer> arg_buffers;
arg_buffers.reserve(inputs.size());
std::vector<void*> arg_value_base_pointers(inputs.size());
size_t i = 0;
for (auto& input : inputs)
{
arg_element_types.push_back(input->get_element_type());
arg_shapes.push_back(input->get_shape());
wrapped_inputs.push_back(input);
if (m_wrapped_function->get_parameters()[i]->is_relevant_to_shapes())
{
// TODO(amprocte): Move has_storage() to runtime::Tensor?
if (auto dynamic_tensor =
std::dynamic_pointer_cast<runtime::dynamic::DynamicTensor>(input))
{
NGRAPH_CHECK(dynamic_tensor->has_storage());
}
arg_buffers.emplace_back(input->get_size_in_bytes(), /*alignment=*/64);
arg_value_base_pointers[i] = arg_buffers.back().get_ptr();
// TODO(amprocte): For host-resident tensors we should be able to skip the read,
// but no API for that yet.
input->read(arg_value_base_pointers[i], 0, input->get_size_in_bytes());
}
else
{
arg_value_base_pointers[i] = nullptr;
}
if (auto dynamic_tensor =
std::dynamic_pointer_cast<runtime::dynamic::DynamicTensor>(input))
{
NGRAPH_CHECK(dynamic_tensor->has_storage());
arg_element_types.push_back(
dynamic_tensor->get_wrapped_tensor()->get_element_type());
arg_shapes.push_back(dynamic_tensor->get_wrapped_tensor()->get_shape());
wrapped_inputs.push_back(dynamic_tensor->get_wrapped_tensor());
}
else
{
arg_element_types.push_back(input->get_element_type());
arg_shapes.push_back(input->get_shape());
wrapped_inputs.push_back(input);
}
i++;
}
clone = specialize_function(
m_wrapped_function, arg_element_types, arg_shapes, arg_value_base_pointers);
}
// TODO: specialize_shapes needs to fill in values of shape-relevant params.
auto clone = specialize_shapes(m_wrapped_function, arg_element_types, arg_shapes);
// TODO: run constant folding and de-dynification on clone.
pass::Manager passes;
passes.register_pass<pass::ConstantFolding>();
passes.register_pass<pass::DynElimination>();
passes.run_passes(clone);
const ResultVector& results = clone->get_results();
NGRAPH_CHECK(results.size() == outputs.size());
......@@ -140,6 +185,27 @@ runtime::dynamic::DynamicTensor::DynamicTensor(
{
}
Strides runtime::dynamic::DynamicTensor::get_strides() const
{
NGRAPH_CHECK(m_wrapped_tensor != nullptr,
"asked for strides of a dynamic tensor with no allocated storage");
return ngraph::row_major_strides(m_wrapped_tensor->get_shape());
}
size_t runtime::dynamic::DynamicTensor::get_size_in_bytes() const
{
NGRAPH_CHECK(m_wrapped_tensor != nullptr,
"asked for size in bytes of a dynamic tensor with no allocated storage");
return get_element_count() * get_element_type().size();
}
size_t runtime::dynamic::DynamicTensor::get_element_count() const
{
NGRAPH_CHECK(m_wrapped_tensor != nullptr,
"asked for element count of a dynamic tensor with no allocated storage");
return shape_size(m_wrapped_tensor->get_shape());
}
const element::Type& runtime::dynamic::DynamicTensor::get_element_type() const
{
if (m_wrapped_tensor == nullptr)
......
......@@ -127,6 +127,9 @@ public:
DynamicTensor(const element::Type& element_type,
const PartialShape& shape,
const std::shared_ptr<runtime::Backend>& wrapped_backend);
virtual ngraph::Strides get_strides() const override;
virtual size_t get_size_in_bytes() const override;
virtual size_t get_element_count() const override;
virtual const element::Type& get_element_type() const override;
virtual const ngraph::Shape& get_shape() const override;
virtual void write(const void* p, size_t offset, size_t n) override;
......
......@@ -58,7 +58,7 @@ namespace ngraph
/// \brief Get tensor strides
/// \return Strides
ngraph::Strides get_strides() const;
virtual ngraph::Strides get_strides() const;
/// \brief Get tensor element type
/// \return element::Type
......@@ -66,11 +66,11 @@ namespace ngraph
/// \brief Get number of elements in the tensor
/// \return number of elements in the tensor
size_t get_element_count() const;
virtual size_t get_element_count() const;
/// \brief Get the size in bytes of the tensor
/// \return number of bytes in tensor's allocation
size_t get_size_in_bytes() const;
virtual size_t get_size_in_bytes() const;
/// \brief Get tensor's unique name
/// \return tensor's name
......
......@@ -14,17 +14,20 @@
// limitations under the License.
//*****************************************************************************
#include "ngraph/specialize_shapes.hpp"
#include "ngraph/specialize_function.hpp"
#include "ngraph/op/constant.hpp"
using namespace ngraph;
std::shared_ptr<Function>
ngraph::specialize_shapes(std::shared_ptr<Function> f,
const std::vector<element::Type>& parameter_element_types,
const std::vector<PartialShape>& parameter_shapes)
ngraph::specialize_function(std::shared_ptr<Function> f,
const std::vector<element::Type>& parameter_element_types,
const std::vector<PartialShape>& parameter_shapes,
const std::vector<void*>& parameter_values)
{
NGRAPH_CHECK(f->get_parameters().size() == parameter_shapes.size());
NGRAPH_CHECK(f->get_parameters().size() == parameter_element_types.size());
NGRAPH_CHECK(f->get_parameters().size() == parameter_values.size());
NodeMap m;
......@@ -35,8 +38,17 @@ std::shared_ptr<Function>
NGRAPH_CHECK(f->get_parameters()[i]->get_element_type().is_dynamic() ||
parameter_element_types[i] == f->get_parameters()[i]->get_element_type());
m[f->get_parameters()[i].get()] =
std::make_shared<op::Parameter>(parameter_element_types[i], parameter_shapes[i]);
if (parameter_values[i] != nullptr && parameter_shapes[i].is_static() &&
parameter_element_types[i].is_static())
{
m[f->get_parameters()[i].get()] = std::make_shared<op::Constant>(
parameter_element_types[i], parameter_shapes[i].to_shape(), parameter_values[i]);
}
else
{
m[f->get_parameters()[i].get()] =
std::make_shared<op::Parameter>(parameter_element_types[i], parameter_shapes[i]);
}
}
for (auto old_node : f->get_ordered_ops())
......@@ -59,7 +71,16 @@ std::shared_ptr<Function>
ParameterVector new_parameters = f->get_parameters();
for (size_t i = 0; i < new_parameters.size(); i++)
{
new_parameters[i] = std::static_pointer_cast<op::Parameter>(m[new_parameters[i].get()]);
new_parameters[i] = std::dynamic_pointer_cast<op::Parameter>(m[new_parameters[i].get()]);
// If the replacement for a Parameter is not itself a Parameter, we must have replaced it
// with a constant. We will insert a dead Parameter into the clone's parameters, in order
// to maintain the arity of the original function.
if (new_parameters[i] == nullptr)
{
new_parameters[i] =
std::make_shared<op::Parameter>(parameter_element_types[i], parameter_shapes[i]);
}
}
ResultVector new_results = f->get_results();
......
......@@ -20,17 +20,23 @@
namespace ngraph
{
/// \brief Creates a clone of a function, with the shapes of that function's parameters
/// specialized to some more specific element types and shapes.
/// \brief Creates a "specialized" clone of a function. The partial shapes and element types of
/// the function's parameters may be narrowed to more specific shapes and element types,
/// and constant values may optionally be substituted for any or all of the parameters.
/// \param f The function to be cloned.
/// \param parameter_element_types The new parameter element types to substitute.
/// \param parameter_shapes The new parameter shapes to substitute.
/// \return A clone of f, with the parameter element types and shapes specialized.
/// \throws CheckFailure if parameter_element_types or parameter_shapes is not valid
/// \param parameter_element_types The new parameter element types to substitute. Length must
/// be equal to the number of parameters of f.
/// \param parameter_shapes The new parameter shapes to substitute. Length must be equal to the
/// number of parameters of f.
/// \param parameter_values Parameter values to substitute. Length must be equal to the number
/// of parameters of f, with nullptr indicating that no substitution is to be made for
/// the corresponding parameter.
/// \return A clone of f, with the parameter element types, shapes, and values specialized.
/// \throws CheckFailure if parameter_element_types, parameter_shapes is not valid
/// (see details).
/// \throws NodeValidationError if node validation fails as the clone is being constructed.
///
/// Creates a "shape-specialized" clone of an nGraph Function function.
/// Creates a "specialized" clone of an nGraph Function.
///
/// For example, suppose that a function f has three parameters with partial shapes:
///
......@@ -75,17 +81,31 @@ namespace ngraph
/// specialized to itself (e.g., specialization does not allow you to change `element::i32`
/// to `element::i64`).
///
/// Finally, it is possible to specialize parameter values. If the ith element of
/// `parameter_values` is not `nullptr`, and fully static element type and shape has been
/// specified for the ith parameter, a `Constant` node will be created and substituted for the
/// ith parameter, with its data drawn from `parameter_values[i]`. Note that the Parameter node
/// remains (in order to maintain the arity of the function), but will no longer have any
/// users.
///
/// It is required that:
/// 1. The length of parameter_element_types and parameter_shapes is the same as the number
/// of f's parameters.
/// 1. The length of parameter_element_types, parameter_shapes, and parameter_values is the
/// same as the number of f's parameters.
/// 2. Each shape in parameter_shapes is a refinement of the shape of the corresponding
/// parameter of f. Roughly speaking, a shape s1 is said to "refine" s2 if s1 can be
/// obtained from s2 by filling in s2's question marks. See PartialShape::refines for
/// more details.
/// 3. For all i, either the element type of fp_i is dynamic, or fp_i is the same as
/// parameter_element_types[i]. (Here fp_i is the ith parameter of f.)
/// 4. For all i where parameter_values[i] != nullptr and parameter_element_types[i] is
/// static and parameter_shapes[i] is static, parameter_values[i] points to a buffer from
/// which a Constant node with element type parameter_element_types[i] and shape
/// parameter_shapes[i] can be created.
///
/// TODO(amprocte): convert this to a pass.
std::shared_ptr<Function>
specialize_shapes(std::shared_ptr<Function> f,
const std::vector<element::Type>& parameter_element_types,
const std::vector<PartialShape>& parameter_shapes);
specialize_function(std::shared_ptr<Function> f,
const std::vector<element::Type>& parameter_element_types,
const std::vector<PartialShape>& parameter_shapes,
const std::vector<void*>& parameter_values);
}
......@@ -28,6 +28,7 @@
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/partial_shape.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/util.hpp"
......@@ -457,14 +458,35 @@ void ngraph::check_fp_values_isnan(const char* name, const double* array, size_t
}
}
template <typename T>
T ngraph::apply_permutation(T input, AxisVector order)
bool ngraph::is_valid_permutation(ngraph::AxisVector permutation, ngraph::Rank rank)
{
if (input.size() != order.size())
std::vector<bool> axis_occurs(permutation.size(), false);
for (auto& axis : permutation)
{
throw "input and order sizes don't match!";
axis_occurs[axis] = true;
}
for (size_t axis = 0; axis < permutation.size(); axis++)
{
if (!axis_occurs[axis])
{
return false;
}
}
return (rank.is_dynamic() || permutation.size() == static_cast<size_t>(rank));
}
template <typename T>
T ngraph::apply_permutation(T input, AxisVector order)
{
NGRAPH_CHECK(is_valid_permutation(order, input.size()),
"Permutation ",
order,
" is not valid for ",
input);
T output(input.size());
for (size_t i = 0; i < order.size(); i++)
......@@ -485,6 +507,35 @@ template ngraph::CoordinateDiff
template ngraph::Strides ngraph::apply_permutation<ngraph::Strides>(ngraph::Strides input,
ngraph::AxisVector order);
namespace ngraph
{
template <>
PartialShape apply_permutation(PartialShape input, AxisVector order)
{
NGRAPH_CHECK(is_valid_permutation(order, input.rank()),
"Permutation ",
order,
" is not valid for ",
input);
// Here's the special part: if AxisVector is a viable permutation of _some_ rank, and input
// has dynamic rank, we just stick with dynamic rank.
if (input.rank().is_dynamic())
{
return input;
}
PartialShape output{PartialShape::dynamic(order.size())};
for (size_t i = 0; i < order.size(); i++)
{
output[i] = input[order.at(i)];
}
return output;
}
}
AxisVector ngraph::get_default_order(const Shape& shape)
{
return get_default_order(shape.size());
......
......@@ -196,6 +196,7 @@ namespace ngraph
void ngraph_free(void*);
size_t round_up(size_t size, size_t alignment);
bool is_valid_permutation(ngraph::AxisVector permutation, ngraph::Rank rank = Rank::dynamic());
template <typename T>
T apply_permutation(T input, ngraph::AxisVector order);
......
......@@ -42,6 +42,7 @@ set(SRC
copy.cpp
cpio.cpp
cse.cpp
dyn_elimination.cpp
element_type.cpp
file_util.cpp
float16.cpp
......@@ -63,7 +64,7 @@ set(SRC
reshape_elimination.cpp
reshape_sinking.cpp
shape.cpp
specialize_shapes.cpp
specialize_function.cpp
tensor.cpp
type_prop.cpp
type_prop_layers.cpp
......
......@@ -260,6 +260,104 @@ TEST(constant_folding, const_quantize)
vector<output_c_type> values_quantize{2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5};
ASSERT_EQ(values_quantize, values_out);
}
TEST(constant_folding, const_convert)
{
Shape input_shape{3, 4};
vector<int32_t> values_in{1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7};
auto constant = op::Constant::create(element::f32, input_shape, values_in);
auto convert = make_shared<op::Convert>(constant, element::u64);
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Convert>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
ASSERT_EQ(new_const->get_output_element_type(0), element::u64);
auto values_out = new_const->get_vector<uint64_t>();
vector<uint64_t> values_expected{1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, shape_of)
{
Shape input_shape{3, 4, 0, 22, 608, 909, 3};
auto param = make_shared<op::Parameter>(element::boolean, input_shape);
auto shape_of = make_shared<op::ShapeOf>(param);
auto f = make_shared<Function>(shape_of, ParameterVector{param});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::ShapeOf>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
ASSERT_EQ(new_const->get_output_element_type(0), element::i64);
auto values_out = new_const->get_vector<int64_t>();
ASSERT_EQ((vector<int64_t>{3, 4, 0, 22, 608, 909, 3}), values_out);
}
// A bit of an unusual case here: constant folding will not succeed on ShapeOf
// if the argument doesn't have dynamic shape. We want to make sure it fails
// gracefully, leaving the ShapeOf op in place.
TEST(constant_folding, shape_of_dynamic)
{
PartialShape input_shape{3, 4, Dimension::dynamic(), 22, 608, 909, 3};
auto param = make_shared<op::Parameter>(element::boolean, input_shape);
auto shape_of = make_shared<op::ShapeOf>(param);
auto f = make_shared<Function>(shape_of, ParameterVector{param});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::ShapeOf>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 0);
auto result_as_shape_of =
std::dynamic_pointer_cast<op::ShapeOf>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(result_as_shape_of);
ASSERT_EQ(result_as_shape_of->get_output_shape(0), Shape{7});
}
// Similar to shape_of_dynamic above but here even the rank is dynamic.
TEST(constant_folding, shape_of_rank_dynamic)
{
PartialShape input_shape{PartialShape::dynamic()};
auto param = make_shared<op::Parameter>(element::boolean, input_shape);
auto shape_of = make_shared<op::ShapeOf>(param);
auto f = make_shared<Function>(shape_of, ParameterVector{param});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::ShapeOf>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 0);
auto result_as_shape_of =
std::dynamic_pointer_cast<op::ShapeOf>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(result_as_shape_of);
ASSERT_TRUE(result_as_shape_of->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic()}));
}
TEST(constant_folding, pass_property)
{
auto pass = std::make_shared<ngraph::pass::ConstantFolding>();
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/dyn_elimination.hpp"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "util/all_close_f.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(dyn_elimination, transpose)
{
Shape shape_in{2, 4, 6, 8};
auto param = make_shared<op::Parameter>(element::boolean, shape_in);
auto constant_perm =
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{2, 3, 1, 0});
auto transpose = make_shared<op::Transpose>(param, constant_perm);
auto f = make_shared<Function>(transpose, ParameterVector{param});
pass::Manager pass_manager;
pass_manager.register_pass<pass::DynElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Transpose>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 1);
auto new_reshape =
std::dynamic_pointer_cast<op::Reshape>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_reshape);
ASSERT_EQ(new_reshape->get_input_order(), (AxisVector{2, 3, 1, 0}));
ASSERT_EQ(new_reshape->output(0).get_shape(), (Shape{6, 8, 4, 2}));
ASSERT_EQ(new_reshape->get_output_element_type(0), element::boolean);
}
// For now, we can't handle the case where the input has dynamic shapes,
// because the classic Reshape op demands a Shape. Probably won't be able to
// deal with this until/unless we make a "StaticTranspose". Just make sure
// we don't crash or mangle the graph.
TEST(dyn_elimination, transpose_dyn_shape)
{
PartialShape shape_in{2, 4, Dimension::dynamic(), 8};
auto param = make_shared<op::Parameter>(element::boolean, shape_in);
auto constant_perm =
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{2, 3, 1, 0});
auto transpose = make_shared<op::Transpose>(param, constant_perm);
auto f = make_shared<Function>(transpose, ParameterVector{param});
pass::Manager pass_manager;
pass_manager.register_pass<pass::DynElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Transpose>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_transpose =
std::dynamic_pointer_cast<op::Transpose>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_transpose);
ASSERT_EQ(new_transpose->get_output_element_type(0), element::boolean);
ASSERT_TRUE(new_transpose->get_output_partial_shape(0).relaxes(
PartialShape{Dimension::dynamic(), 8, 4, 2}));
}
......@@ -112,3 +112,50 @@ NGRAPH_TEST(dynamic_${BACKEND_NAME}, abc)
EXPECT_TRUE(test::all_close_f(results, expected_values));
}
}
NGRAPH_TEST(dynamic_${BACKEND_NAME}, transpose)
{
//
// Create a graph for f(x,perm) = Transpose(x,Convert<i64>(perm)). We'll do the permutation in
// i32 and cast it to i64, just for fun (and to mirror the TensorFlow test I am porting here).
//
auto x = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto perm = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto perm_i64 = make_shared<op::Convert>(perm, element::i64);
auto x_transpose = make_shared<op::Transpose>(x, perm_i64);
auto f = make_shared<Function>(NodeVector{x_transpose}, ParameterVector{x, perm});
auto backend = runtime::Backend::create("${BACKEND_NAME}", true);
auto ex = backend->compile(f);
auto t_r = backend->create_dynamic_tensor(element::f32, PartialShape::dynamic());
std::vector<Shape> x_shapes{Shape{2, 3}, Shape{2, 3}, Shape{2, 2, 3}};
std::vector<std::vector<int32_t>> perms{{0, 1}, {1, 0}, {2, 1, 0}};
std::vector<std::vector<float>> inputs{
{1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
std::vector<Shape> expected_result_shapes{Shape{2, 3}, Shape{3, 2}, {3, 2, 2}};
// Generated with numpy, so don't worry. :)
std::vector<std::vector<float>> expected_results{
{1, 2, 3, 4, 5, 6}, {1, 4, 2, 5, 3, 6}, {1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12}};
for (size_t i = 0; i < x_shapes.size(); i++)
{
auto t_x = backend->create_tensor(element::f32, x_shapes[i]);
auto t_perm = backend->create_tensor(element::i32, Shape{perms[i].size()});
copy_data(t_x, inputs[i]);
copy_data(t_perm, perms[i]);
ex->call_with_validate({t_r}, {t_x, t_perm});
ASSERT_EQ(t_r->get_shape(), expected_result_shapes[i]);
auto results = read_vector<float>(t_r);
ASSERT_TRUE(test::all_close_f(results, expected_results[i], MIN_FLOAT_TOLERANCE_BITS));
}
}
......@@ -12442,6 +12442,39 @@ TEST(type_prop, transpose_arg_static_input_order_static_ok)
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
}
TEST(type_prop, transpose_arg_static_input_order_constant_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto input_order = op::Constant::create(element::i64, Shape{4}, vector<int64_t>{2, 1, 0, 3});
auto r = make_shared<op::Transpose>(arg, input_order);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{6, 4, 2, 8}));
}
TEST(type_prop, transpose_arg_static_input_order_constant_invalid_perm)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto input_order = op::Constant::create(element::i64, Shape{4}, vector<int64_t>{2, 9, 0, 3});
try
{
auto r = make_shared<op::Transpose>(arg, input_order);
FAIL() << "Did not detect invalid permutation";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Permutation AxisVector{2, 9, 0, 3} is not valid for input shape"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, transpose_arg_rank_static_dynamic_input_order_static_ok)
{
auto arg = make_shared<op::Parameter>(
......
......@@ -537,3 +537,74 @@ TEST(util, enum_mask_operators)
EXPECT_EQ(false, n[Type::d]);
EXPECT_EQ(true, n[Type::b]);
}
TEST(util, apply_permutation)
{
ASSERT_EQ(apply_permutation(Shape{0, 1, 2, 3}, AxisVector{2, 1, 0, 3}), (Shape{2, 1, 0, 3}));
}
TEST(util, apply_permutation_too_short_fails)
{
ASSERT_THROW(apply_permutation(Shape{0, 1, 2, 3}, AxisVector{0, 1, 2}), CheckFailure);
}
TEST(util, apply_permutation_too_long_fails)
{
ASSERT_THROW(apply_permutation(Shape{0, 1, 2, 3}, AxisVector{0, 1, 2, 3, 3}), CheckFailure);
}
TEST(util, apply_permutation_oob_axis_fails)
{
ASSERT_THROW(apply_permutation(Shape{0, 1, 2, 3}, AxisVector{0, 1, 2, 4}), CheckFailure);
}
TEST(util, apply_permutation_repeated_axis_fails)
{
ASSERT_THROW(apply_permutation(Shape{0, 1, 2, 3}, AxisVector{0, 1, 2, 2}), CheckFailure);
}
TEST(util, apply_permutation_pshape)
{
ASSERT_TRUE(
apply_permutation(PartialShape{0, Dimension::dynamic(), 2, 3}, AxisVector{2, 1, 0, 3})
.same_scheme(PartialShape{2, Dimension::dynamic(), 0, 3}));
}
TEST(util, apply_permutation_pshape_rank_dynamic)
{
ASSERT_TRUE(apply_permutation(PartialShape::dynamic(), AxisVector{2, 1, 0, 3})
.same_scheme(PartialShape::dynamic()));
}
TEST(util, apply_permutation_pshape_too_short_fails)
{
ASSERT_THROW(
apply_permutation(PartialShape{0, Dimension::dynamic(), 2, 3}, AxisVector{0, 1, 2}),
CheckFailure);
}
TEST(util, apply_permutation_pshape_too_long_fails)
{
ASSERT_THROW(
apply_permutation(PartialShape{0, Dimension::dynamic(), 2, 3}, AxisVector{0, 1, 2, 3, 3}),
CheckFailure);
}
TEST(util, apply_permutation_pshape_oob_axis_fails)
{
ASSERT_THROW(
apply_permutation(PartialShape{0, Dimension::dynamic(), 2, 3}, AxisVector{0, 1, 2, 4}),
CheckFailure);
}
TEST(util, apply_permutation_pshape_repeated_axis_fails)
{
ASSERT_THROW(
apply_permutation(PartialShape{0, Dimension::dynamic(), 2, 3}, AxisVector{0, 1, 2, 2}),
CheckFailure);
}
TEST(util, apply_permutation_pshape_rank_dynamic_inviable_permutation_fails)
{
ASSERT_THROW(apply_permutation(PartialShape::dynamic(), AxisVector{0, 1, 2, 2}), CheckFailure);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment