Commit a5908869 authored by Gleb Kazantaev's avatar Gleb Kazantaev Committed by Scott Cyphers

specialize_function with shared constants (#3845)

* Added constant sharing support

* Added unit tests for specialize_function

* Added parameter description

* Code style fix
parent 143fd0f2
...@@ -27,7 +27,7 @@ std::shared_ptr<Function> ...@@ -27,7 +27,7 @@ std::shared_ptr<Function>
const std::vector<void*>& parameter_values) const std::vector<void*>& parameter_values)
{ {
return specialize_function( return specialize_function(
f, parameter_element_types, parameter_shapes, parameter_values, false); f, parameter_element_types, parameter_shapes, parameter_values, false, false);
} }
std::shared_ptr<Function> std::shared_ptr<Function>
...@@ -35,7 +35,8 @@ std::shared_ptr<Function> ...@@ -35,7 +35,8 @@ std::shared_ptr<Function>
const std::vector<element::Type>& parameter_element_types, const std::vector<element::Type>& parameter_element_types,
const std::vector<PartialShape>& parameter_shapes, const std::vector<PartialShape>& parameter_shapes,
const std::vector<void*>& parameter_values, const std::vector<void*>& parameter_values,
bool constant_folding) bool constant_folding,
bool share_constants)
{ {
NGRAPH_CHECK(f->get_parameters().size() == parameter_shapes.size()); NGRAPH_CHECK(f->get_parameters().size() == parameter_shapes.size());
NGRAPH_CHECK(f->get_parameters().size() == parameter_element_types.size()); NGRAPH_CHECK(f->get_parameters().size() == parameter_element_types.size());
...@@ -74,7 +75,16 @@ std::shared_ptr<Function> ...@@ -74,7 +75,16 @@ std::shared_ptr<Function>
auto output = input.get_source_output(); auto output = input.get_source_output();
new_args.push_back(output.for_node(m[output.get_node()])); new_args.push_back(output.for_node(m[output.get_node()]));
} }
m[old_node.get()] = old_node->copy_with_new_inputs(new_args);
if (share_constants && as_type_ptr<op::Constant>(old_node))
{
m[old_node.get()] = old_node;
}
else
{
m[old_node.get()] = old_node->copy_with_new_inputs(new_args);
}
m[old_node.get()]->set_friendly_name(old_node->get_friendly_name()); m[old_node.get()]->set_friendly_name(old_node->get_friendly_name());
} }
......
...@@ -121,6 +121,8 @@ namespace ngraph ...@@ -121,6 +121,8 @@ namespace ngraph
/// of parameters of f, with nullptr indicating that no substitution is to be made for /// of parameters of f, with nullptr indicating that no substitution is to be made for
/// the corresponding parameter. /// the corresponding parameter.
/// \param constant_folding If flag is true, constant propagation is applied /// \param constant_folding If flag is true, constant propagation is applied
/// \param share_constants If flag is true, cloned function will have shared constants with
/// original function.
/// \return A clone of f, with the parameter element types, shapes, and values specialized. /// \return A clone of f, with the parameter element types, shapes, and values specialized.
/// \throws CheckFailure if parameter_element_types, parameter_shapes is not valid /// \throws CheckFailure if parameter_element_types, parameter_shapes is not valid
/// (see details). /// (see details).
...@@ -198,5 +200,6 @@ namespace ngraph ...@@ -198,5 +200,6 @@ namespace ngraph
const std::vector<element::Type>& parameter_element_types, const std::vector<element::Type>& parameter_element_types,
const std::vector<PartialShape>& parameter_shapes, const std::vector<PartialShape>& parameter_shapes,
const std::vector<void*>& parameter_values, const std::vector<void*>& parameter_values,
bool constant_folding); bool constant_folding,
bool share_constants);
} }
...@@ -321,3 +321,44 @@ TEST(specialize_function, value_count_wrong) ...@@ -321,3 +321,44 @@ TEST(specialize_function, value_count_wrong)
}, },
CheckFailure); CheckFailure);
} }
// Test checks that constant sharing is working
TEST(specialize_function, share_constants)
{
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 3, 64, 64});
auto mul_const = op::Constant::create(element::f32, Shape{1, 3, 1, 1}, {1, 2, 3});
auto mul = std::make_shared<op::Multiply>(p0, mul_const, op::AutoBroadcastType::NUMPY);
auto add_const = op::Constant::create(element::f32, Shape{1, 3, 1, 1}, {4, 5, 6});
auto add = std::make_shared<op::Multiply>(mul, add_const, op::AutoBroadcastType::NUMPY);
auto f = std::make_shared<Function>(add, ParameterVector{p0});
auto f_specialized =
specialize_function(f, {element::f32}, {PartialShape{2, 3, 64, 64}}, {nullptr}, true, true);
ASSERT_EQ(mul_const->output(0).get_target_inputs().size(), 2);
ASSERT_EQ(add_const->output(0).get_target_inputs().size(), 2);
}
// Test checks that constant sharing works when constant folding replaces constants
TEST(specialize_function, share_constants_with_cf)
{
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 3, 64, 64});
auto mul_const = op::Constant::create(element::f32, Shape{1, 3, 1, 1}, {1, 2, 3});
auto mul = std::make_shared<op::Multiply>(p0, mul_const, op::AutoBroadcastType::NUMPY);
auto add_const_1 = op::Constant::create(element::f32, Shape{1, 3, 1, 1}, {4, 5, 6});
auto add_const_2 = op::Constant::create(element::f32, Shape{1, 3, 1, 1}, {1, 2, 3});
auto add_const = std::make_shared<op::Add>(add_const_1, add_const_2);
auto add = std::make_shared<op::Add>(mul, add_const, op::AutoBroadcastType::NUMPY);
auto f = std::make_shared<Function>(add, ParameterVector{p0});
auto f_specialized =
specialize_function(f, {element::f32}, {PartialShape{2, 3, 64, 64}}, {nullptr}, true, true);
ASSERT_EQ(mul_const->output(0).get_target_inputs().size(), 2);
ASSERT_EQ(add_const_1->output(0).get_target_inputs().size(), 1);
ASSERT_EQ(add_const_2->output(0).get_target_inputs().size(), 1);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment