Commit b77fd922 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Reshape SoftMax Reshape (#2188)

* reshape softmax reshape

* add new line

* add new line

* fix style errors
parent c8bc3edc
...@@ -36,6 +36,7 @@ ...@@ -36,6 +36,7 @@
#include "ngraph/op/relu.hpp" #include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/sigmoid.hpp" #include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sqrt.hpp" #include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/pass/graph_rewrite.hpp" #include "ngraph/pass/graph_rewrite.hpp"
...@@ -622,3 +623,50 @@ void pass::CoreFusion::construct_optimized_strided_conv() ...@@ -622,3 +623,50 @@ void pass::CoreFusion::construct_optimized_strided_conv()
make_shared<pattern::Matcher>(eltwise_conv, callback, "CoreFusion.OptimizedStridedConv"); make_shared<pattern::Matcher>(eltwise_conv, callback, "CoreFusion.OptimizedStridedConv");
this->add_matcher(m); this->add_matcher(m);
} }
void ngraph::pass::CoreFusion::construct_reshape_softmax_reshape()
{
Shape input_shape{10, 20};
AxisVector io{1, 0};
auto input = make_shared<pattern::op::Label>(element::f32, input_shape);
auto reshape1 = make_shared<op::Reshape>(input, io, Shape{20, 10});
auto softmax = make_shared<op::Softmax>(reshape1, AxisSet{1});
auto reshape2 = make_shared<op::Reshape>(softmax, io, input_shape);
pattern::graph_rewrite_callback callback = [input](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_reshape_softmax_reshape against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto reshape2_m = std::static_pointer_cast<op::Reshape>(m.get_match_root());
auto softmax_m = std::static_pointer_cast<op::Softmax>(reshape2_m->get_argument(0));
auto reshape1_m = std::static_pointer_cast<op::Reshape>(softmax_m->get_argument(0));
auto input_m = m.get_pattern_map()[input];
if (!reshape2_m->get_is_transpose() || !reshape1_m->get_is_transpose())
{
NGRAPH_DEBUG << "we expect reshape2 and reshape1 both be dimshuffles";
return false;
}
if (input_m->get_shape() != reshape2_m->get_shape())
{
NGRAPH_DEBUG << "input and reshape2's shape are different";
return false;
}
AxisSet new_axes;
const auto& axis_order = reshape2_m->get_input_order();
for (auto axis : softmax_m->get_axes())
{
new_axes.insert(axis_order.at(axis));
}
auto new_softmax = make_shared<op::Softmax>(input_m, new_axes);
ngraph::replace_node(m.get_match_root(), new_softmax);
return true;
};
auto m = make_shared<pattern::Matcher>(reshape2, callback, "CoreFusion.ReshapeSoftmaxReshape");
this->add_matcher(m);
}
...@@ -38,6 +38,7 @@ public: ...@@ -38,6 +38,7 @@ public:
construct_sigmoid(); construct_sigmoid();
construct_sigmoid_bprop(); construct_sigmoid_bprop();
construct_optimized_strided_conv(); construct_optimized_strided_conv();
construct_reshape_softmax_reshape();
} }
void construct_relu(); void construct_relu();
void construct_folded_batch_norm(); void construct_folded_batch_norm();
...@@ -45,4 +46,5 @@ public: ...@@ -45,4 +46,5 @@ public:
void construct_sigmoid(); void construct_sigmoid();
void construct_sigmoid_bprop(); void construct_sigmoid_bprop();
void construct_optimized_strided_conv(); void construct_optimized_strided_conv();
void construct_reshape_softmax_reshape();
}; };
...@@ -27,6 +27,8 @@ ...@@ -27,6 +27,8 @@
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/op/relu.hpp" #include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/pass/core_fusion.hpp" #include "ngraph/pass/core_fusion.hpp"
#include "ngraph/pass/graph_rewrite.hpp" #include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
...@@ -36,8 +38,10 @@ ...@@ -36,8 +38,10 @@
#include "ngraph/serializer.hpp" #include "ngraph/serializer.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
#include "util/all_close.hpp"
#include "util/autodiff/backprop_function.hpp" #include "util/autodiff/backprop_function.hpp"
#include "util/matcher.hpp" #include "util/matcher.hpp"
#include "util/random.hpp"
#include "util/test_tools.hpp" #include "util/test_tools.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -138,3 +142,39 @@ TEST(core_fusion, sparsity_opt_56x56) ...@@ -138,3 +142,39 @@ TEST(core_fusion, sparsity_opt_56x56)
ASSERT_EQ(t_eltwise_conv1->get_window_movement_strides(), stride_1); ASSERT_EQ(t_eltwise_conv1->get_window_movement_strides(), stride_1);
ASSERT_EQ(t_eltwise_conv2->get_window_movement_strides(), stride_1); ASSERT_EQ(t_eltwise_conv2->get_window_movement_strides(), stride_1);
} }
static std::shared_ptr<Function> generate_reshape_softmax_reshape()
{
Shape shape_nchw{10, 20, 30, 40};
Shape shape_nhwc{10, 30, 40, 20};
AxisVector to_nhwc{0, 2, 3, 1};
AxisVector to_nchw{0, 3, 1, 2};
auto input = make_shared<op::Parameter>(element::f32, shape_nchw);
auto reshape1 = make_shared<op::Reshape>(input, to_nhwc, shape_nhwc);
auto softmax = make_shared<op::Softmax>(reshape1, AxisSet{1, 2, 3});
auto reshape2 = make_shared<op::Reshape>(softmax, to_nchw, shape_nchw);
auto f = make_shared<Function>(reshape2, ParameterVector{input});
return f;
}
TEST(core_fusion, reshape_softmax_reshape)
{
auto baseline_f = generate_reshape_softmax_reshape();
auto optimized_f = generate_reshape_softmax_reshape();
auto baseline_input = baseline_f->get_parameters().at(0);
pass::Manager pass_manager;
pass_manager.register_pass<pass::CoreFusion>();
pass_manager.run_passes(optimized_f);
test::Uniform<float> rng(0.0f, 100.0f);
vector<vector<float>> args;
vector<float> tensor_val(shape_size(baseline_input->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment