Commit 490d6698 authored by Robert Kimball's avatar Robert Kimball Committed by Sang Ik Lee

Add ConstantToBroadcast pass (#2754)

* Add ConstantToBroadcast pass

* address review comment

* fix merge error
parent 3442879f
...@@ -291,6 +291,7 @@ set (SRC ...@@ -291,6 +291,7 @@ set (SRC
pass/common_function_collection.hpp pass/common_function_collection.hpp
pass/constant_folding.cpp pass/constant_folding.cpp
pass/constant_folding.hpp pass/constant_folding.hpp
pass/constant_to_broadcast.cpp
pass/core_fusion.cpp pass/core_fusion.cpp
pass/core_fusion.hpp pass/core_fusion.hpp
pass/cse.cpp pass/cse.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/constant_to_broadcast.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
using namespace std;
using namespace ngraph;
template <typename T>
static bool is_data_constant(shared_ptr<op::Constant> constant)
{
const size_t size = shape_size(constant->get_shape());
bool data_is_constant = true;
if (size > 0)
{
const T* data = constant->get_data_ptr<T>();
const T compare = data[0];
for (size_t i = 1; i < size; i++)
{
if (data[i] != compare)
{
data_is_constant = false;
break;
}
}
if (data_is_constant)
{
auto scalar_constant = make_shared<op::Constant>(
constant->get_element_type(), Shape{}, constant->get_data_ptr());
AxisSet broadcast_axes;
for (size_t i = 0; i < constant->get_output_shape(0).size(); i++)
{
broadcast_axes.insert(i);
}
auto broadcast = make_shared<op::Broadcast>(
scalar_constant, constant->get_output_shape(0), broadcast_axes);
replace_node(constant, broadcast);
}
}
return data_is_constant;
}
bool pass::ConstantToBroadcast::run_on_node(shared_ptr<Node> node)
{
const size_t minimum_size_of_interest = 32;
bool modified = false;
if (node->description() == "Constant")
{
auto constant = static_pointer_cast<op::Constant>(node);
size_t size = shape_size(constant->get_shape());
if (size > minimum_size_of_interest)
{
switch (constant->get_element_type().get_type_enum())
{
case element::Type_t::boolean:
case element::Type_t::i8:
case element::Type_t::u8:
{
modified = is_data_constant<uint8_t>(constant);
break;
}
case element::Type_t::bf16:
case element::Type_t::i16:
case element::Type_t::u16:
{
modified = is_data_constant<uint16_t>(constant);
break;
}
case element::Type_t::f32:
case element::Type_t::i32:
case element::Type_t::u32:
{
modified = is_data_constant<uint32_t>(constant);
break;
}
case element::Type_t::f64:
case element::Type_t::i64:
case element::Type_t::u64:
{
modified = is_data_constant<uint64_t>(constant);
break;
}
case element::Type_t::undefined:
case element::Type_t::dynamic: break;
}
}
}
return modified;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class ConstantToBroadcast;
}
}
class ngraph::pass::ConstantToBroadcast : public NodePass
{
public:
bool run_on_node(std::shared_ptr<ngraph::Node>) override;
};
...@@ -286,6 +286,12 @@ void ngraph::serialize(const string& path, shared_ptr<ngraph::Function> func, si ...@@ -286,6 +286,12 @@ void ngraph::serialize(const string& path, shared_ptr<ngraph::Function> func, si
} }
void ngraph::serialize(ostream& out, shared_ptr<ngraph::Function> func, size_t indent) void ngraph::serialize(ostream& out, shared_ptr<ngraph::Function> func, size_t indent)
{
out << ::serialize(func, indent, false);
}
#if defined ENABLE_CPIO_FILE
static void serialize_to_cpio(ostream& out, shared_ptr<ngraph::Function> func, size_t indent)
{ {
string j = ::serialize(func, indent, true); string j = ::serialize(func, indent, true);
cpio::Writer writer(out); cpio::Writer writer(out);
...@@ -305,6 +311,7 @@ void ngraph::serialize(ostream& out, shared_ptr<ngraph::Function> func, size_t i ...@@ -305,6 +311,7 @@ void ngraph::serialize(ostream& out, shared_ptr<ngraph::Function> func, size_t i
true); true);
}); });
} }
#endif
static string serialize(shared_ptr<ngraph::Function> func, size_t indent, bool binary_constant_data) static string serialize(shared_ptr<ngraph::Function> func, size_t indent, bool binary_constant_data)
{ {
......
...@@ -30,7 +30,7 @@ namespace ngraph ...@@ -30,7 +30,7 @@ namespace ngraph
/// indent level specified. /// indent level specified.
std::string serialize(std::shared_ptr<ngraph::Function> func, size_t indent = 0); std::string serialize(std::shared_ptr<ngraph::Function> func, size_t indent = 0);
/// \brief Serialize a Function to as a json file /// \brief Serialize a Function to a json file
/// \param path The path to the output file /// \param path The path to the output file
/// \param func The Function to serialize /// \param func The Function to serialize
/// \param indent If 0 then there is no formatting applied and the resulting string is the /// \param indent If 0 then there is no formatting applied and the resulting string is the
...@@ -40,7 +40,7 @@ namespace ngraph ...@@ -40,7 +40,7 @@ namespace ngraph
std::shared_ptr<ngraph::Function> func, std::shared_ptr<ngraph::Function> func,
size_t indent = 0); size_t indent = 0);
/// \brief Serialize a Function to a CPIO file with all constant data stored as binary /// \brief Serialize a Function to a json stream
/// \param out The output stream to which the data is serialized. /// \param out The output stream to which the data is serialized.
/// \param func The Function to serialize /// \param func The Function to serialize
/// \param indent If 0 then there is no formatting applied and the json is the /// \param indent If 0 then there is no formatting applied and the json is the
......
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
#include <iostream> #include <iostream>
#include <string> #include <string>
#include "ngraph/pass/constant_to_broadcast.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/serializer.hpp" #include "ngraph/serializer.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -41,6 +43,7 @@ SYNOPSIS ...@@ -41,6 +43,7 @@ SYNOPSIS
OPTIONS OPTIONS
-i or --input input serialized model -i or --input input serialized model
-o or --output output serialized model -o or --output output serialized model
-c or --constant_to_broacast Convert large constant constants to broadcast
)###"; )###";
} }
...@@ -48,6 +51,7 @@ int main(int argc, char** argv) ...@@ -48,6 +51,7 @@ int main(int argc, char** argv)
{ {
string input; string input;
string output; string output;
bool c2b = false;
for (size_t i = 1; i < argc; i++) for (size_t i = 1; i < argc; i++)
{ {
string arg = argv[i]; string arg = argv[i];
...@@ -59,6 +63,10 @@ int main(int argc, char** argv) ...@@ -59,6 +63,10 @@ int main(int argc, char** argv)
{ {
input = argv[++i]; input = argv[++i];
} }
else if (arg == "-c" || arg == "--constant_to_broadcast")
{
c2b = true;
}
else if (arg == "-h" || arg == "--help") else if (arg == "-h" || arg == "--help")
{ {
help(); help();
...@@ -66,6 +74,20 @@ int main(int argc, char** argv) ...@@ -66,6 +74,20 @@ int main(int argc, char** argv)
} }
} }
if (input.empty())
{
cout << "input file missing\n";
help();
return 1;
}
if (output.empty())
{
cout << "output file missing\n";
help();
return 1;
}
ifstream f(input); ifstream f(input);
if (f) if (f)
{ {
...@@ -75,6 +97,13 @@ int main(int argc, char** argv) ...@@ -75,6 +97,13 @@ int main(int argc, char** argv)
timer.stop(); timer.stop();
cout << "deserialize took " << timer.get_milliseconds() << "ms\n"; cout << "deserialize took " << timer.get_milliseconds() << "ms\n";
if (c2b)
{
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::ConstantToBroadcast>();
pass_manager.run_passes(function);
}
timer.start(); timer.start();
ngraph::serialize(output, function, 2); ngraph::serialize(output, function, 2);
timer.stop(); timer.stop();
......
...@@ -53,6 +53,7 @@ set(SRC ...@@ -53,6 +53,7 @@ set(SRC
nop_elimination.cpp nop_elimination.cpp
op.cpp op.cpp
partial_shape.cpp partial_shape.cpp
pass.cpp
pass_liveness.cpp pass_liveness.cpp
pass_manager.cpp pass_manager.cpp
pass_memory_layout.cpp pass_memory_layout.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/pass/constant_to_broadcast.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/serializer.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(pass, visualize_tree)
{
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>((A + B) * C, ParameterVector{A, B, C});
ngraph::pass::Manager pm;
pm.register_pass<pass::VisualizeTree>("test_viz.png");
pm.run_passes(f);
}
TEST(pass, constant_to_broadcast)
{
Shape shape{128, 256, 1, 1};
vector<float> v = {3};
auto c = make_shared<op::Constant>(element::f32, shape, v);
auto f = make_shared<Function>(c, ParameterVector{});
{
ngraph::pass::Manager pm;
pm.register_pass<pass::VisualizeTree>("pre_constant_to_broadcast.png");
pm.run_passes(f);
}
{
ngraph::pass::Manager pm;
pm.register_pass<pass::ConstantToBroadcast>();
EXPECT_EQ(count_ops_of_type<op::Broadcast>(f), 0);
pm.run_passes(f);
EXPECT_EQ(count_ops_of_type<op::Broadcast>(f), 1);
}
{
ngraph::pass::Manager pm;
pm.register_pass<pass::VisualizeTree>("post_constant_to_broadcast.png");
pm.run_passes(f);
}
}
...@@ -394,19 +394,6 @@ TEST(graph_util, test_subgraph_topological_sort_control_dependencies) ...@@ -394,19 +394,6 @@ TEST(graph_util, test_subgraph_topological_sort_control_dependencies)
ASSERT_EQ(expected, sorted); ASSERT_EQ(expected, sorted);
} }
TEST(pass, visualize_tree)
{
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>((A + B) * C, ParameterVector{A, B, C});
ngraph::pass::Manager pm;
pm.register_pass<pass::VisualizeTree>("test_viz.png");
pm.run_passes(f);
}
TEST(util, enum_mask_construction) TEST(util, enum_mask_construction)
{ {
enum class Type : uint32_t enum class Type : uint32_t
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment