Commit 2d75f665 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Zero Dimension Tensor Elimination (#617)

*  zero dimension tensor elimination init

* more ops + refactor + tests

* revert pattern.cpp

* add internal zero-length test

* address Scott's feedback

* fix comp errors

* proper static init

* get rid of unique-ptr

* refactor hashmap into virtual get_default_values on op classes

* fix formatting
parent 96604f12
......@@ -114,9 +114,11 @@ set (SRC
pass/pass.cpp
pass/reshape_elimination.cpp
pass/result_copy_elimination.cpp
pass/zero_dim_tensor_elimination.cpp
pass/validate_graph.cpp
pass/visualize_tree.cpp
pass/core_fusion.cpp
pass/zero_dim_tensor_elimination.cpp
pattern/matcher.cpp
runtime/aligned_buffer.cpp
runtime/backend.cpp
......
......@@ -409,3 +409,11 @@ std::shared_ptr<Node> ngraph::make_zero(const element::Type& element_type, const
}
return zero;
}
std::shared_ptr<Node> ngraph::make_constant_from_string(std::string val,
const element::Type& element_type,
const Shape& shape)
{
auto cvals = std::vector<std::string>(shape_size(shape), val);
return std::make_shared<op::Constant>(element_type, shape, cvals);
}
......@@ -123,4 +123,8 @@ namespace ngraph
const std::shared_ptr<Node>& new_node);
std::shared_ptr<Node> make_zero(const element::Type& element_type, const Shape& shape);
std::shared_ptr<Node> make_constant_from_string(std::string val,
const element::Type& element_type,
const Shape& shape);
}
......@@ -191,6 +191,7 @@ namespace ngraph
/// Get all the nodes that uses the current node
NodeVector get_users() const;
virtual std::shared_ptr<Node> get_default_value() const { return nullptr; }
protected:
void add_output(const element::Type& element_type, const Shape& shape);
......
......@@ -16,6 +16,7 @@
#pragma once
#include "ngraph/graph_util.hpp"
#include "ngraph/op/util/requires_tensor_view_args.hpp"
namespace ngraph
......@@ -87,6 +88,11 @@ namespace ngraph
{
return m_include_padding_in_avg_computation;
}
/// \return The default value for AvgPool.
virtual std::shared_ptr<Node> get_default_value() const override
{
return ngraph::make_constant_from_string("0", get_element_type(), get_shape());
}
protected:
Shape m_window_shape;
......
......@@ -17,6 +17,7 @@
#pragma once
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/util/requires_tensor_view_args.hpp"
namespace ngraph
......@@ -138,6 +139,12 @@ namespace ngraph
const CoordinateDiff& get_padding_above() const { return m_padding_above; }
/// \return The input data dilation strides.
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
/// \return The default value for Convolution.
virtual std::shared_ptr<Node> get_default_value() const override
{
return ngraph::make_constant_from_string("0", get_element_type(), get_shape());
}
protected:
Strides m_window_movement_strides;
Strides m_window_dilation_strides;
......
......@@ -16,6 +16,7 @@
#pragma once
#include "ngraph/graph_util.hpp"
#include "ngraph/op/util/requires_tensor_view_args.hpp"
namespace ngraph
......@@ -83,6 +84,12 @@ namespace ngraph
const Shape& get_padding_below() const { return m_padding_below; }
/// \return The above-padding shape.
const Shape& get_padding_above() const { return m_padding_above; }
/// \return The default value for MaxPool.
virtual std::shared_ptr<Node> get_default_value() const override
{
return ngraph::make_constant_from_string("0", get_element_type(), get_shape());
}
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
......
......@@ -15,6 +15,7 @@
*******************************************************************************/
#include "ngraph/op/pad.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/util.hpp"
using namespace std;
......@@ -120,7 +121,16 @@ shared_ptr<Node> op::Pad::copy_with_new_args(const NodeVector& new_args) const
*/
void op::Pad::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
auto delta = deltas.at(0);
throw invalid_argument("Autodiff is not yet implemented for Pad");
}
std::shared_ptr<Node> op::Pad::get_default_value() const
{
AxisSet axes{};
for (size_t i = 0; i < get_shape().size(); i++)
{
axes.insert(i);
}
return std::make_shared<op::Broadcast>(
m_inputs.at(1).get_output().get_node(), get_shape(), axes);
}
......@@ -80,6 +80,9 @@ namespace ngraph
const Shape& get_padding_above() const { return m_padding_above; }
/// \return The interior padding sizes.
const Shape& get_padding_interior() const { return m_padding_interior; }
/// \return The default value for Pad.
virtual std::shared_ptr<Node> get_default_value() const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
......
......@@ -16,6 +16,7 @@
#pragma once
#include "ngraph/graph_util.hpp"
#include "ngraph/op/util/arithmetic_reduction.hpp"
namespace ngraph
......@@ -84,6 +85,12 @@ namespace ngraph
/// \param reduction_axes The axis positions (0-based) to be eliminated.
Product(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes);
/// \return The default value for Product.
virtual std::shared_ptr<Node> get_default_value() const override
{
return ngraph::make_constant_from_string("1", get_element_type(), get_shape());
}
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
......
......@@ -17,6 +17,7 @@
#pragma once
#include "ngraph/axis_set.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/util/arithmetic_reduction.hpp"
#include "ngraph/op/util/requires_tensor_view_args.hpp"
......@@ -90,6 +91,12 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \return The default value for Sum.
virtual std::shared_ptr<Node> get_default_value() const override
{
return ngraph::make_constant_from_string("0", get_element_type(), get_shape());
}
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <memory>
#include <set>
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/sum.hpp"
#include "zero_dim_tensor_elimination.hpp"
using namespace ngraph;
static bool has_zero_dim(std::shared_ptr<Node> node)
{
if (node->get_output_size() != 1)
{
throw ngraph_error("has_zero_dim is called on multi-output op");
}
return shape_size(node->get_shape()) == 0;
}
static bool verify_no_internal_zero_length_ops(std::shared_ptr<ngraph::Function> f)
{
std::set<std::shared_ptr<Node>> zero_length_nodes;
for (auto n : f->get_ordered_ops())
{
if (n->is_output() || n->is_parameter() || n->get_outputs().size() > 1)
{
continue;
}
if (has_zero_dim(n))
{
zero_length_nodes.insert(n);
}
}
//all zero-length ops should be in a result set
//if we remove all such nodes included in the result set
//from zero_length_nodes and there are still nodes left
//(in zero_length_nodes), this means we have INTERNAL
//zero-length nodes (which violates our assumption)
for (auto r : f->get_results())
{
auto n = r->get_input_op(0);
if (zero_length_nodes.count(n) != 0)
{
zero_length_nodes.erase(n);
}
}
return zero_length_nodes.size() > 0;
}
bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngraph::Function> f)
{
bool replaced = false;
//we need to go over all nodes since we could have sum or any other 0-length-tensor-to scalar op
//as an internal node (i.e. a node that isn't an argument to `op::Result`)
for (auto n : f->get_ordered_ops())
{
//don't try to replace `op::Result`
//all multi-output feed into `GetOutputElement`
//if any `GetOutputElement` is zero-length
//we replace it w/ a signalling constant
//so we don't have to deal w/ multi-output nodes directly
if (n->is_output() || n->is_parameter() || n->get_outputs().size() > 1)
{
continue;
}
if (has_zero_dim(n))
{
//we don't have to create constants every time but this is the easiest
//and it's CSE's job to eliminate the same ones
auto cvals = std::vector<std::string>(0);
auto constant =
std::make_shared<op::Constant>(n->get_element_type(), n->get_shape(), cvals);
replace_node(n, constant);
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << constant->get_name();
replaced = true;
continue;
}
auto new_node = n->get_default_value();
if (!new_node || !has_zero_dim(n->get_input_op(0)))
{
continue;
}
replaced = true;
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << new_node->get_name();
replace_node(n, new_node);
}
if (verify_no_internal_zero_length_ops(f))
{
throw ngraph_error("there were internal zero-length nodes in a graph");
}
return replaced;
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class ZeroDimTensorElimination;
}
}
class ngraph::pass::ZeroDimTensorElimination : public FunctionPass
{
public:
ZeroDimTensorElimination()
: FunctionPass()
{
}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
};
......@@ -57,6 +57,7 @@ set (SRC
util/benchmark.cpp
util.cpp
uuid.cpp
zero_dim_tensor_elimination.cpp
)
add_subdirectory(models)
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/file_util.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pass/zero_dim_tensor_elimination.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(zero_dim_tensor_elimination, zero_sum)
{
Shape zero_shape{0};
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto sum = std::make_shared<op::Sum>(A, AxisSet{0});
auto abs_node = std::make_shared<op::Abs>(A);
auto sum_node = std::make_shared<op::Sum>(abs_node, AxisSet{0});
auto constant = std::make_shared<op::Constant>(element::i32, zero_shape, std::vector<string>{});
auto f = std::make_shared<Function>(NodeVector{sum_node, constant}, op::ParameterVector{A});
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Sum>(f), 0);
}
TEST(zero_dim_tensor_elimination, zero_const_conv)
{
Shape zero_shape{0};
auto A = std::make_shared<op::Parameter>(element::f32, Shape{1, 1, 0});
auto weights = std::make_shared<op::Parameter>(element::f32, Shape{1, 1, 4});
auto convolution = std::make_shared<op::Convolution>(
A, weights, Strides{1}, Strides{1}, CoordinateDiff{2}, CoordinateDiff{2});
auto abs_node = std::make_shared<op::Abs>(convolution);
auto constant = std::make_shared<op::Constant>(element::i32, zero_shape, std::vector<string>{});
auto f =
std::make_shared<Function>(NodeVector{abs_node, constant}, op::ParameterVector{A, weights});
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Convolution>(f), 0);
}
TEST(zero_dim_tensor_elimination, zero_const_avg_pool)
{
Shape zero_shape{0};
auto A = std::make_shared<op::Parameter>(element::f32, Shape{1, 1, 0});
auto avg_pool =
std::make_shared<op::AvgPool>(A, Shape{1}, Strides{1}, Shape{2}, Shape{2}, true);
auto abs_node = std::make_shared<op::Abs>(avg_pool);
auto constant = std::make_shared<op::Constant>(element::i32, zero_shape, std::vector<string>{});
auto f = std::make_shared<Function>(NodeVector{abs_node, constant}, op::ParameterVector{A});
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::AvgPool>(f), 0);
}
TEST(zero_dim_tensor_elimination, zero_const_pad)
{
Shape zero_shape{0};
auto A = std::make_shared<op::Parameter>(element::f32, zero_shape);
auto B = std::make_shared<op::Parameter>(element::f32, Shape{});
auto pad = std::make_shared<op::Pad>(A, B, Shape{2}, Shape{2}, Shape{0});
auto abs_node = std::make_shared<op::Abs>(pad);
auto constant = std::make_shared<op::Constant>(element::i32, zero_shape, std::vector<string>{});
auto f = std::make_shared<Function>(NodeVector{abs_node, constant}, op::ParameterVector{A, B});
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Broadcast>(f), 1);
}
TEST(zero_dim_tensor_elimination, zero_const_slice)
{
Shape zero_shape{0};
auto A = std::make_shared<op::Parameter>(element::f32, zero_shape);
auto B = std::make_shared<op::Parameter>(element::f32, Shape{});
auto slice = make_shared<op::Slice>(A, Coordinate{0}, Coordinate{0});
auto pad = std::make_shared<op::Pad>(A, B, Shape{2}, Shape{2}, Shape{0});
auto abs_node = std::make_shared<op::Abs>(pad);
auto constant = std::make_shared<op::Constant>(element::i32, zero_shape, std::vector<string>{});
auto f = std::make_shared<Function>(NodeVector{abs_node, constant}, op::ParameterVector{A, B});
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Broadcast>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Slice>(f), 0);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment