Commit 6e877f6c authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Add support for shape-specialization of functions (#2322)

* Add support for shape-specializing functions

* Move PR comment into docstring for specialize_shapes

* Fix some broken unit tests

* Add stub cases for op_tbl-dependent stuff

* Revert "Add stub cases for op_tbl-dependent stuff"

This reverts commit 2153967dc2fe544ca78a99548c8bb3cdfefc8470.
parent 2ece2d9c
...@@ -332,6 +332,8 @@ set (SRC ...@@ -332,6 +332,8 @@ set (SRC
shape.hpp shape.hpp
shape_util.cpp shape_util.cpp
shape_util.hpp shape_util.hpp
specialize_shapes.cpp
specialize_shapes.hpp
state/rng_state.cpp state/rng_state.cpp
strides.cpp strides.cpp
strides.hpp strides.hpp
......
...@@ -133,4 +133,5 @@ ...@@ -133,4 +133,5 @@
#include "ngraph/runtime/tensor.hpp" #include "ngraph/runtime/tensor.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/shape_util.hpp" #include "ngraph/shape_util.hpp"
#include "ngraph/specialize_shapes.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/specialize_shapes.hpp"
using namespace ngraph;
using ReplacementMap = std::map<Node*, std::shared_ptr<Node>>;
std::shared_ptr<Function>
ngraph::specialize_shapes(std::shared_ptr<Function> f,
const std::vector<element::Type>& parameter_element_types,
const std::vector<PartialShape>& parameter_shapes)
{
NGRAPH_ASSERT(f->get_parameters().size() == parameter_shapes.size());
NGRAPH_ASSERT(f->get_parameters().size() == parameter_element_types.size());
ReplacementMap m;
for (size_t i = 0; i < parameter_shapes.size(); i++)
{
NGRAPH_ASSERT(
parameter_shapes[i].refines(f->get_parameters()[i]->get_output_partial_shape(0)));
NGRAPH_ASSERT(f->get_parameters()[i]->get_element_type().is_dynamic() ||
parameter_element_types[i] == f->get_parameters()[i]->get_element_type());
m[f->get_parameters()[i].get()] =
std::make_shared<op::Parameter>(parameter_element_types[i], parameter_shapes[i]);
}
for (auto old_node : f->get_ordered_ops())
{
if (old_node->is_parameter())
{
continue;
}
NodeVector new_args = old_node->get_arguments();
for (size_t i = 0; i < new_args.size(); i++)
{
new_args[i] = m[new_args[i].get()];
}
m[old_node.get()] = old_node->copy_with_new_args(new_args);
}
ParameterVector new_parameters = f->get_parameters();
for (size_t i = 0; i < new_parameters.size(); i++)
{
new_parameters[i] = std::static_pointer_cast<op::Parameter>(m[new_parameters[i].get()]);
}
ResultVector new_results = f->get_results();
for (size_t i = 0; i < new_results.size(); i++)
{
new_results[i] = std::static_pointer_cast<op::Result>(m[new_results[i].get()]);
}
return std::make_shared<Function>(new_results, new_parameters);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/function.hpp"
namespace ngraph
{
/// \brief Creates a clone of a function, with the shapes of that function's parameters
/// specialized to some more specific element types and shapes.
/// \param f The function to be cloned.
/// \param parameter_element_types The new parameter element types to substitute.
/// \param parameter_shapes The new parameter shapes to substitute.
/// \return A clone of f, with the parameter element types and shapes specialized.
/// \throws AssertionFailure if parameter_element_types or parameter_shapes is not valid
/// (see details).
/// \throws NodeValidationError if node validation fails as the clone is being constructed.
///
/// Creates a "shape-specialized" clone of an nGraph Function function.
///
/// For example, suppose that a function f has three parameters with partial shapes:
///
/// ```
/// param0: ?
/// param1: {1,?,3}
/// param2: {?,?,4}
/// ```
///
/// Shape specialization would allow us to create a clone of f where the shapes are (for
/// example):
///
/// ```
/// param0: {1,2}
/// param1: {1,5,3}
/// param2: {3,?,4}
/// ```
///
/// But not (for example):
///
/// ```
/// param1: {1,5,3,4} // rank doesn't match {1,?,3}
/// param1: {2,?,3} // the "2" doesn't match the "1"
/// param1: {?,?,3} // the new shape is too relaxed: it doesn't require 1 for the first dim
/// ```
///
/// Note that validation errors can potentially occur during cloning. For example:
///
/// ```
/// n = Parameter{shape=?}
/// m = Parameter{shape=?}
/// x = n + m
/// f = Function(x,{n,m})
/// ```
///
/// If we specialize n to the shape `{1,2,3}` and m to the shape `{4,5,6}`, cloning will fail
/// because when we reconstruct the new x node, it will see that the shapes are inconsistent
/// for elementwise add.
///
/// Specialization of element types is also possible: `element::dynamic` can be specialized
/// to a concrete element type or left dynamic; but a concrete element type can only be
/// specialized to itself (e.g., specialization does not allow you to change `element::i32`
/// to `element::i64`).
///
/// It is required that:
/// 1. The length of parameter_element_types and parameter_shapes is the same as the number
/// of f's parameters.
/// 2. Each shape in parameter_shapes is a refinement of the shape of the corresponding
/// parameter of f. Roughly speaking, a shape s1 is said to "refine" s2 if s1 can be
/// obtained from s2 by filling in s2's question marks. See PartialShape::refines for
/// more details.
/// 3. For all i, either the element type of fp_i is dynamic, or fp_i is the same as
/// parameter_element_types[i]. (Here fp_i is the ith parameter of f.)
std::shared_ptr<Function>
specialize_shapes(std::shared_ptr<Function> f,
const std::vector<element::Type>& parameter_element_types,
const std::vector<PartialShape>& parameter_shapes);
}
...@@ -52,6 +52,7 @@ set(SRC ...@@ -52,6 +52,7 @@ set(SRC
reshape_sinking.cpp reshape_sinking.cpp
serialize.cpp serialize.cpp
shape.cpp shape.cpp
specialize_shapes.cpp
tensor.cpp tensor.cpp
type_prop.cpp type_prop.cpp
util.cpp util.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/specialize_shapes.hpp"
using namespace ngraph;
// Simple case: create a function with static parameter shapes and "specialize" them to the same
// shapes.
TEST(specialize_shapes, et_shape_static)
{
auto p0 = std::make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
auto p1 = std::make_shared<op::Parameter>(element::i32, Shape{1, 2, 3});
auto k = std::make_shared<op::Convert>(p1, element::f32);
auto a = p0 + k;
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
auto g = specialize_shapes(
f, {element::f32, element::i32}, {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}});
ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3}));
ASSERT_EQ(g->get_output_element_type(0), element::f32);
}
// Test specialization of dynamic element types.
TEST(specialize_shapes, et_dynamic_shape_static)
{
auto p0 = std::make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3});
auto p1 = std::make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3});
auto k = std::make_shared<op::Convert>(p1, element::f32);
auto a = p0 + k;
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
auto g = specialize_shapes(
f, {element::f32, element::i32}, {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}});
ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3}));
ASSERT_EQ(g->get_output_element_type(0), element::f32);
}
// Test specialization of rank-dynamic shapes.
TEST(specialize_shapes, et_static_shape_rank_dynamic)
{
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
auto k = std::make_shared<op::Convert>(p1, element::f32);
auto a = p0 + k;
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
auto g = specialize_shapes(
f, {element::f32, element::i32}, {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}});
ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3}));
ASSERT_EQ(g->get_output_element_type(0), element::f32);
}
// Test specialization of rank-static dynamic shapes.
TEST(specialize_shapes, et_static_shape_rank_static_dynamic)
{
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic(3));
auto k = std::make_shared<op::Convert>(p1, element::f32);
auto a = p0 + k;
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
auto g = specialize_shapes(
f, {element::f32, element::i32}, {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}});
ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3}));
ASSERT_EQ(g->get_output_element_type(0), element::f32);
}
// Test specialization of rank-dynamic shapes to a case where validation will fail.
//
// (The input shapes we provide at specialization time are inconsistent.)
TEST(specialize_shapes, et_static_shape_rank_dynamic_validation_fails)
{
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
auto k = std::make_shared<op::Convert>(p1, element::f32);
auto a = p0 + k;
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
ASSERT_THROW(
{
specialize_shapes(
f, {element::f32, element::i32}, {PartialShape{1, 2, 3}, PartialShape{1, 2, 3, 4}});
},
NodeValidationFailure);
}
// Test specialization of dynamic element types to a case where validation will fail.
//
// (The input element types we provide at specialization time are inconsistent.)
TEST(specialize_shapes, et_dynamic_shape_static_validation_fails)
{
auto p0 = std::make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3});
auto p1 = std::make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3});
auto k = std::make_shared<op::Convert>(p1, element::f32);
auto a = p0 + k;
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
ASSERT_THROW(
{
specialize_shapes(
f, {element::u32, element::i32}, {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}});
},
NodeValidationFailure);
}
// Test specialization of rank-static dynamic shapes, where the replacement shapes have the wrong
// rank.
//
// (Note that we are testing for a different exception class here because the failure is in
// specialize_shape's pre-checks, which use NGRAPH_ASSERT, rather than inside validation as we
// reconstruct the graph.)
TEST(specialize_shapes, et_static_shape_rank_static_dynamic_rank_mismatch)
{
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic(3));
auto k = std::make_shared<op::Convert>(p1, element::f32);
auto a = p0 + k;
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
ASSERT_THROW(
{
specialize_shapes(
f, {element::f32, element::i32}, {PartialShape{1, 2, 3}, PartialShape{1, 2, 3, 4}});
},
AssertionFailure);
}
// Test specialization of rank-static dynamic shapes, where the replacement shapes have wrong
// dimensions.
//
// (Note that we are testing for a different exception class here because the failure is in
// specialize_shape's pre-checks, which use NGRAPH_ASSERT, rather than inside validation as we
// reconstruct the graph.)
TEST(specialize_shapes, et_static_shape_rank_static_dynamic_dim_mismatch)
{
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
auto p1 =
std::make_shared<op::Parameter>(element::i32, PartialShape{1, Dimension::dynamic(), 3});
auto k = std::make_shared<op::Convert>(p1, element::f32);
auto a = p0 + k;
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
ASSERT_THROW(
{
specialize_shapes(
f, {element::f32, element::i32}, {PartialShape{1, 2, 3}, PartialShape{1, 9, 4}});
},
AssertionFailure);
}
// Test for failure when we supply the wrong number of replacement element types.
TEST(specialize_shapes, et_count_wrong)
{
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape{1, 2, 3});
auto k = std::make_shared<op::Convert>(p1, element::f32);
auto a = p0 + k;
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
ASSERT_THROW(
{
specialize_shapes(f,
{element::f32, element::i32, element::u32},
{PartialShape{1, 2, 3}, PartialShape{1, 2, 3}});
},
AssertionFailure);
}
// Test for failure when we supply the wrong number of replacement shapes.
TEST(specialize_shapes, shape_count_wrong)
{
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape{1, 2, 3});
auto k = std::make_shared<op::Convert>(p1, element::f32);
auto a = p0 + k;
auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});
ASSERT_THROW(
{
specialize_shapes(
f,
{element::f32, element::i32},
{PartialShape{1, 2, 3}, PartialShape{1, 2, 3}, PartialShape{4, 5, 6}});
},
AssertionFailure);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment