Commit ef778693 authored by Louis Feng's avatar Louis Feng Committed by Sang Ik Lee

Addes backprop to BatchDot op, allows fusion in training. (#2297)

* batch dot bprop WIP.

* WIP.

* testing.

* clean up debug code.

* comments and var name change.

* clean up.

* format style, batch dot differentiable pass.

* removed debug output.

* added unit test to autodiff, refactored make_function -> make_function_from_file.

* fixed build warning.

* fixed gpu build error.

* clang format fix.

* all test_tools.cpp to find SERIALIZED_ZOO

* remove cmake redef.

* fix unused macro.

* making test cpu only.

* testing build var

* macro test

* verbose makefile test

* style fix

* verbose make

* test/util needs test/models.

* removed debug output.

* refactor fusion type.

* refactor fusion type.
parent ea8407de
......@@ -360,6 +360,7 @@ if (NGRAPH_UNIT_TEST_ENABLE)
add_subdirectory(test)
message(STATUS "unit tests enabled")
else()
add_subdirectory(test/models)
add_subdirectory(test/util)
message(STATUS "unit tests disabled")
endif()
......
......@@ -34,6 +34,14 @@ namespace ngraph
class NodePass;
class CallGraphPass;
class Manager;
enum FusionType
{
//`DIFFERENTIABLE_FUSIONS` produce ops that support autodiff
// i.e. implement `generate_adjoints`
DIFFERENTIABLE_FUSIONS = 0x1,
REGULAR_FUSIONS = 0x2,
ALL_FUSIONS = 0xFFFFFFFF
};
}
}
......
......@@ -16,6 +16,7 @@
#include "batch_dot.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/util.hpp"
using namespace std;
......@@ -69,3 +70,41 @@ op::BatchDot::BatchDot(shared_ptr<Node> a, shared_ptr<Node> b, bool transpose_a,
set_output_type(0, a->get_element_type(), dot_shape);
}
void op::BatchDot::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
auto delta = deltas.at(0); // NxIxK
auto a = get_inputs().at(0).get_output().get_node(); // NxIxJ (maybe transposed)
auto b = get_inputs().at(1).get_output().get_node(); // NxJxK (maybe transposed)
auto batch_transpose = [](const shared_ptr<Node>& node) {
const auto& batch_shape = node->get_shape();
// index 0 is the batch, only transposing the others.
AxisVector input_order{0, 2, 1};
Shape output_shape{batch_shape[0], batch_shape[2], batch_shape[1]};
return make_shared<op::Reshape>(node, input_order, output_shape);
};
// if b is already transposed, it does not need to be transposed again
auto delta_dot_b = make_shared<op::BatchDot>(delta, b, false, !m_transpose_b); // IK.KJ->IJ
// if a is transposed, the result need to be transposed to match original a shape.
if (m_transpose_a)
{
adjoints.add_delta(a, batch_transpose(delta_dot_b));
}
else
{
adjoints.add_delta(a, delta_dot_b);
}
auto a_dot_delta = make_shared<BatchDot>(a, delta, !m_transpose_a, false); // JI.IK->JK
if (m_transpose_b)
{
adjoints.add_delta(b, batch_transpose(a_dot_delta));
}
else
{
adjoints.add_delta(b, a_dot_delta);
}
}
......@@ -37,6 +37,9 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
private:
bool m_transpose_a;
bool m_transpose_b;
......
......@@ -53,7 +53,8 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
protected:
Strides m_window_movement_strides;
......
......@@ -49,7 +49,8 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
protected:
Strides m_window_movement_strides;
......
......@@ -35,27 +35,16 @@ namespace ngraph
class ngraph::runtime::cpu::pass::CPUFusion : public ngraph::pass::GraphRewrite
{
public:
// 30 different fusion groups that we can nest/mix&match/etc
// should be good enough for quite a while
enum fusions
{
//`DIFFERENTIABLE_FUSIONS` produce ops that support autodiff
// i.e. implement `generate_adjoints`
DIFFERENTIABLE_FUSIONS = 0x1,
REGULAR_FUSIONS = 0x2,
ALL = 0xFFFFFFFF
};
CPUFusion(int fusions = ALL)
CPUFusion(ngraph::pass::FusionType fusions = ngraph::pass::ALL_FUSIONS)
: GraphRewrite()
{
if (fusions & DIFFERENTIABLE_FUSIONS)
if (fusions & ngraph::pass::DIFFERENTIABLE_FUSIONS)
{
construct_conv_bias();
construct_sigmoid_multiply();
}
if (fusions & REGULAR_FUSIONS)
if (fusions & ngraph::pass::REGULAR_FUSIONS)
{
construct_matmul();
construct_matmulbias();
......
......@@ -578,18 +578,23 @@ bool runtime::cpu::pass::CPUBatchFusion::run_on_function(std::shared_ptr<Functio
const Node& node = *n;
if (TI(node) == TI(op::Concat))
{
auto fused_node = fuse_batch_dot(n);
if (fused_node)
if (m_fusion_type & ngraph::pass::DIFFERENTIABLE_FUSIONS)
{
if (auto fused_node = fuse_batch_dot(n))
{
func->replace_node(n, fused_node);
modified = true;
}
else if (auto fused_conv = fuse_group_convolution(n))
}
if (m_fusion_type & ngraph::pass::REGULAR_FUSIONS)
{
if (auto fused_conv = fuse_group_convolution(n))
{
func->replace_node(n, fused_conv);
modified = true;
}
}
}
}
return modified;
}
......@@ -29,12 +29,22 @@ namespace ngraph
class CPURnnMatFusion : public ngraph::pass::FunctionPass
{
public:
bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
virtual bool
run_on_function(std::shared_ptr<ngraph::Function> function) override;
};
class CPUBatchFusion : public ngraph::pass::FunctionPass
{
public:
bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
CPUBatchFusion(ngraph::pass::FusionType type = ngraph::pass::ALL_FUSIONS)
: FunctionPass()
, m_fusion_type(type)
{
}
virtual bool
run_on_function(std::shared_ptr<ngraph::Function> function) override;
private:
ngraph::pass::FusionType m_fusion_type;
};
}
}
......
......@@ -21,13 +21,18 @@
#include "gtest/gtest.h"
// clang-format off
#define AUTODIFF_BACKEND_${BACKEND_NAME}
// clang-format on
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/runtime/cpu/pass/cpu_mat_fusion.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
#include "util/autodiff/backprop_function.hpp"
#include "util/autodiff/numeric_compare.hpp"
#include "util/random.hpp"
#include "util/test_control.hpp"
using namespace std;
using namespace ngraph;
......@@ -847,6 +852,30 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_dot_tensor3_tensor3)
EXPECT_TRUE(autodiff_numeric_compare<float>(backend.get(), make_graph, {x0, x1}, .01f, .01f));
}
#ifdef AUTODIFF_BACKEND_CPU
NGRAPH_TEST(${BACKEND_NAME}, backwards_batchdot_tensor2_tensor2)
{
auto backend = runtime::Backend::create("${BACKEND_NAME}");
std::string backend_name = "${BACKEND_NAME}";
const std::string file_name("mxnet/batch_dot_3.json");
auto f = make_function_from_file(file_name);
test::Uniform<float> rng(-1.0f, 1.0f);
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> args;
for (shared_ptr<op::Parameter> param : f->get_parameters())
{
args.push_back(rng.initialize(backend->create_tensor<float>(param->get_shape())));
}
auto g = make_function_from_file(file_name);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
pass_manager.run_passes(g);
EXPECT_TRUE(autodiff_numeric_compare<float>(backend.get(), f, g, args, .01f, .01f));
}
#endif
NGRAPH_TEST(${BACKEND_NAME}, backwards_exp)
{
auto backend = runtime::Backend::create("${BACKEND_NAME}");
......@@ -1760,3 +1789,9 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_reverse_sequence_n4d2c3h2w2)
backend->call_with_validate(handle, {da, db}, {a, b, c});
ASSERT_EQ(read_vector<int>(da), expected);
}
// clang-format off
#ifdef AUTODIFF_BACKEND_${BACKEND_NAME}
#undef AUTODIFF_BACKEND_${BACKEND_NAME}
#endif
// clang-format on
......@@ -278,8 +278,7 @@ TEST(cpu_fusion, cpu_fusion_pass_basic)
auto add = dot + broadcast;
auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, ParameterVector{A, B, C});
pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_argument(0)), nullptr);
......@@ -300,8 +299,7 @@ TEST(cpu_fusion, commutative_matmul_bias)
auto add = broadcast + dot;
auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, ParameterVector{A, B, C});
pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_argument(0)), nullptr);
......@@ -323,8 +321,7 @@ TEST(cpu_fusion, cpu_fusion_pass_matmul_bias)
auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, ParameterVector{W, x, b});
pass_manager.run_passes(func);
auto gmm = graph->get_argument(0);
......@@ -345,8 +342,7 @@ TEST(cpu_fusion, cpu_fusion_pass_matmul_no_bias)
auto graph = make_shared<op::Abs>(re_dot);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, ParameterVector{W, x});
pass_manager.run_passes(func);
size_t mmb = count_ops_of_type<op::MatmulBias>(func);
......@@ -360,8 +356,7 @@ TEST(cpu_fusion, gemm_mlp)
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
pass_manager.run_passes(func);
auto mmbs = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(mmbs, 3);
......@@ -372,8 +367,7 @@ TEST(cpu_fusion, fuse_fprop_bn)
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("bn_fprop_before_fusion.png");
pass_manager.register_pass<ngraph::pass::ReshapeElimination>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
pass_manager.register_pass<pass::VisualizeTree>("bn_fprop_after_fusion.png");
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/bn_fprop_b2c3h2w2.json");
const string json_string = file_util::read_file_to_string(json_path);
......@@ -503,8 +497,7 @@ TEST(cpu_fusion, fuse_conv_bias)
{
pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::ReshapeElimination>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::DIFFERENTIABLE_FUSIONS);
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::DIFFERENTIABLE_FUSIONS);
const string json_path = file_util::path_join(SERIALIZED_ZOO, "conv_bias.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
......@@ -851,8 +844,7 @@ TEST(cpu_fusion, fuse_conv_relu)
auto func = make_shared<Function>(abs_node, ParameterVector{A, weights});
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
pass_manager.run_passes(func);
size_t cb = count_ops_of_type<op::ConvolutionRelu>(func);
ASSERT_GT(cb, 0);
......@@ -1280,8 +1272,7 @@ std::vector<shared_ptr<runtime::Tensor>> rnn_matrix_fusion_eval(const size_t tim
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
pass_manager.run_passes(func);
// check all of our dot/add are converted to a single MatmulBias op.
size_t count = count_ops_of_type<op::MatmulBias>(func);
......@@ -1340,8 +1331,7 @@ TEST(cpu_fusion, rnn_fusion_from_json_model)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
const string json_path =
file_util::path_join(SERIALIZED_ZOO, "mxnet/rnn-10-step-fusion-test.json");
const string json_string = file_util::read_file_to_string(json_path);
......@@ -2301,20 +2291,11 @@ TEST(cpu_fusion, fuse_1_layer_rnn)
}
}
static std::shared_ptr<Function> make_function(const std::string& file_name)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, file_name);
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
return func;
}
TEST(cpu_fusion, rnn_fusion_1lstm_cell)
{
const std::string file_name("mxnet/1_lstm_cell_forward.json");
auto cpu_f = make_function(file_name);
auto int_f = make_function(file_name);
auto cpu_f = make_function_from_file(file_name);
auto int_f = make_function_from_file(file_name);
test::Uniform<float> rng(-1.0f, 1.0f);
vector<vector<float>> args;
......@@ -2335,8 +2316,8 @@ TEST(cpu_fusion, rnn_fusion_1lstm_cell)
TEST(cpu_fusion, rnn_fusion_1rnn_layer_3lstm_cell)
{
const std::string file_name("mxnet/1rnn_layer_3lstm_cell.json");
auto cpu_f = make_function(file_name);
auto int_f = make_function(file_name);
auto cpu_f = make_function_from_file(file_name);
auto int_f = make_function_from_file(file_name);
test::Uniform<float> rng(-1.0f, 1.0f);
vector<vector<float>> args;
......@@ -2357,8 +2338,8 @@ TEST(cpu_fusion, rnn_fusion_1rnn_layer_3lstm_cell)
TEST(cpu_fusion, rnn_fusion_2rnn_layer_3lstm_cell)
{
const std::string file_name("mxnet/2rnn_layer_3lstm_cell.json");
auto cpu_f = make_function(file_name);
auto int_f = make_function(file_name);
auto cpu_f = make_function_from_file(file_name);
auto int_f = make_function_from_file(file_name);
test::Uniform<float> rng(-1.0f, 1.0f);
vector<vector<float>> args;
......@@ -3016,8 +2997,8 @@ TEST(cpu_fusion, fuse_batch_dot_forward)
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
const std::string file_name("mxnet/batch_dot_3.json");
auto cpu_f = make_function(file_name);
auto int_f = make_function(file_name);
auto cpu_f = make_function_from_file(file_name);
auto int_f = make_function_from_file(file_name);
pass_manager.run_passes(cpu_f);
test::Uniform<float> rng(0.0f, 1.0f);
vector<vector<float>> args;
......@@ -3036,11 +3017,42 @@ TEST(cpu_fusion, fuse_batch_dot_forward)
}
}
TEST(cpu_fusion, fuse_batch_dot_backward)
{
const std::string file_name("mxnet/batch_dot_3.json");
auto cpu_f = make_function_from_file(file_name);
auto int_f = make_function_from_file(file_name);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
pass_manager.run_passes(cpu_f);
auto int_df = autodiff::backprop_function(int_f);
auto cpu_df = autodiff::backprop_function(cpu_f);
test::Uniform<float> rng(-1.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : cpu_df->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_df, args, "INTERPRETER");
auto cpu_results = execute(cpu_df, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
TEST(cpu_fusion, fuse_rnn_across_layer_2layer_3timestep)
{
const std::string file_name("mxnet/2layer_3timestep_ic100oc100.json");
auto cpu_f = make_function(file_name);
auto int_f = make_function(file_name);
auto cpu_f = make_function_from_file(file_name);
auto int_f = make_function_from_file(file_name);
test::Uniform<float> rng(-1.0f, 1.0f);
vector<vector<float>> args;
......
......@@ -183,15 +183,6 @@ TEST(DISABLED_gpu_fusion, fuse_1_layer_rnn)
}
}
static std::shared_ptr<Function> make_function(const std::string& file_name)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, file_name);
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
return func;
}
TEST(gpu_fusion, lstm_analytic)
{
auto input_xt = std::make_shared<op::Parameter>(element::f32, Shape{1, 1});
......@@ -432,8 +423,8 @@ TEST(gpu_fusion, fuse_2_layer_rnn_1lstm_analytic)
TEST(gpu_fusion, rnn_fusion_inter_vs_gpu_1lstm_cell)
{
const std::string file_name("mxnet/1_lstm_cell_forward.json");
auto gpu_f = make_function(file_name);
auto int_f = make_function(file_name);
auto gpu_f = make_function_from_file(file_name);
auto int_f = make_function_from_file(file_name);
test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args;
......@@ -454,8 +445,8 @@ TEST(gpu_fusion, rnn_fusion_inter_vs_gpu_1lstm_cell)
TEST(DISABLED_gpu_fusion, rnn_fusion_inter_vs_gpu_1rnn_layer_3lstm_cell)
{
const std::string file_name("mxnet/1rnn_layer_3lstm_cell.json");
auto gpu_f = make_function(file_name);
auto int_f = make_function(file_name);
auto gpu_f = make_function_from_file(file_name);
auto int_f = make_function_from_file(file_name);
test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args;
......@@ -476,8 +467,8 @@ TEST(DISABLED_gpu_fusion, rnn_fusion_inter_vs_gpu_1rnn_layer_3lstm_cell)
TEST(gpu_fusion, rnn_fusion_inter_vs_gpu_2rnn_layer_3lstm_cell)
{
const std::string file_name("mxnet/2rnn_layer_3lstm_cell.json");
auto gpu_f = make_function(file_name);
auto int_f = make_function(file_name);
auto gpu_f = make_function_from_file(file_name);
auto int_f = make_function_from_file(file_name);
test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args;
......@@ -516,8 +507,8 @@ TEST(gpu_fusion, fuse_rnn_across_layer)
TEST(gpu_fusion, fuse_rnn_across_2layer_1timestep)
{
const std::string file_name("mxnet/2rnn_layer_1timestep.json");
auto gpu_f = make_function(file_name);
auto int_f = make_function(file_name);
auto gpu_f = make_function_from_file(file_name);
auto int_f = make_function_from_file(file_name);
test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args;
......
......@@ -301,3 +301,12 @@ string
return ss.str();
}
std::shared_ptr<Function> make_function_from_file(const std::string& file_name)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, file_name);
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
return func;
}
......@@ -38,6 +38,7 @@ namespace ngraph
bool validate_list(const std::list<std::shared_ptr<ngraph::Node>>& nodes);
std::shared_ptr<ngraph::Function> make_test_graph();
std::shared_ptr<ngraph::Function> make_function_from_file(const std::string& file_name);
template <typename T>
void copy_data(std::shared_ptr<ngraph::runtime::Tensor> tv, const std::vector<T>& data)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment