Commit a5908869 authored by Gleb Kazantaev's avatar Gleb Kazantaev Committed by Scott Cyphers

specialize_function with shared constants (#3845)

* Added constant sharing support

* Added unit tests for specialize_function

* Added parameter description

* Code style fix
parent 143fd0f2
......@@ -27,7 +27,7 @@ std::shared_ptr<Function>
const std::vector<void*>& parameter_values)
{
return specialize_function(
f, parameter_element_types, parameter_shapes, parameter_values, false);
f, parameter_element_types, parameter_shapes, parameter_values, false, false);
}
std::shared_ptr<Function>
......@@ -35,7 +35,8 @@ std::shared_ptr<Function>
const std::vector<element::Type>& parameter_element_types,
const std::vector<PartialShape>& parameter_shapes,
const std::vector<void*>& parameter_values,
bool constant_folding)
bool constant_folding,
bool share_constants)
{
NGRAPH_CHECK(f->get_parameters().size() == parameter_shapes.size());
NGRAPH_CHECK(f->get_parameters().size() == parameter_element_types.size());
......@@ -74,7 +75,16 @@ std::shared_ptr<Function>
auto output = input.get_source_output();
new_args.push_back(output.for_node(m[output.get_node()]));
}
if (share_constants && as_type_ptr<op::Constant>(old_node))
{
m[old_node.get()] = old_node;
}
else
{
m[old_node.get()] = old_node->copy_with_new_inputs(new_args);
}
m[old_node.get()]->set_friendly_name(old_node->get_friendly_name());
}
......
......@@ -121,6 +121,8 @@ namespace ngraph
/// of parameters of f, with nullptr indicating that no substitution is to be made for
/// the corresponding parameter.
/// \param constant_folding If flag is true, constant propagation is applied
/// \param share_constants If flag is true, cloned function will have shared constants with
/// original function.
/// \return A clone of f, with the parameter element types, shapes, and values specialized.
/// \throws CheckFailure if parameter_element_types, parameter_shapes is not valid
/// (see details).
......@@ -198,5 +200,6 @@ namespace ngraph
const std::vector<element::Type>& parameter_element_types,
const std::vector<PartialShape>& parameter_shapes,
const std::vector<void*>& parameter_values,
bool constant_folding);
bool constant_folding,
bool share_constants);
}
......@@ -321,3 +321,44 @@ TEST(specialize_function, value_count_wrong)
},
CheckFailure);
}
// Test checks that constant sharing is working
TEST(specialize_function, share_constants)
{
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 3, 64, 64});
auto mul_const = op::Constant::create(element::f32, Shape{1, 3, 1, 1}, {1, 2, 3});
auto mul = std::make_shared<op::Multiply>(p0, mul_const, op::AutoBroadcastType::NUMPY);
auto add_const = op::Constant::create(element::f32, Shape{1, 3, 1, 1}, {4, 5, 6});
auto add = std::make_shared<op::Multiply>(mul, add_const, op::AutoBroadcastType::NUMPY);
auto f = std::make_shared<Function>(add, ParameterVector{p0});
auto f_specialized =
specialize_function(f, {element::f32}, {PartialShape{2, 3, 64, 64}}, {nullptr}, true, true);
ASSERT_EQ(mul_const->output(0).get_target_inputs().size(), 2);
ASSERT_EQ(add_const->output(0).get_target_inputs().size(), 2);
}
// Test checks that constant sharing works when constant folding replaces constants
TEST(specialize_function, share_constants_with_cf)
{
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 3, 64, 64});
auto mul_const = op::Constant::create(element::f32, Shape{1, 3, 1, 1}, {1, 2, 3});
auto mul = std::make_shared<op::Multiply>(p0, mul_const, op::AutoBroadcastType::NUMPY);
auto add_const_1 = op::Constant::create(element::f32, Shape{1, 3, 1, 1}, {4, 5, 6});
auto add_const_2 = op::Constant::create(element::f32, Shape{1, 3, 1, 1}, {1, 2, 3});
auto add_const = std::make_shared<op::Add>(add_const_1, add_const_2);
auto add = std::make_shared<op::Add>(mul, add_const, op::AutoBroadcastType::NUMPY);
auto f = std::make_shared<Function>(add, ParameterVector{p0});
auto f_specialized =
specialize_function(f, {element::f32}, {PartialShape{2, 3, 64, 64}}, {nullptr}, true, true);
ASSERT_EQ(mul_const->output(0).get_target_inputs().size(), 2);
ASSERT_EQ(add_const_1->output(0).get_target_inputs().size(), 1);
ASSERT_EQ(add_const_2->output(0).get_target_inputs().size(), 1);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment