Commit de760a38 authored by Louis Feng's avatar Louis Feng Committed by Scott Cyphers

ngmx-372 Fuse RNN matrix operations (#724)

* test graph.

* rnn matrix fusion wip.

* WIP.

* matrix fused.

* clean up.

* additional refactor.

* fixed merge errors.

* testing.

* added test case.

* more testing.

* more testing.

* bias wip.

* debuggging.

* fusing bias too.

* disabled debug outputs.

* removed debugging.

* removed testing models.

* removed debug models.

* reset nbench.

* clean up and testing.

* removed debug code.

* updated comments and var names.

* rename var.

* removed debug code.

* removed debug code.

* fixed compiler warnings.

* refactored test.

* Added const to parameter.

* address some of the review comments.

* added comments and addressed reviews.
parent e548bb03
......@@ -214,6 +214,7 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND
runtime/cpu/pass/cpu_fusion.cpp
runtime/cpu/pass/cpu_layout.cpp
runtime/cpu/pass/cpu_nop_elimination.cpp
runtime/cpu/pass/cpu_rnn_mat_fusion.cpp
)
# LLVM binary builds are typically built without RTTI
# The built-in headers are in a version-specific directory
......
......@@ -328,3 +328,18 @@ descriptor::Output* Node::get_output_to(const shared_ptr<Node>& dst)
}
throw ngraph_error("Error: dst is not one of self's output Node");
}
NodeVector Node::get_users() const
{
NodeVector result;
for (size_t i = 0; i < get_output_size(); ++i)
{
for (auto input : get_output_inputs(i))
{
result.push_back(input->get_node());
}
}
return result;
}
......@@ -206,6 +206,9 @@ namespace ngraph
/// Get ouput descriptor that outputs to dst
descriptor::Output* get_output_to(const std::shared_ptr<Node>& dst);
/// Get all the nodes that uses the current node
NodeVector get_users() const;
protected:
void add_output(const element::Type& element_type, const Shape& shape);
......
......@@ -47,6 +47,8 @@ namespace ngraph
{
}
NodeVector& operator=(const NodeVector& other) = default;
NodeVector() {}
};
}
......@@ -222,21 +222,6 @@ namespace ngraph
return cb(*this);
}
static NodeVector get_users(std::shared_ptr<Node> node)
{
NodeVector result;
for (size_t i = 0; i < node->get_output_size(); ++i)
{
for (auto input : node->get_output_inputs(i))
{
result.push_back(input->get_node());
}
}
return result;
}
bool Matcher::match(const std::shared_ptr<Node>& graph_node)
{
//clear our state
......@@ -248,8 +233,6 @@ namespace ngraph
throw "m_pattern_node or graph_node are not set!";
}
(void)get_users; //to supress an unused function warning
NGRAPH_DEBUG << "[MATCHER] Starting match pattern = " << m_pattern_node->get_name()
<< " , graph_node = " << graph_node->get_name();
......
This diff is collapsed.
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace pass
{
class CPURnnMatFusion : public ngraph::pass::FunctionPass
{
public:
bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
};
}
}
}
}
......@@ -46,6 +46,7 @@
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_rnn_mat_fusion.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "nlohmann/json.hpp"
......@@ -53,6 +54,7 @@
#include "util/autodiff/backprop_function.hpp"
#include "util/autodiff/numeric_compare.hpp"
#include "util/matcher.hpp"
#include "util/random.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
......@@ -1094,3 +1096,94 @@ TEST(cpu_fusion, batchnorm_fprop_inference_b2c2h2w1)
ASSERT_TRUE(
ngraph::test::all_close(expected_result, read_vector<float>(bn_output), 1e-3f, 1e-4f));
}
std::vector<shared_ptr<runtime::TensorView>>
rnn_matrix_fusion_eval(const size_t time_steps,
const Shape& data_shape,
const Shape& weights_shape,
const Shape& bias_shape,
const vector<float>& data_val,
const vector<float>& weights_val,
const vector<float>& bias_val,
const bool enable_pass)
{
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto weights = make_shared<op::Parameter>(element::f32, weights_shape);
auto bias = make_shared<op::Parameter>(element::f32, bias_shape);
// results from each time step
NodeVector results;
for (size_t t = 0; t < time_steps; ++t)
{
auto data_slice = make_shared<op::Slice>(
data, Coordinate{0, t, 0}, Coordinate{data_shape[0], t + 1, data_shape[2]});
auto data_reshape = make_shared<op::Reshape>(
data_slice, AxisVector{0, 1, 2}, Shape{data_shape[0], data_shape[2]});
auto weights_reshape = make_shared<op::Reshape>(
weights, AxisVector{1, 0}, Shape{weights_shape[1], weights_shape[0]});
auto dot = make_shared<op::Dot>(data_reshape, weights_reshape);
auto bias_broadcast = make_shared<op::Broadcast>(bias, dot->get_shape(), AxisSet{0});
auto add = make_shared<op::Add>(dot, bias_broadcast);
results.push_back(add);
}
auto func = make_shared<Function>(results, op::ParameterVector{data, weights, bias});
if (enable_pass)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(func);
// check all of our dot/add are converted to a single MatmulBias op.
size_t count = count_ops_of_type<op::MatmulBias>(func);
EXPECT_EQ(count, 1);
}
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(func);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
shared_ptr<runtime::TensorView> data_tensor =
backend->make_primary_tensor_view(element::f32, data->get_shape());
shared_ptr<runtime::TensorView> weights_tensor =
backend->make_primary_tensor_view(element::f32, weights->get_shape());
shared_ptr<runtime::TensorView> bias_tensor =
backend->make_primary_tensor_view(element::f32, bias->get_shape());
std::vector<shared_ptr<runtime::TensorView>> result_tensors;
for (auto r : results)
{
result_tensors.push_back(backend->make_primary_tensor_view(element::f32, r->get_shape()));
}
copy_data(data_tensor, data_val);
copy_data(weights_tensor, weights_val);
copy_data(bias_tensor, bias_val);
cf->call(result_tensors, {data_tensor, weights_tensor, bias_tensor});
return result_tensors;
}
TEST(cpu_fusion, rnn_matrix_fusion_eval_pass)
{
const size_t time_steps = 4;
Shape data_shape{3, time_steps, 5};
Shape weights_shape{6, data_shape[2]};
Shape bias_shape{6};
test::Uniform<float> rng{0, 1, 0};
vector<float> data_val(shape_size(data_shape));
vector<float> weights_val(shape_size(weights_shape));
vector<float> bias_val(shape_size(bias_shape));
rng.initialize(data_val);
rng.initialize(weights_val);
rng.initialize(bias_val);
std::vector<shared_ptr<runtime::TensorView>> result_expected = rnn_matrix_fusion_eval(
time_steps, data_shape, weights_shape, bias_shape, data_val, weights_val, bias_val, false);
std::vector<shared_ptr<runtime::TensorView>> result_fused = rnn_matrix_fusion_eval(
time_steps, data_shape, weights_shape, bias_shape, data_val, weights_val, bias_val, true);
for (size_t i = 0; i < result_expected.size(); ++i)
{
EXPECT_TRUE(test::all_close<float>(result_expected[i], result_fused[i]));
}
}
......@@ -45,12 +45,18 @@ namespace ngraph
initialize(const std::shared_ptr<runtime::TensorView>& ptv)
{
std::vector<T> vec = read_vector<T>(ptv);
initialize(vec);
write_vector(ptv, vec);
return ptv;
}
/// @brief Randomly initialize a vector
/// @param vec The tensor to initialize
void initialize(std::vector<T>& vec)
{
for (T& elt : vec)
{
elt = m_r();
}
write_vector(ptv, vec);
return ptv;
}
protected:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment