Commit 6c676d2d authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Merge fixes

parent 2659d5be
......@@ -226,20 +226,13 @@ void runtime::cpu::CPU_ExternalFunction::compile()
string function_name = m_function->get_name();
<<<<<<< HEAD
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPULayout>();
pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::pass::MemoryLayout>(MemoryPoolAlignment);
=======
pass::Manager pass_manager;
// For now, just make everyone row-major.
pass_manager.register_pass<pass::CPUFusion>();
pass_manager.register_pass<pass::AssignLayout<descriptor::layout::DenseTensorViewLayout>>();
pass_manager.register_pass<pass::Liveness>();
pass_manager.register_pass<pass::MemoryLayout>(64);
>>>>>>> master
pass_manager.run_passes(m_function);
codegen::CodeWriter writer;
......
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "cpu_fusion.hpp"
#include <algorithm>
#include <iostream>
#include <unordered_set>
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape,
std::shared_ptr<ngraph::Node> arg,
bool& transpose_w,
ngraph::Shape& shape_w)
{
auto r_w = std::dynamic_pointer_cast<ngraph::op::Reshape>(reshape);
if (!r_w)
{
return true; //nth to do; reshape isn't a reshape
}
if (r_w->get_shape().size() != 2)
{
NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " doesn't reshape into matrix"
<< ngraph::vector_to_string(r_w->get_shape());
return false;
}
auto io = r_w->get_input_order();
if (r_w->get_shape().size() != arg->get_shape().size()) //reshape
{
ngraph::AxisVector dio(io.size());
std::iota(begin(dio), end(dio), 0);
if (io != dio) //we can't reshape and transpose at the same time
{
NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " is not in default order "
<< ngraph::vector_to_string(io);
NGRAPH_DEBUG << "r_w shape = " << ngraph::vector_to_string(r_w->get_shape());
NGRAPH_DEBUG << "arg shape = " << ngraph::vector_to_string(arg->get_shape());
return false;
}
shape_w = r_w->get_shape();
}
else
{
if (io == ngraph::AxisVector{1, 0})
{
transpose_w = true;
}
//otherwise no-op reshape
}
return true;
}
template <typename T>
static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector order)
{
if (input.size() != order.size())
{
throw "input and order sizes don't match!";
}
std::vector<T> output(input.size());
for (size_t i = 0; i < order.size(); i++)
{
output[i] = input.at(order.at(i));
}
return output;
}
void ngraph::pass::CPUFusion::construct_gemm_pattern()
{
auto shape_w = Shape{2, 4};
auto shape_x = Shape{4, 1};
auto shape_b = Shape{1};
auto shape_dot = Shape{2, 1};
auto W = std::make_shared<pattern::op::Label>(element::f32, shape_w);
auto x = std::make_shared<pattern::op::Label>(element::f32, shape_x);
auto reshape_pred = [](std::shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Reshape>(n));
};
auto skip_w = std::make_shared<pattern::op::Any>(W, reshape_pred);
auto skip_x = std::make_shared<pattern::op::Any>(x, reshape_pred);
auto pdot = std::make_shared<op::Dot>(skip_w, skip_x);
auto b = std::make_shared<pattern::op::Label>(element::f32, shape_b);
auto pbroadcast = std::make_shared<op::Broadcast>(b, shape_dot, AxisSet{0});
auto padd = pdot + pbroadcast;
ngraph::pattern::gr_callback_fn callback = [W, x, b](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_gemm_pattern against node = "
<< m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
std::shared_ptr<Node> nn = nullptr;
auto mpattern = m.match_root();
if (mpattern->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "mpattern = " << mpattern->get_name() << " type is not float!";
return nn;
}
auto dot = mpattern->get_input_op(0);
if (dot->get_shape().size() != 2)
{
NGRAPH_DEBUG << "dot = " << dot->get_name() << " shape is not equal to 2!";
return nn;
}
bool transpose_w = false;
Shape shape_arg0{pattern_map[W]->get_shape()};
if (!init_cblas_arg(dot->get_input_op(0), pattern_map[W], transpose_w, shape_arg0))
{
return nn;
}
bool transpose_x = false;
Shape shape_arg1{pattern_map[x]->get_shape()};
if (!init_cblas_arg(dot->get_input_op(1), pattern_map[x], transpose_x, shape_arg1))
{
return nn;
}
auto cg = std::shared_ptr<Node>(new op::MatmulBias(pattern_map[W],
pattern_map[x],
mpattern->get_input_op(1),
shape_arg0,
shape_arg1,
transpose_w,
transpose_x));
return cg;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(padd, callback);
this->add_matcher(m);
}
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "cpu_fusion.hpp"
#include <algorithm>
#include <iostream>
#include <unordered_set>
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape,
std::shared_ptr<ngraph::Node> arg,
bool& transpose_w,
ngraph::Shape& shape_w)
{
auto r_w = std::dynamic_pointer_cast<ngraph::op::Reshape>(reshape);
if (!r_w)
{
return true; //nth to do; reshape isn't a reshape
}
if (r_w->get_shape().size() != 2)
{
NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " doesn't reshape into matrix"
<< ngraph::vector_to_string(r_w->get_shape());
return false;
}
auto io = r_w->get_input_order();
if (r_w->get_shape().size() != arg->get_shape().size()) //reshape
{
ngraph::AxisVector dio(io.size());
std::iota(begin(dio), end(dio), 0);
if (io != dio) //we can't reshape and transpose at the same time
{
NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " is not in default order "
<< ngraph::vector_to_string(io);
NGRAPH_DEBUG << "r_w shape = " << ngraph::vector_to_string(r_w->get_shape());
NGRAPH_DEBUG << "arg shape = " << ngraph::vector_to_string(arg->get_shape());
return false;
}
shape_w = r_w->get_shape();
}
else
{
if (io == ngraph::AxisVector{1, 0})
{
transpose_w = true;
}
//otherwise no-op reshape
}
return true;
}
template <typename T>
static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector order)
{
if (input.size() != order.size())
{
throw "input and order sizes don't match!";
}
std::vector<T> output(input.size());
for (size_t i = 0; i < order.size(); i++)
{
output[i] = input.at(order.at(i));
}
return output;
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_gemm_pattern()
{
auto shape_w = Shape{2, 4};
auto shape_x = Shape{4, 1};
auto shape_b = Shape{1};
auto shape_dot = Shape{2, 1};
auto W = std::make_shared<pattern::op::Label>(element::f32, shape_w);
auto x = std::make_shared<pattern::op::Label>(element::f32, shape_x);
auto reshape_pred = [](std::shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Reshape>(n));
};
auto skip_w = std::make_shared<pattern::op::Any>(W, reshape_pred);
auto skip_x = std::make_shared<pattern::op::Any>(x, reshape_pred);
auto pdot = std::make_shared<op::Dot>(skip_w, skip_x);
auto b = std::make_shared<pattern::op::Label>(element::f32, shape_b);
auto pbroadcast = std::make_shared<op::Broadcast>(b, shape_dot, AxisSet{0});
auto padd = pdot + pbroadcast;
ngraph::pattern::gr_callback_fn callback = [W, x, b](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_gemm_pattern against node = "
<< m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
std::shared_ptr<Node> nn = nullptr;
auto mpattern = m.match_root();
if (mpattern->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "mpattern = " << mpattern->get_name() << " type is not float!";
return nn;
}
auto dot = mpattern->get_input_op(0);
if (dot->get_shape().size() != 2)
{
NGRAPH_DEBUG << "dot = " << dot->get_name() << " shape is not equal to 2!";
return nn;
}
bool transpose_w = false;
Shape shape_arg0{pattern_map[W]->get_shape()};
if (!init_cblas_arg(dot->get_input_op(0), pattern_map[W], transpose_w, shape_arg0))
{
return nn;
}
bool transpose_x = false;
Shape shape_arg1{pattern_map[x]->get_shape()};
if (!init_cblas_arg(dot->get_input_op(1), pattern_map[x], transpose_x, shape_arg1))
{
return nn;
}
auto cg = std::shared_ptr<Node>(new op::MatmulBias(pattern_map[W],
pattern_map[x],
mpattern->get_input_op(1),
shape_arg0,
shape_arg1,
transpose_w,
transpose_x));
return cg;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(padd, callback);
this->add_matcher(m);
}
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace pass
{
class CPUFusion;
}
}
class ngraph::pass::CPUFusion : public ngraph::pass::GraphRewrite
{
public:
CPUFusion()
: GraphRewrite()
{
construct_gemm_pattern();
}
private:
void construct_gemm_pattern();
};
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace pass
{
class CPUFusion;
}
}
}
}
class ngraph::runtime::cpu::pass::CPUFusion : public ngraph::pass::GraphRewrite
{
public:
CPUFusion()
: GraphRewrite()
{
construct_gemm_pattern();
}
private:
void construct_gemm_pattern();
};
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/ops/sum.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
//
#include "ngraph/file_util.hpp"
#include "ngraph/json.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "util/matcher.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(cpu_fusion, gemm_pattern)
{
auto shape_w = Shape{2, 4};
auto shape_x = Shape{4, 1};
auto shape_b = Shape{1};
auto A = make_shared<op::Parameter>(element::f32, shape_w);
auto B = make_shared<op::Parameter>(element::f32, shape_x);
auto C = make_shared<op::Parameter>(element::f32, shape_b);
auto dot = make_shared<op::Dot>(A, B);
auto broadcast = make_shared<op::Broadcast>(C, dot->get_shape(), AxisSet{0});
auto add = dot + broadcast;
auto W = std::make_shared<pattern::op::Label>(A);
auto x = std::make_shared<pattern::op::Label>(B);
auto reshape_pred = [](std::shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Reshape>(n));
};
auto skip_w = std::make_shared<pattern::op::Any>(W, reshape_pred);
auto skip_x = std::make_shared<pattern::op::Any>(x, reshape_pred);
auto pdot = make_shared<op::Dot>(skip_w, skip_x);
auto b = std::make_shared<pattern::op::Label>(C);
auto pbroadcast = make_shared<op::Broadcast>(b, dot->get_shape(), AxisSet{0});
auto padd = pdot + pbroadcast;
TestMatcher n(nullptr);
ASSERT_TRUE(n.match(padd, add));
ASSERT_EQ(n.get_pattern_map()[W], A);
ASSERT_EQ(n.get_pattern_map()[x], B);
ASSERT_EQ(n.get_pattern_map()[b], C);
auto reshape_w = make_shared<op::Reshape>(A, AxisVector{1, 0}, W->get_shape());
auto reshape_x = make_shared<op::Reshape>(B, AxisVector{1, 0}, x->get_shape());
auto re_dot = make_shared<op::Dot>(reshape_w, reshape_x);
auto re_add = re_dot + broadcast;
ASSERT_TRUE(n.match(padd, re_add));
ASSERT_EQ(n.get_pattern_map()[W], A);
ASSERT_EQ(n.get_pattern_map()[x], B);
ASSERT_EQ(n.get_pattern_map()[b], C);
auto cg =
make_shared<op::MatmulBias>(W, x, broadcast, W->get_shape(), x->get_shape(), false, false);
}
TEST(cpu_fusion, gemm_cpu)
{
auto shapeA = Shape{3, 2};
auto shapeB = Shape{2, 3};
auto shapeC = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shapeA);
auto B = make_shared<op::Parameter>(element::f32, shapeB);
auto reshape_w = make_shared<op::Reshape>(A, AxisVector{1, 0}, Shape{2, 3});
auto reshape_x = make_shared<op::Reshape>(B, AxisVector{1, 0}, Shape{3, 2});
auto one = op::Constant::create<float>(element::f32, Shape{}, std::vector<float>{1.0f});
auto broadcast = make_shared<op::Broadcast>(one, shapeC, AxisSet{0, 1});
auto cg =
make_shared<op::MatmulBias>(A, B, broadcast, A->get_shape(), B->get_shape(), true, true);
auto f = make_shared<Function>(cg, op::Parameters{A, B});
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
shared_ptr<runtime::TensorView> a = backend->make_primary_tensor_view(element::f32, shapeA);
shared_ptr<runtime::TensorView> b = backend->make_primary_tensor_view(element::f32, shapeB);
shared_ptr<runtime::TensorView> result =
backend->make_primary_tensor_view(element::f32, shapeC);
vector<float> dataA{1.0f, 4.0f, 1.0f, 4.0f, 1.0f, 4.0f};
vector<float> dataB{3.0f, 3.0f, 3.0f, 9.0f, 9.0f, 9.0f};
copy_data(a, dataA);
copy_data(b, dataB);
cf->call({a, b}, {result});
vector<float> expected{10, 28, 37, 109};
ASSERT_TRUE(read_vector<float>(result) == expected);
}
TEST(cpu_fusion, cpu_fusion_pass_basic)
{
auto shape = Shape{};
auto shape_w = Shape{2, 4};
auto shape_x = Shape{4, 1};
auto shape_b = Shape{1};
auto A = make_shared<op::Parameter>(element::f32, shape_w);
auto B = make_shared<op::Parameter>(element::f32, shape_x);
auto C = make_shared<op::Parameter>(element::f32, shape_b);
auto dot = make_shared<op::Dot>(A, B);
auto broadcast = make_shared<op::Broadcast>(C, dot->get_shape(), AxisSet{0});
auto add = dot + broadcast;
auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager;
pass_manager.register_pass<pass::CPUFusion>();
auto func = make_shared<Function>(graph, op::Parameters{A, B, C});
pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_input_op(0)), nullptr);
}
TEST(cpu_fusion, gemm_mlp)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/mnist_mlp_forward.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass::Manager pass_manager;
pass_manager.register_pass<pass::CPUFusion>();
pass_manager.run_passes(func);
size_t ccg = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(ccg, 3);
}
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/ops/sum.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
//
#include "ngraph/file_util.hpp"
#include "ngraph/json.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "util/matcher.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(cpu_fusion, gemm_pattern)
{
auto shape_w = Shape{2, 4};
auto shape_x = Shape{4, 1};
auto shape_b = Shape{1};
auto A = make_shared<op::Parameter>(element::f32, shape_w);
auto B = make_shared<op::Parameter>(element::f32, shape_x);
auto C = make_shared<op::Parameter>(element::f32, shape_b);
auto dot = make_shared<op::Dot>(A, B);
auto broadcast = make_shared<op::Broadcast>(C, dot->get_shape(), AxisSet{0});
auto add = dot + broadcast;
auto W = std::make_shared<pattern::op::Label>(A);
auto x = std::make_shared<pattern::op::Label>(B);
auto reshape_pred = [](std::shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Reshape>(n));
};
auto skip_w = std::make_shared<pattern::op::Any>(W, reshape_pred);
auto skip_x = std::make_shared<pattern::op::Any>(x, reshape_pred);
auto pdot = make_shared<op::Dot>(skip_w, skip_x);
auto b = std::make_shared<pattern::op::Label>(C);
auto pbroadcast = make_shared<op::Broadcast>(b, dot->get_shape(), AxisSet{0});
auto padd = pdot + pbroadcast;
TestMatcher n(nullptr);
ASSERT_TRUE(n.match(padd, add));
ASSERT_EQ(n.get_pattern_map()[W], A);
ASSERT_EQ(n.get_pattern_map()[x], B);
ASSERT_EQ(n.get_pattern_map()[b], C);
auto reshape_w = make_shared<op::Reshape>(A, AxisVector{1, 0}, W->get_shape());
auto reshape_x = make_shared<op::Reshape>(B, AxisVector{1, 0}, x->get_shape());
auto re_dot = make_shared<op::Dot>(reshape_w, reshape_x);
auto re_add = re_dot + broadcast;
ASSERT_TRUE(n.match(padd, re_add));
ASSERT_EQ(n.get_pattern_map()[W], A);
ASSERT_EQ(n.get_pattern_map()[x], B);
ASSERT_EQ(n.get_pattern_map()[b], C);
auto cg =
make_shared<op::MatmulBias>(W, x, broadcast, W->get_shape(), x->get_shape(), false, false);
}
TEST(cpu_fusion, gemm_cpu)
{
auto shapeA = Shape{3, 2};
auto shapeB = Shape{2, 3};
auto shapeC = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shapeA);
auto B = make_shared<op::Parameter>(element::f32, shapeB);
auto reshape_w = make_shared<op::Reshape>(A, AxisVector{1, 0}, Shape{2, 3});
auto reshape_x = make_shared<op::Reshape>(B, AxisVector{1, 0}, Shape{3, 2});
auto one = op::Constant::create<float>(element::f32, Shape{}, std::vector<float>{1.0f});
auto broadcast = make_shared<op::Broadcast>(one, shapeC, AxisSet{0, 1});
auto cg =
make_shared<op::MatmulBias>(A, B, broadcast, A->get_shape(), B->get_shape(), true, true);
auto f = make_shared<Function>(cg, op::Parameters{A, B});
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
shared_ptr<runtime::TensorView> a = backend->make_primary_tensor_view(element::f32, shapeA);
shared_ptr<runtime::TensorView> b = backend->make_primary_tensor_view(element::f32, shapeB);
shared_ptr<runtime::TensorView> result =
backend->make_primary_tensor_view(element::f32, shapeC);
vector<float> dataA{1.0f, 4.0f, 1.0f, 4.0f, 1.0f, 4.0f};
vector<float> dataB{3.0f, 3.0f, 3.0f, 9.0f, 9.0f, 9.0f};
copy_data(a, dataA);
copy_data(b, dataB);
cf->call({a, b}, {result});
vector<float> expected{10, 28, 37, 109};
ASSERT_TRUE(read_vector<float>(result) == expected);
}
TEST(cpu_fusion, cpu_fusion_pass_basic)
{
auto shape = Shape{};
auto shape_w = Shape{2, 4};
auto shape_x = Shape{4, 1};
auto shape_b = Shape{1};
auto A = make_shared<op::Parameter>(element::f32, shape_w);
auto B = make_shared<op::Parameter>(element::f32, shape_x);
auto C = make_shared<op::Parameter>(element::f32, shape_b);
auto dot = make_shared<op::Dot>(A, B);
auto broadcast = make_shared<op::Broadcast>(C, dot->get_shape(), AxisSet{0});
auto add = dot + broadcast;
auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
auto func = make_shared<Function>(graph, op::Parameters{A, B, C});
pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_input_op(0)), nullptr);
}
TEST(cpu_fusion, gemm_mlp)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/mnist_mlp_forward.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(func);
size_t ccg = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(ccg, 3);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment