Commit ef778693 authored by Louis Feng's avatar Louis Feng Committed by Sang Ik Lee

Addes backprop to BatchDot op, allows fusion in training. (#2297)

* batch dot bprop WIP.

* WIP.

* testing.

* clean up debug code.

* comments and var name change.

* clean up.

* format style, batch dot differentiable pass.

* removed debug output.

* added unit test to autodiff, refactored make_function -> make_function_from_file.

* fixed build warning.

* fixed gpu build error.

* clang format fix.

* all test_tools.cpp to find SERIALIZED_ZOO

* remove cmake redef.

* fix unused macro.

* making test cpu only.

* testing build var

* macro test

* verbose makefile test

* style fix

* verbose make

* test/util needs test/models.

* removed debug output.

* refactor fusion type.

* refactor fusion type.
parent ea8407de
...@@ -360,6 +360,7 @@ if (NGRAPH_UNIT_TEST_ENABLE) ...@@ -360,6 +360,7 @@ if (NGRAPH_UNIT_TEST_ENABLE)
add_subdirectory(test) add_subdirectory(test)
message(STATUS "unit tests enabled") message(STATUS "unit tests enabled")
else() else()
add_subdirectory(test/models)
add_subdirectory(test/util) add_subdirectory(test/util)
message(STATUS "unit tests disabled") message(STATUS "unit tests disabled")
endif() endif()
......
...@@ -34,6 +34,14 @@ namespace ngraph ...@@ -34,6 +34,14 @@ namespace ngraph
class NodePass; class NodePass;
class CallGraphPass; class CallGraphPass;
class Manager; class Manager;
enum FusionType
{
//`DIFFERENTIABLE_FUSIONS` produce ops that support autodiff
// i.e. implement `generate_adjoints`
DIFFERENTIABLE_FUSIONS = 0x1,
REGULAR_FUSIONS = 0x2,
ALL_FUSIONS = 0xFFFFFFFF
};
} }
} }
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "batch_dot.hpp" #include "batch_dot.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std; using namespace std;
...@@ -69,3 +70,41 @@ op::BatchDot::BatchDot(shared_ptr<Node> a, shared_ptr<Node> b, bool transpose_a, ...@@ -69,3 +70,41 @@ op::BatchDot::BatchDot(shared_ptr<Node> a, shared_ptr<Node> b, bool transpose_a,
set_output_type(0, a->get_element_type(), dot_shape); set_output_type(0, a->get_element_type(), dot_shape);
} }
void op::BatchDot::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
auto delta = deltas.at(0); // NxIxK
auto a = get_inputs().at(0).get_output().get_node(); // NxIxJ (maybe transposed)
auto b = get_inputs().at(1).get_output().get_node(); // NxJxK (maybe transposed)
auto batch_transpose = [](const shared_ptr<Node>& node) {
const auto& batch_shape = node->get_shape();
// index 0 is the batch, only transposing the others.
AxisVector input_order{0, 2, 1};
Shape output_shape{batch_shape[0], batch_shape[2], batch_shape[1]};
return make_shared<op::Reshape>(node, input_order, output_shape);
};
// if b is already transposed, it does not need to be transposed again
auto delta_dot_b = make_shared<op::BatchDot>(delta, b, false, !m_transpose_b); // IK.KJ->IJ
// if a is transposed, the result need to be transposed to match original a shape.
if (m_transpose_a)
{
adjoints.add_delta(a, batch_transpose(delta_dot_b));
}
else
{
adjoints.add_delta(a, delta_dot_b);
}
auto a_dot_delta = make_shared<BatchDot>(a, delta, !m_transpose_a, false); // JI.IK->JK
if (m_transpose_b)
{
adjoints.add_delta(b, batch_transpose(a_dot_delta));
}
else
{
adjoints.add_delta(b, a_dot_delta);
}
}
...@@ -37,6 +37,9 @@ namespace ngraph ...@@ -37,6 +37,9 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
private: private:
bool m_transpose_a; bool m_transpose_a;
bool m_transpose_b; bool m_transpose_b;
......
...@@ -53,7 +53,8 @@ namespace ngraph ...@@ -53,7 +53,8 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) override; virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
protected: protected:
Strides m_window_movement_strides; Strides m_window_movement_strides;
......
...@@ -49,7 +49,8 @@ namespace ngraph ...@@ -49,7 +49,8 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) override; virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
protected: protected:
Strides m_window_movement_strides; Strides m_window_movement_strides;
......
...@@ -35,27 +35,16 @@ namespace ngraph ...@@ -35,27 +35,16 @@ namespace ngraph
class ngraph::runtime::cpu::pass::CPUFusion : public ngraph::pass::GraphRewrite class ngraph::runtime::cpu::pass::CPUFusion : public ngraph::pass::GraphRewrite
{ {
public: public:
// 30 different fusion groups that we can nest/mix&match/etc CPUFusion(ngraph::pass::FusionType fusions = ngraph::pass::ALL_FUSIONS)
// should be good enough for quite a while
enum fusions
{
//`DIFFERENTIABLE_FUSIONS` produce ops that support autodiff
// i.e. implement `generate_adjoints`
DIFFERENTIABLE_FUSIONS = 0x1,
REGULAR_FUSIONS = 0x2,
ALL = 0xFFFFFFFF
};
CPUFusion(int fusions = ALL)
: GraphRewrite() : GraphRewrite()
{ {
if (fusions & DIFFERENTIABLE_FUSIONS) if (fusions & ngraph::pass::DIFFERENTIABLE_FUSIONS)
{ {
construct_conv_bias(); construct_conv_bias();
construct_sigmoid_multiply(); construct_sigmoid_multiply();
} }
if (fusions & REGULAR_FUSIONS) if (fusions & ngraph::pass::REGULAR_FUSIONS)
{ {
construct_matmul(); construct_matmul();
construct_matmulbias(); construct_matmulbias();
......
...@@ -578,18 +578,23 @@ bool runtime::cpu::pass::CPUBatchFusion::run_on_function(std::shared_ptr<Functio ...@@ -578,18 +578,23 @@ bool runtime::cpu::pass::CPUBatchFusion::run_on_function(std::shared_ptr<Functio
const Node& node = *n; const Node& node = *n;
if (TI(node) == TI(op::Concat)) if (TI(node) == TI(op::Concat))
{ {
auto fused_node = fuse_batch_dot(n); if (m_fusion_type & ngraph::pass::DIFFERENTIABLE_FUSIONS)
if (fused_node) {
if (auto fused_node = fuse_batch_dot(n))
{ {
func->replace_node(n, fused_node); func->replace_node(n, fused_node);
modified = true; modified = true;
} }
else if (auto fused_conv = fuse_group_convolution(n)) }
if (m_fusion_type & ngraph::pass::REGULAR_FUSIONS)
{
if (auto fused_conv = fuse_group_convolution(n))
{ {
func->replace_node(n, fused_conv); func->replace_node(n, fused_conv);
modified = true; modified = true;
} }
} }
} }
}
return modified; return modified;
} }
...@@ -29,12 +29,22 @@ namespace ngraph ...@@ -29,12 +29,22 @@ namespace ngraph
class CPURnnMatFusion : public ngraph::pass::FunctionPass class CPURnnMatFusion : public ngraph::pass::FunctionPass
{ {
public: public:
bool run_on_function(std::shared_ptr<ngraph::Function> function) override; virtual bool
run_on_function(std::shared_ptr<ngraph::Function> function) override;
}; };
class CPUBatchFusion : public ngraph::pass::FunctionPass class CPUBatchFusion : public ngraph::pass::FunctionPass
{ {
public: public:
bool run_on_function(std::shared_ptr<ngraph::Function> function) override; CPUBatchFusion(ngraph::pass::FusionType type = ngraph::pass::ALL_FUSIONS)
: FunctionPass()
, m_fusion_type(type)
{
}
virtual bool
run_on_function(std::shared_ptr<ngraph::Function> function) override;
private:
ngraph::pass::FusionType m_fusion_type;
}; };
} }
} }
......
...@@ -21,13 +21,18 @@ ...@@ -21,13 +21,18 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
// clang-format off
#define AUTODIFF_BACKEND_${BACKEND_NAME}
// clang-format on
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/runtime/cpu/pass/cpu_mat_fusion.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp" #include "ngraph/runtime/reference/avg_pool.hpp"
#include "util/autodiff/backprop_function.hpp" #include "util/autodiff/backprop_function.hpp"
#include "util/autodiff/numeric_compare.hpp" #include "util/autodiff/numeric_compare.hpp"
#include "util/random.hpp" #include "util/random.hpp"
#include "util/test_control.hpp" #include "util/test_control.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -847,6 +852,30 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_dot_tensor3_tensor3) ...@@ -847,6 +852,30 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_dot_tensor3_tensor3)
EXPECT_TRUE(autodiff_numeric_compare<float>(backend.get(), make_graph, {x0, x1}, .01f, .01f)); EXPECT_TRUE(autodiff_numeric_compare<float>(backend.get(), make_graph, {x0, x1}, .01f, .01f));
} }
#ifdef AUTODIFF_BACKEND_CPU
NGRAPH_TEST(${BACKEND_NAME}, backwards_batchdot_tensor2_tensor2)
{
auto backend = runtime::Backend::create("${BACKEND_NAME}");
std::string backend_name = "${BACKEND_NAME}";
const std::string file_name("mxnet/batch_dot_3.json");
auto f = make_function_from_file(file_name);
test::Uniform<float> rng(-1.0f, 1.0f);
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> args;
for (shared_ptr<op::Parameter> param : f->get_parameters())
{
args.push_back(rng.initialize(backend->create_tensor<float>(param->get_shape())));
}
auto g = make_function_from_file(file_name);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
pass_manager.run_passes(g);
EXPECT_TRUE(autodiff_numeric_compare<float>(backend.get(), f, g, args, .01f, .01f));
}
#endif
NGRAPH_TEST(${BACKEND_NAME}, backwards_exp) NGRAPH_TEST(${BACKEND_NAME}, backwards_exp)
{ {
auto backend = runtime::Backend::create("${BACKEND_NAME}"); auto backend = runtime::Backend::create("${BACKEND_NAME}");
...@@ -1760,3 +1789,9 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_reverse_sequence_n4d2c3h2w2) ...@@ -1760,3 +1789,9 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_reverse_sequence_n4d2c3h2w2)
backend->call_with_validate(handle, {da, db}, {a, b, c}); backend->call_with_validate(handle, {da, db}, {a, b, c});
ASSERT_EQ(read_vector<int>(da), expected); ASSERT_EQ(read_vector<int>(da), expected);
} }
// clang-format off
#ifdef AUTODIFF_BACKEND_${BACKEND_NAME}
#undef AUTODIFF_BACKEND_${BACKEND_NAME}
#endif
// clang-format on
...@@ -278,8 +278,7 @@ TEST(cpu_fusion, cpu_fusion_pass_basic) ...@@ -278,8 +278,7 @@ TEST(cpu_fusion, cpu_fusion_pass_basic)
auto add = dot + broadcast; auto add = dot + broadcast;
auto graph = make_shared<op::Abs>(add); auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>( pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, ParameterVector{A, B, C}); auto func = make_shared<Function>(graph, ParameterVector{A, B, C});
pass_manager.run_passes(func); pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_argument(0)), nullptr); ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_argument(0)), nullptr);
...@@ -300,8 +299,7 @@ TEST(cpu_fusion, commutative_matmul_bias) ...@@ -300,8 +299,7 @@ TEST(cpu_fusion, commutative_matmul_bias)
auto add = broadcast + dot; auto add = broadcast + dot;
auto graph = make_shared<op::Abs>(add); auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>( pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, ParameterVector{A, B, C}); auto func = make_shared<Function>(graph, ParameterVector{A, B, C});
pass_manager.run_passes(func); pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_argument(0)), nullptr); ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_argument(0)), nullptr);
...@@ -323,8 +321,7 @@ TEST(cpu_fusion, cpu_fusion_pass_matmul_bias) ...@@ -323,8 +321,7 @@ TEST(cpu_fusion, cpu_fusion_pass_matmul_bias)
auto graph = make_shared<op::Abs>(add); auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>( pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, ParameterVector{W, x, b}); auto func = make_shared<Function>(graph, ParameterVector{W, x, b});
pass_manager.run_passes(func); pass_manager.run_passes(func);
auto gmm = graph->get_argument(0); auto gmm = graph->get_argument(0);
...@@ -345,8 +342,7 @@ TEST(cpu_fusion, cpu_fusion_pass_matmul_no_bias) ...@@ -345,8 +342,7 @@ TEST(cpu_fusion, cpu_fusion_pass_matmul_no_bias)
auto graph = make_shared<op::Abs>(re_dot); auto graph = make_shared<op::Abs>(re_dot);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>( pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, ParameterVector{W, x}); auto func = make_shared<Function>(graph, ParameterVector{W, x});
pass_manager.run_passes(func); pass_manager.run_passes(func);
size_t mmb = count_ops_of_type<op::MatmulBias>(func); size_t mmb = count_ops_of_type<op::MatmulBias>(func);
...@@ -360,8 +356,7 @@ TEST(cpu_fusion, gemm_mlp) ...@@ -360,8 +356,7 @@ TEST(cpu_fusion, gemm_mlp)
stringstream ss(json_string); stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss); shared_ptr<Function> func = ngraph::deserialize(ss);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>( pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
pass_manager.run_passes(func); pass_manager.run_passes(func);
auto mmbs = count_ops_of_type<op::MatmulBias>(func); auto mmbs = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(mmbs, 3); ASSERT_EQ(mmbs, 3);
...@@ -372,8 +367,7 @@ TEST(cpu_fusion, fuse_fprop_bn) ...@@ -372,8 +367,7 @@ TEST(cpu_fusion, fuse_fprop_bn)
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("bn_fprop_before_fusion.png"); pass_manager.register_pass<pass::VisualizeTree>("bn_fprop_before_fusion.png");
pass_manager.register_pass<ngraph::pass::ReshapeElimination>(); pass_manager.register_pass<ngraph::pass::ReshapeElimination>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>( pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
pass_manager.register_pass<pass::VisualizeTree>("bn_fprop_after_fusion.png"); pass_manager.register_pass<pass::VisualizeTree>("bn_fprop_after_fusion.png");
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/bn_fprop_b2c3h2w2.json"); const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/bn_fprop_b2c3h2w2.json");
const string json_string = file_util::read_file_to_string(json_path); const string json_string = file_util::read_file_to_string(json_path);
...@@ -503,8 +497,7 @@ TEST(cpu_fusion, fuse_conv_bias) ...@@ -503,8 +497,7 @@ TEST(cpu_fusion, fuse_conv_bias)
{ {
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::ReshapeElimination>(); pass_manager.register_pass<ngraph::pass::ReshapeElimination>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>( pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::DIFFERENTIABLE_FUSIONS);
runtime::cpu::pass::CPUFusion::DIFFERENTIABLE_FUSIONS);
const string json_path = file_util::path_join(SERIALIZED_ZOO, "conv_bias.json"); const string json_path = file_util::path_join(SERIALIZED_ZOO, "conv_bias.json");
const string json_string = file_util::read_file_to_string(json_path); const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string); stringstream ss(json_string);
...@@ -851,8 +844,7 @@ TEST(cpu_fusion, fuse_conv_relu) ...@@ -851,8 +844,7 @@ TEST(cpu_fusion, fuse_conv_relu)
auto func = make_shared<Function>(abs_node, ParameterVector{A, weights}); auto func = make_shared<Function>(abs_node, ParameterVector{A, weights});
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>( pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
pass_manager.run_passes(func); pass_manager.run_passes(func);
size_t cb = count_ops_of_type<op::ConvolutionRelu>(func); size_t cb = count_ops_of_type<op::ConvolutionRelu>(func);
ASSERT_GT(cb, 0); ASSERT_GT(cb, 0);
...@@ -1280,8 +1272,7 @@ std::vector<shared_ptr<runtime::Tensor>> rnn_matrix_fusion_eval(const size_t tim ...@@ -1280,8 +1272,7 @@ std::vector<shared_ptr<runtime::Tensor>> rnn_matrix_fusion_eval(const size_t tim
{ {
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>( pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
pass_manager.run_passes(func); pass_manager.run_passes(func);
// check all of our dot/add are converted to a single MatmulBias op. // check all of our dot/add are converted to a single MatmulBias op.
size_t count = count_ops_of_type<op::MatmulBias>(func); size_t count = count_ops_of_type<op::MatmulBias>(func);
...@@ -1340,8 +1331,7 @@ TEST(cpu_fusion, rnn_fusion_from_json_model) ...@@ -1340,8 +1331,7 @@ TEST(cpu_fusion, rnn_fusion_from_json_model)
{ {
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>( pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
const string json_path = const string json_path =
file_util::path_join(SERIALIZED_ZOO, "mxnet/rnn-10-step-fusion-test.json"); file_util::path_join(SERIALIZED_ZOO, "mxnet/rnn-10-step-fusion-test.json");
const string json_string = file_util::read_file_to_string(json_path); const string json_string = file_util::read_file_to_string(json_path);
...@@ -2301,20 +2291,11 @@ TEST(cpu_fusion, fuse_1_layer_rnn) ...@@ -2301,20 +2291,11 @@ TEST(cpu_fusion, fuse_1_layer_rnn)
} }
} }
static std::shared_ptr<Function> make_function(const std::string& file_name)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, file_name);
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
return func;
}
TEST(cpu_fusion, rnn_fusion_1lstm_cell) TEST(cpu_fusion, rnn_fusion_1lstm_cell)
{ {
const std::string file_name("mxnet/1_lstm_cell_forward.json"); const std::string file_name("mxnet/1_lstm_cell_forward.json");
auto cpu_f = make_function(file_name); auto cpu_f = make_function_from_file(file_name);
auto int_f = make_function(file_name); auto int_f = make_function_from_file(file_name);
test::Uniform<float> rng(-1.0f, 1.0f); test::Uniform<float> rng(-1.0f, 1.0f);
vector<vector<float>> args; vector<vector<float>> args;
...@@ -2335,8 +2316,8 @@ TEST(cpu_fusion, rnn_fusion_1lstm_cell) ...@@ -2335,8 +2316,8 @@ TEST(cpu_fusion, rnn_fusion_1lstm_cell)
TEST(cpu_fusion, rnn_fusion_1rnn_layer_3lstm_cell) TEST(cpu_fusion, rnn_fusion_1rnn_layer_3lstm_cell)
{ {
const std::string file_name("mxnet/1rnn_layer_3lstm_cell.json"); const std::string file_name("mxnet/1rnn_layer_3lstm_cell.json");
auto cpu_f = make_function(file_name); auto cpu_f = make_function_from_file(file_name);
auto int_f = make_function(file_name); auto int_f = make_function_from_file(file_name);
test::Uniform<float> rng(-1.0f, 1.0f); test::Uniform<float> rng(-1.0f, 1.0f);
vector<vector<float>> args; vector<vector<float>> args;
...@@ -2357,8 +2338,8 @@ TEST(cpu_fusion, rnn_fusion_1rnn_layer_3lstm_cell) ...@@ -2357,8 +2338,8 @@ TEST(cpu_fusion, rnn_fusion_1rnn_layer_3lstm_cell)
TEST(cpu_fusion, rnn_fusion_2rnn_layer_3lstm_cell) TEST(cpu_fusion, rnn_fusion_2rnn_layer_3lstm_cell)
{ {
const std::string file_name("mxnet/2rnn_layer_3lstm_cell.json"); const std::string file_name("mxnet/2rnn_layer_3lstm_cell.json");
auto cpu_f = make_function(file_name); auto cpu_f = make_function_from_file(file_name);
auto int_f = make_function(file_name); auto int_f = make_function_from_file(file_name);
test::Uniform<float> rng(-1.0f, 1.0f); test::Uniform<float> rng(-1.0f, 1.0f);
vector<vector<float>> args; vector<vector<float>> args;
...@@ -3016,8 +2997,8 @@ TEST(cpu_fusion, fuse_batch_dot_forward) ...@@ -3016,8 +2997,8 @@ TEST(cpu_fusion, fuse_batch_dot_forward)
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
const std::string file_name("mxnet/batch_dot_3.json"); const std::string file_name("mxnet/batch_dot_3.json");
auto cpu_f = make_function(file_name); auto cpu_f = make_function_from_file(file_name);
auto int_f = make_function(file_name); auto int_f = make_function_from_file(file_name);
pass_manager.run_passes(cpu_f); pass_manager.run_passes(cpu_f);
test::Uniform<float> rng(0.0f, 1.0f); test::Uniform<float> rng(0.0f, 1.0f);
vector<vector<float>> args; vector<vector<float>> args;
...@@ -3036,11 +3017,42 @@ TEST(cpu_fusion, fuse_batch_dot_forward) ...@@ -3036,11 +3017,42 @@ TEST(cpu_fusion, fuse_batch_dot_forward)
} }
} }
TEST(cpu_fusion, fuse_batch_dot_backward)
{
const std::string file_name("mxnet/batch_dot_3.json");
auto cpu_f = make_function_from_file(file_name);
auto int_f = make_function_from_file(file_name);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
pass_manager.run_passes(cpu_f);
auto int_df = autodiff::backprop_function(int_f);
auto cpu_df = autodiff::backprop_function(cpu_f);
test::Uniform<float> rng(-1.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : cpu_df->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_df, args, "INTERPRETER");
auto cpu_results = execute(cpu_df, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
TEST(cpu_fusion, fuse_rnn_across_layer_2layer_3timestep) TEST(cpu_fusion, fuse_rnn_across_layer_2layer_3timestep)
{ {
const std::string file_name("mxnet/2layer_3timestep_ic100oc100.json"); const std::string file_name("mxnet/2layer_3timestep_ic100oc100.json");
auto cpu_f = make_function(file_name); auto cpu_f = make_function_from_file(file_name);
auto int_f = make_function(file_name); auto int_f = make_function_from_file(file_name);
test::Uniform<float> rng(-1.0f, 1.0f); test::Uniform<float> rng(-1.0f, 1.0f);
vector<vector<float>> args; vector<vector<float>> args;
......
...@@ -183,15 +183,6 @@ TEST(DISABLED_gpu_fusion, fuse_1_layer_rnn) ...@@ -183,15 +183,6 @@ TEST(DISABLED_gpu_fusion, fuse_1_layer_rnn)
} }
} }
static std::shared_ptr<Function> make_function(const std::string& file_name)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, file_name);
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
return func;
}
TEST(gpu_fusion, lstm_analytic) TEST(gpu_fusion, lstm_analytic)
{ {
auto input_xt = std::make_shared<op::Parameter>(element::f32, Shape{1, 1}); auto input_xt = std::make_shared<op::Parameter>(element::f32, Shape{1, 1});
...@@ -432,8 +423,8 @@ TEST(gpu_fusion, fuse_2_layer_rnn_1lstm_analytic) ...@@ -432,8 +423,8 @@ TEST(gpu_fusion, fuse_2_layer_rnn_1lstm_analytic)
TEST(gpu_fusion, rnn_fusion_inter_vs_gpu_1lstm_cell) TEST(gpu_fusion, rnn_fusion_inter_vs_gpu_1lstm_cell)
{ {
const std::string file_name("mxnet/1_lstm_cell_forward.json"); const std::string file_name("mxnet/1_lstm_cell_forward.json");
auto gpu_f = make_function(file_name); auto gpu_f = make_function_from_file(file_name);
auto int_f = make_function(file_name); auto int_f = make_function_from_file(file_name);
test::Uniform<float> rng(-10.0f, 10.0f); test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args; vector<vector<float>> args;
...@@ -454,8 +445,8 @@ TEST(gpu_fusion, rnn_fusion_inter_vs_gpu_1lstm_cell) ...@@ -454,8 +445,8 @@ TEST(gpu_fusion, rnn_fusion_inter_vs_gpu_1lstm_cell)
TEST(DISABLED_gpu_fusion, rnn_fusion_inter_vs_gpu_1rnn_layer_3lstm_cell) TEST(DISABLED_gpu_fusion, rnn_fusion_inter_vs_gpu_1rnn_layer_3lstm_cell)
{ {
const std::string file_name("mxnet/1rnn_layer_3lstm_cell.json"); const std::string file_name("mxnet/1rnn_layer_3lstm_cell.json");
auto gpu_f = make_function(file_name); auto gpu_f = make_function_from_file(file_name);
auto int_f = make_function(file_name); auto int_f = make_function_from_file(file_name);
test::Uniform<float> rng(-10.0f, 10.0f); test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args; vector<vector<float>> args;
...@@ -476,8 +467,8 @@ TEST(DISABLED_gpu_fusion, rnn_fusion_inter_vs_gpu_1rnn_layer_3lstm_cell) ...@@ -476,8 +467,8 @@ TEST(DISABLED_gpu_fusion, rnn_fusion_inter_vs_gpu_1rnn_layer_3lstm_cell)
TEST(gpu_fusion, rnn_fusion_inter_vs_gpu_2rnn_layer_3lstm_cell) TEST(gpu_fusion, rnn_fusion_inter_vs_gpu_2rnn_layer_3lstm_cell)
{ {
const std::string file_name("mxnet/2rnn_layer_3lstm_cell.json"); const std::string file_name("mxnet/2rnn_layer_3lstm_cell.json");
auto gpu_f = make_function(file_name); auto gpu_f = make_function_from_file(file_name);
auto int_f = make_function(file_name); auto int_f = make_function_from_file(file_name);
test::Uniform<float> rng(-10.0f, 10.0f); test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args; vector<vector<float>> args;
...@@ -516,8 +507,8 @@ TEST(gpu_fusion, fuse_rnn_across_layer) ...@@ -516,8 +507,8 @@ TEST(gpu_fusion, fuse_rnn_across_layer)
TEST(gpu_fusion, fuse_rnn_across_2layer_1timestep) TEST(gpu_fusion, fuse_rnn_across_2layer_1timestep)
{ {
const std::string file_name("mxnet/2rnn_layer_1timestep.json"); const std::string file_name("mxnet/2rnn_layer_1timestep.json");
auto gpu_f = make_function(file_name); auto gpu_f = make_function_from_file(file_name);
auto int_f = make_function(file_name); auto int_f = make_function_from_file(file_name);
test::Uniform<float> rng(-10.0f, 10.0f); test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args; vector<vector<float>> args;
......
...@@ -301,3 +301,12 @@ string ...@@ -301,3 +301,12 @@ string
return ss.str(); return ss.str();
} }
std::shared_ptr<Function> make_function_from_file(const std::string& file_name)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, file_name);
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
return func;
}
...@@ -38,6 +38,7 @@ namespace ngraph ...@@ -38,6 +38,7 @@ namespace ngraph
bool validate_list(const std::list<std::shared_ptr<ngraph::Node>>& nodes); bool validate_list(const std::list<std::shared_ptr<ngraph::Node>>& nodes);
std::shared_ptr<ngraph::Function> make_test_graph(); std::shared_ptr<ngraph::Function> make_test_graph();
std::shared_ptr<ngraph::Function> make_function_from_file(const std::string& file_name);
template <typename T> template <typename T>
void copy_data(std::shared_ptr<ngraph::runtime::Tensor> tv, const std::vector<T>& data) void copy_data(std::shared_ptr<ngraph::runtime::Tensor> tv, const std::vector<T>& data)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment