Commit 490d6698 authored by Robert Kimball's avatar Robert Kimball Committed by Sang Ik Lee

Add ConstantToBroadcast pass (#2754)

* Add ConstantToBroadcast pass

* address review comment

* fix merge error
parent 3442879f
......@@ -291,6 +291,7 @@ set (SRC
pass/common_function_collection.hpp
pass/constant_folding.cpp
pass/constant_folding.hpp
pass/constant_to_broadcast.cpp
pass/core_fusion.cpp
pass/core_fusion.hpp
pass/cse.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/constant_to_broadcast.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
using namespace std;
using namespace ngraph;
template <typename T>
static bool is_data_constant(shared_ptr<op::Constant> constant)
{
const size_t size = shape_size(constant->get_shape());
bool data_is_constant = true;
if (size > 0)
{
const T* data = constant->get_data_ptr<T>();
const T compare = data[0];
for (size_t i = 1; i < size; i++)
{
if (data[i] != compare)
{
data_is_constant = false;
break;
}
}
if (data_is_constant)
{
auto scalar_constant = make_shared<op::Constant>(
constant->get_element_type(), Shape{}, constant->get_data_ptr());
AxisSet broadcast_axes;
for (size_t i = 0; i < constant->get_output_shape(0).size(); i++)
{
broadcast_axes.insert(i);
}
auto broadcast = make_shared<op::Broadcast>(
scalar_constant, constant->get_output_shape(0), broadcast_axes);
replace_node(constant, broadcast);
}
}
return data_is_constant;
}
bool pass::ConstantToBroadcast::run_on_node(shared_ptr<Node> node)
{
const size_t minimum_size_of_interest = 32;
bool modified = false;
if (node->description() == "Constant")
{
auto constant = static_pointer_cast<op::Constant>(node);
size_t size = shape_size(constant->get_shape());
if (size > minimum_size_of_interest)
{
switch (constant->get_element_type().get_type_enum())
{
case element::Type_t::boolean:
case element::Type_t::i8:
case element::Type_t::u8:
{
modified = is_data_constant<uint8_t>(constant);
break;
}
case element::Type_t::bf16:
case element::Type_t::i16:
case element::Type_t::u16:
{
modified = is_data_constant<uint16_t>(constant);
break;
}
case element::Type_t::f32:
case element::Type_t::i32:
case element::Type_t::u32:
{
modified = is_data_constant<uint32_t>(constant);
break;
}
case element::Type_t::f64:
case element::Type_t::i64:
case element::Type_t::u64:
{
modified = is_data_constant<uint64_t>(constant);
break;
}
case element::Type_t::undefined:
case element::Type_t::dynamic: break;
}
}
}
return modified;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class ConstantToBroadcast;
}
}
class ngraph::pass::ConstantToBroadcast : public NodePass
{
public:
bool run_on_node(std::shared_ptr<ngraph::Node>) override;
};
......@@ -286,6 +286,12 @@ void ngraph::serialize(const string& path, shared_ptr<ngraph::Function> func, si
}
void ngraph::serialize(ostream& out, shared_ptr<ngraph::Function> func, size_t indent)
{
out << ::serialize(func, indent, false);
}
#if defined ENABLE_CPIO_FILE
static void serialize_to_cpio(ostream& out, shared_ptr<ngraph::Function> func, size_t indent)
{
string j = ::serialize(func, indent, true);
cpio::Writer writer(out);
......@@ -305,6 +311,7 @@ void ngraph::serialize(ostream& out, shared_ptr<ngraph::Function> func, size_t i
true);
});
}
#endif
static string serialize(shared_ptr<ngraph::Function> func, size_t indent, bool binary_constant_data)
{
......
......@@ -30,7 +30,7 @@ namespace ngraph
/// indent level specified.
std::string serialize(std::shared_ptr<ngraph::Function> func, size_t indent = 0);
/// \brief Serialize a Function to as a json file
/// \brief Serialize a Function to a json file
/// \param path The path to the output file
/// \param func The Function to serialize
/// \param indent If 0 then there is no formatting applied and the resulting string is the
......@@ -40,7 +40,7 @@ namespace ngraph
std::shared_ptr<ngraph::Function> func,
size_t indent = 0);
/// \brief Serialize a Function to a CPIO file with all constant data stored as binary
/// \brief Serialize a Function to a json stream
/// \param out The output stream to which the data is serialized.
/// \param func The Function to serialize
/// \param indent If 0 then there is no formatting applied and the json is the
......
......@@ -24,6 +24,8 @@
#include <iostream>
#include <string>
#include "ngraph/pass/constant_to_broadcast.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
......@@ -41,6 +43,7 @@ SYNOPSIS
OPTIONS
-i or --input input serialized model
-o or --output output serialized model
-c or --constant_to_broacast Convert large constant constants to broadcast
)###";
}
......@@ -48,6 +51,7 @@ int main(int argc, char** argv)
{
string input;
string output;
bool c2b = false;
for (size_t i = 1; i < argc; i++)
{
string arg = argv[i];
......@@ -59,6 +63,10 @@ int main(int argc, char** argv)
{
input = argv[++i];
}
else if (arg == "-c" || arg == "--constant_to_broadcast")
{
c2b = true;
}
else if (arg == "-h" || arg == "--help")
{
help();
......@@ -66,6 +74,20 @@ int main(int argc, char** argv)
}
}
if (input.empty())
{
cout << "input file missing\n";
help();
return 1;
}
if (output.empty())
{
cout << "output file missing\n";
help();
return 1;
}
ifstream f(input);
if (f)
{
......@@ -75,6 +97,13 @@ int main(int argc, char** argv)
timer.stop();
cout << "deserialize took " << timer.get_milliseconds() << "ms\n";
if (c2b)
{
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::ConstantToBroadcast>();
pass_manager.run_passes(function);
}
timer.start();
ngraph::serialize(output, function, 2);
timer.stop();
......
......@@ -53,6 +53,7 @@ set(SRC
nop_elimination.cpp
op.cpp
partial_shape.cpp
pass.cpp
pass_liveness.cpp
pass_manager.cpp
pass_memory_layout.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/pass/constant_to_broadcast.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/serializer.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(pass, visualize_tree)
{
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>((A + B) * C, ParameterVector{A, B, C});
ngraph::pass::Manager pm;
pm.register_pass<pass::VisualizeTree>("test_viz.png");
pm.run_passes(f);
}
TEST(pass, constant_to_broadcast)
{
Shape shape{128, 256, 1, 1};
vector<float> v = {3};
auto c = make_shared<op::Constant>(element::f32, shape, v);
auto f = make_shared<Function>(c, ParameterVector{});
{
ngraph::pass::Manager pm;
pm.register_pass<pass::VisualizeTree>("pre_constant_to_broadcast.png");
pm.run_passes(f);
}
{
ngraph::pass::Manager pm;
pm.register_pass<pass::ConstantToBroadcast>();
EXPECT_EQ(count_ops_of_type<op::Broadcast>(f), 0);
pm.run_passes(f);
EXPECT_EQ(count_ops_of_type<op::Broadcast>(f), 1);
}
{
ngraph::pass::Manager pm;
pm.register_pass<pass::VisualizeTree>("post_constant_to_broadcast.png");
pm.run_passes(f);
}
}
......@@ -394,19 +394,6 @@ TEST(graph_util, test_subgraph_topological_sort_control_dependencies)
ASSERT_EQ(expected, sorted);
}
TEST(pass, visualize_tree)
{
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>((A + B) * C, ParameterVector{A, B, C});
ngraph::pass::Manager pm;
pm.register_pass<pass::VisualizeTree>("test_viz.png");
pm.run_passes(f);
}
TEST(util, enum_mask_construction)
{
enum class Type : uint32_t
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment