Commit 518bba03 authored by Jayaram Bobba's avatar Jayaram Bobba

Merge remote-tracking branch 'origin/master' into jbobba/maxpool-layouts

parents eed7b313 e46184a1
...@@ -272,13 +272,6 @@ namespace ngraph ...@@ -272,13 +272,6 @@ namespace ngraph
const char* cbeta = "0.0f"; const char* cbeta = "0.0f";
if (args.size() > 2)
{
writer << "memcpy(" << out[0].get_name() << ", " << args[2].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
cbeta = "1.0f";
}
writer << "cblas::cblas_sgemm(" writer << "cblas::cblas_sgemm("
<< "cblas::Layout::RowMajor, " << tranpose_a << tranpose_b << m << ", " << n << "cblas::Layout::RowMajor, " << tranpose_a << tranpose_b << m << ", " << n
<< ", " << k << ",\n" << ", " << k << ",\n"
...@@ -287,6 +280,101 @@ namespace ngraph ...@@ -287,6 +280,101 @@ namespace ngraph
<< " " << out[0].get_name() << ", " << max(1UL, arg2_shape[1]) << " " << out[0].get_name() << ", " << max(1UL, arg2_shape[1])
<< ");\n"; << ");\n";
if (args.size() > 2)
{
auto axes = cg->get_broadcast_axes();
if (axes.size() == 1)
{
if (*(axes.begin()) == 0)
{
writer << "static " << out[0].get_element_type().c_type_string()
<< " ones_row[" << arg2_shape[0] << "]"
<< " = { 1.0f";
for (size_t i = 1; i < arg2_shape[0]; ++i)
{
writer << ", 1.0f";
}
writer << "};\n";
writer << "cblas::cblas_sgemm("
<< "cblas::Layout::RowMajor, " << cnotranspose << cnotranspose
<< arg2_shape[0] << ", " << arg2_shape[1] << ", 1"
<< ",\n"
<< " 1.0f, ones_row, "
<< "1"
<< ", " << args[2].get_name() << ", " << max(1UL, arg2_shape[1])
<< ", "
<< "1.0f"
<< ",\n"
<< " " << out[0].get_name() << ", "
<< max(1UL, arg2_shape[1]) << ");\n";
}
else
{
writer << "static " << out[0].get_element_type().c_type_string()
<< " ones_col[" << arg2_shape[1] << "]"
<< " = { 1.0f";
for (size_t i = 1; i < arg2_shape[1]; ++i)
{
writer << ", 1.0f";
}
writer << "};\n";
writer << "cblas::cblas_sgemm("
<< "cblas::Layout::RowMajor, " << cnotranspose << ctranspose
<< arg2_shape[0] << ", " << arg2_shape[1] << ", 1"
<< ",\n"
<< " 1.0f, ones_col," << max(1UL, arg2_shape[1]) << ", "
<< args[2].get_name() << ", "
<< "1"
<< ", "
<< "1.0f"
<< ",\n"
<< " " << out[0].get_name() << ", "
<< max(1UL, arg2_shape[1]) << ");\n";
}
}
else
{
if (axes.size() != 2)
{
throw ngraph_error("unexpected broadcast rank");
}
writer << out[0].get_element_type().c_type_string() << " bias["
<< arg2_shape[1] << "]"
<< " = { " << args[2].get_name() << "[0]";
for (size_t i = 1; i < arg2_shape[1]; ++i)
{
writer << "," << args[2].get_name() << "[0]";
}
writer << "};\n";
writer << "static " << out[0].get_element_type().c_type_string()
<< " ones_scalar[" << arg2_shape[0] << "]"
<< " = { 1.0f";
for (size_t i = 1; i < arg2_shape[0]; ++i)
{
writer << ", 1.0f";
}
writer << "};\n";
writer << "cblas::cblas_sgemm("
<< "cblas::Layout::RowMajor, " << cnotranspose << cnotranspose
<< arg2_shape[0] << ", " << arg2_shape[1] << ", 1"
<< ",\n"
<< " 1.0f, ones_scalar, "
<< "1"
<< ", "
<< "bias"
<< ", " << max(1UL, arg2_shape[1]) << ", "
<< "1.0f"
<< ",\n"
<< " " << out[0].get_name() << ", " << max(1UL, arg2_shape[1])
<< ");\n";
}
}
writer.indent--; writer.indent--;
writer << "}\n"; writer << "}\n";
} }
...@@ -3022,6 +3110,19 @@ namespace ngraph ...@@ -3022,6 +3110,19 @@ namespace ngraph
auto output_format = auto output_format =
dynamic_cast<runtime::cpu::LayoutDescriptor&>(*output_tvl).get_mkldnn_format(); dynamic_cast<runtime::cpu::LayoutDescriptor&>(*output_tvl).get_mkldnn_format();
// MKLDNN relies on format names for selecting optimized kernel implementations
// Hacky way to deal with this until they move to using canonicalized layouts
if (input_format == mkldnn::memory::format::nchw &&
runtime::cpu::mkldnn_utils::is_mkldnn_filter_format(output_format))
{
input_format = mkldnn::memory::format::oihw;
}
if (output_format == mkldnn::memory::format::nchw &&
runtime::cpu::mkldnn_utils::is_mkldnn_filter_format(input_format))
{
output_format = mkldnn::memory::format::oihw;
}
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_emitter->build_memory_descriptor(args[0], input_format); auto input_desc = mkldnn_emitter->build_memory_descriptor(args[0], input_format);
auto result_desc = mkldnn_emitter->build_memory_descriptor(out[0], output_format); auto result_desc = mkldnn_emitter->build_memory_descriptor(out[0], output_format);
......
...@@ -110,6 +110,23 @@ static const std::map<memory::format, const std::string> s_mkldnn_format_string_ ...@@ -110,6 +110,23 @@ static const std::map<memory::format, const std::string> s_mkldnn_format_string_
{memory::format::OhIw16o4i, "memory::format::OhIw16o4i"}, {memory::format::OhIw16o4i, "memory::format::OhIw16o4i"},
}; };
static const std::set<memory::format> s_filter_formats{
memory::format::oihw,
memory::format::ihwo,
memory::format::hwio,
// memory::format::oIhw8i, // These currently map to nChw8c and nChw16c
// memory::format::oIhw16i,
memory::format::OIhw8i8o,
memory::format::OIhw16i16o,
memory::format::IOhw16o16i,
memory::format::OIhw8o8i,
memory::format::OIhw16o16i,
memory::format::Oihw8o,
memory::format::Oihw16o,
memory::format::Ohwi8o,
memory::format::Ohwi16o,
memory::format::OhIw16o4i};
bool runtime::cpu::mkldnn_utils::IsMKLDNNOp(ngraph::Node& op) bool runtime::cpu::mkldnn_utils::IsMKLDNNOp(ngraph::Node& op)
{ {
return (s_op_registry.find(TI(op)) != s_op_registry.end()); return (s_op_registry.find(TI(op)) != s_op_registry.end());
...@@ -157,16 +174,16 @@ const std::string& runtime::cpu::mkldnn_utils::get_mkldnn_format_string(memory:: ...@@ -157,16 +174,16 @@ const std::string& runtime::cpu::mkldnn_utils::get_mkldnn_format_string(memory::
} }
mkldnn::memory::format runtime::cpu::mkldnn_utils::get_input_mkldnn_format(const Node* node, mkldnn::memory::format runtime::cpu::mkldnn_utils::get_input_mkldnn_format(const Node* node,
int index) size_t index)
{ {
auto tvl = node->get_inputs()[index].get_output().get_tensor_view()->get_tensor_view_layout(); auto tvl = node->get_inputs()[index].get_output().get_tensor_view()->get_tensor_view_layout();
return dynamic_cast<runtime::cpu::LayoutDescriptor&>(*tvl).get_mkldnn_format(); return dynamic_cast<runtime::cpu::LayoutDescriptor&>(*tvl).get_mkldnn_format();
} }
mkldnn::memory::format runtime::cpu::mkldnn_utils::get_output_mkldnn_format(const Node* node, mkldnn::memory::format runtime::cpu::mkldnn_utils::get_output_mkldnn_format(const Node* node,
int index) size_t index)
{ {
auto tvl = node->get_output_tensor_view(0)->get_tensor_view_layout(); auto tvl = node->get_output_tensor_view(index)->get_tensor_view_layout();
return dynamic_cast<runtime::cpu::LayoutDescriptor&>(*tvl).get_mkldnn_format(); return dynamic_cast<runtime::cpu::LayoutDescriptor&>(*tvl).get_mkldnn_format();
} }
...@@ -181,7 +198,7 @@ bool runtime::cpu::mkldnn_utils::use_mkldnn_kernel(const ngraph::Node* node) ...@@ -181,7 +198,7 @@ bool runtime::cpu::mkldnn_utils::use_mkldnn_kernel(const ngraph::Node* node)
bool runtime::cpu::mkldnn_utils::compare_mkldnn_formats(mkldnn::memory::format fmt1, bool runtime::cpu::mkldnn_utils::compare_mkldnn_formats(mkldnn::memory::format fmt1,
mkldnn::memory::format fmt2) mkldnn::memory::format fmt2)
{ {
set<mkldnn::memory::format> similar_4d_formats{mkldnn::memory::format::nchw, std::set<mkldnn::memory::format> similar_4d_formats{mkldnn::memory::format::nchw,
mkldnn::memory::format::oihw}; mkldnn::memory::format::oihw};
if ((fmt1 == fmt2) || (similar_4d_formats.find(fmt1) != similar_4d_formats.end() && if ((fmt1 == fmt2) || (similar_4d_formats.find(fmt1) != similar_4d_formats.end() &&
similar_4d_formats.find(fmt2) != similar_4d_formats.end())) similar_4d_formats.find(fmt2) != similar_4d_formats.end()))
...@@ -190,3 +207,12 @@ bool runtime::cpu::mkldnn_utils::compare_mkldnn_formats(mkldnn::memory::format f ...@@ -190,3 +207,12 @@ bool runtime::cpu::mkldnn_utils::compare_mkldnn_formats(mkldnn::memory::format f
} }
return false; return false;
} }
bool runtime::cpu::mkldnn_utils::is_mkldnn_filter_format(mkldnn::memory::format fmt)
{
if (s_filter_formats.find(fmt) != s_filter_formats.end())
{
return true;
}
return false;
}
...@@ -39,11 +39,12 @@ namespace ngraph ...@@ -39,11 +39,12 @@ namespace ngraph
mkldnn::memory::data_type get_mkldnn_data_type(const ngraph::element::Type& type); mkldnn::memory::data_type get_mkldnn_data_type(const ngraph::element::Type& type);
const std::string& get_mkldnn_format_string(mkldnn::memory::format fmt); const std::string& get_mkldnn_format_string(mkldnn::memory::format fmt);
mkldnn::memory::format get_input_mkldnn_format(const Node* node, int index); mkldnn::memory::format get_input_mkldnn_format(const Node* node, size_t index);
mkldnn::memory::format get_output_mkldnn_format(const Node* node, int index); mkldnn::memory::format get_output_mkldnn_format(const Node* node, size_t index);
bool use_mkldnn_kernel(const ngraph::Node* node); bool use_mkldnn_kernel(const ngraph::Node* node);
bool compare_mkldnn_formats(mkldnn::memory::format fmt1, bool compare_mkldnn_formats(mkldnn::memory::format fmt1,
mkldnn::memory::format fmt2); mkldnn::memory::format fmt2);
bool is_mkldnn_filter_format(mkldnn::memory::format fmt);
} }
} }
} }
......
...@@ -32,7 +32,8 @@ std::shared_ptr<ngraph::Node> ...@@ -32,7 +32,8 @@ std::shared_ptr<ngraph::Node>
m_shape_w, m_shape_w,
m_shape_x, m_shape_x,
m_transpose_w, m_transpose_w,
m_transpose_x); m_transpose_x,
m_broadcast_axes);
} }
ngraph::op::MatmulBias::MatmulBias(std::shared_ptr<ngraph::Node> W, ngraph::op::MatmulBias::MatmulBias(std::shared_ptr<ngraph::Node> W,
...@@ -41,7 +42,8 @@ ngraph::op::MatmulBias::MatmulBias(std::shared_ptr<ngraph::Node> W, ...@@ -41,7 +42,8 @@ ngraph::op::MatmulBias::MatmulBias(std::shared_ptr<ngraph::Node> W,
Shape shape_w, Shape shape_w,
Shape shape_x, Shape shape_x,
bool transpose_w, bool transpose_w,
bool transpose_x) bool transpose_x,
AxisSet axes)
: RequiresTensorViewArgs("MatMulBias", : RequiresTensorViewArgs("MatMulBias",
b == nullptr ? std::vector<std::shared_ptr<Node>>{W, x} b == nullptr ? std::vector<std::shared_ptr<Node>>{W, x}
: std::vector<std::shared_ptr<Node>>{W, x, b}) : std::vector<std::shared_ptr<Node>>{W, x, b})
...@@ -49,8 +51,24 @@ ngraph::op::MatmulBias::MatmulBias(std::shared_ptr<ngraph::Node> W, ...@@ -49,8 +51,24 @@ ngraph::op::MatmulBias::MatmulBias(std::shared_ptr<ngraph::Node> W,
, m_shape_x(shape_x) , m_shape_x(shape_x)
, m_transpose_w(transpose_w) , m_transpose_w(transpose_w)
, m_transpose_x(transpose_x) , m_transpose_x(transpose_x)
, m_broadcast_axes(axes)
{ {
if (axes.size() == 0 && b != nullptr)
{
throw ngraph_error("Bias but no broadcast axes");
}
if (b == nullptr && axes.size() != 0)
{
throw ngraph_error("Broadcast axes but no bias");
}
if (axes.size() > 2)
{
throw ngraph_error("Broadcasting to > 2D tensor");
}
if (shape_w.size() != 2) if (shape_w.size() != 2)
{ {
NGRAPH_DEBUG << "W shape = " << vector_to_string(shape_w); NGRAPH_DEBUG << "W shape = " << vector_to_string(shape_w);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#pragma once #pragma once
#include "ngraph/axis_set.hpp"
#include "ngraph/ops/util/requires_tensor_view_args.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
...@@ -31,12 +32,14 @@ namespace ngraph ...@@ -31,12 +32,14 @@ namespace ngraph
Shape shape_w, Shape shape_w,
Shape shape_x, Shape shape_x,
bool transpose_w, bool transpose_w,
bool transpose_x); bool transpose_x,
AxisSet axes = AxisSet{});
bool get_is_arg0_transposed() const { return m_transpose_w; } bool get_is_arg0_transposed() const { return m_transpose_w; }
bool get_is_arg1_transposed() const { return m_transpose_x; } bool get_is_arg1_transposed() const { return m_transpose_x; }
Shape get_arg0_shape() const { return m_shape_w; } Shape get_arg0_shape() const { return m_shape_w; }
Shape get_arg1_shape() const { return m_shape_x; } Shape get_arg1_shape() const { return m_shape_x; }
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -45,6 +48,7 @@ namespace ngraph ...@@ -45,6 +48,7 @@ namespace ngraph
Shape m_shape_x; Shape m_shape_x;
bool m_transpose_w; bool m_transpose_w;
bool m_transpose_x; bool m_transpose_x;
AxisSet m_broadcast_axes;
}; };
} }
} }
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "ngraph/descriptor/output.hpp" #include "ngraph/descriptor/output.hpp"
#include "ngraph/ops/add.hpp" #include "ngraph/ops/add.hpp"
#include "ngraph/ops/avg_pool.hpp" #include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/convolution.hpp" #include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/max_pool.hpp" #include "ngraph/ops/max_pool.hpp"
#include "ngraph/ops/relu.hpp" #include "ngraph/ops/relu.hpp"
...@@ -265,6 +266,16 @@ namespace ngraph ...@@ -265,6 +266,16 @@ namespace ngraph
relu_bprop->set_op_annotations(op_annotations); relu_bprop->set_op_annotations(op_annotations);
} }
} }
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNorm)
{
auto batchnorm = static_cast<op::BatchNorm*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations);
}
} }
} }
} }
...@@ -277,6 +288,7 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{ ...@@ -277,6 +288,7 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::AvgPool>}, {TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop), {TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::AvgPoolBackprop>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::AvgPoolBackprop>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNorm>},
{TI(ngraph::op::Convolution), {TI(ngraph::op::Convolution),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Convolution>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Convolution>},
{TI(ngraph::op::ConvolutionBackpropData), {TI(ngraph::op::ConvolutionBackpropData),
......
...@@ -134,12 +134,21 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias_pattern() ...@@ -134,12 +134,21 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias_pattern()
<< m.match_root()->get_name(); << m.match_root()->get_name();
auto mpattern = m.match_root(); //add auto mpattern = m.match_root(); //add
auto m_matmul = mpattern->get_input_op(0); auto m_matmul = std::dynamic_pointer_cast<op::MatmulBias>(mpattern->get_input_op(0));
auto m_broadcast = mpattern->get_input_op(1); auto m_broadcast = std::dynamic_pointer_cast<op::Broadcast>(mpattern->get_input_op(1));
auto m_bias = m_broadcast->get_input_op(0);
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
return m_matmul->copy_with_new_args( auto mmb = std::make_shared<op::MatmulBias>(pattern_map[W],
NodeVector{pattern_map[W], pattern_map[x], m_broadcast}); pattern_map[x],
m_bias,
m_matmul->get_arg0_shape(),
m_matmul->get_arg1_shape(),
m_matmul->get_is_arg0_transposed(),
m_matmul->get_is_arg1_transposed(),
m_broadcast->get_broadcast_axes());
return mmb;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(padd, callback); auto m = std::make_shared<ngraph::pattern::Matcher>(padd, callback);
......
...@@ -28,8 +28,10 @@ ...@@ -28,8 +28,10 @@
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ops/add.hpp" #include "ngraph/ops/add.hpp"
#include "ngraph/ops/avg_pool.hpp" #include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/convolution.hpp" #include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/max_pool.hpp" #include "ngraph/ops/max_pool.hpp"
#include "ngraph/ops/get_output_element.hpp"
#include "ngraph/ops/op.hpp" #include "ngraph/ops/op.hpp"
#include "ngraph/ops/relu.hpp" #include "ngraph/ops/relu.hpp"
#include "ngraph/ops/result.hpp" #include "ngraph/ops/result.hpp"
...@@ -796,6 +798,17 @@ namespace ngraph ...@@ -796,6 +798,17 @@ namespace ngraph
set_output_layouts(node, prim_output_formats); set_output_layouts(node, prim_output_formats);
} }
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::GetOutputElement)
{
auto goe = static_cast<const ngraph::op::GetOutputElement*>(node.get());
auto input_layout = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(
node.get(), goe->get_n());
vector<memory::format> prim_output_formats;
prim_output_formats.push_back(input_layout);
set_output_layouts(node, prim_output_formats);
}
template <> template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Relu) void CPULayout::LAYOUT_DECL(ngraph::op::Relu)
{ {
...@@ -836,6 +849,32 @@ namespace ngraph ...@@ -836,6 +849,32 @@ namespace ngraph
} }
} }
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::BatchNorm)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 2);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(input_layout);
prim_output_formats.push_back(input_layout);
prim_output_formats.push_back(memory::format::x);
prim_output_formats.push_back(memory::format::x);
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
throw ngraph_error("Batchnorm only supported in MKLDNN for now");
}
}
template <> template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Add) void CPULayout::LAYOUT_DECL(ngraph::op::Add)
{ {
...@@ -878,6 +917,9 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{ ...@@ -878,6 +917,9 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::MaxPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::MaxPool>}, {TI(ngraph::op::MaxPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::MaxPool>},
{TI(ngraph::op::MaxPoolBackprop), {TI(ngraph::op::MaxPoolBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::MaxPoolBackprop>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::MaxPoolBackprop>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNorm>},
{TI(ngraph::op::GetOutputElement),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::GetOutputElement>},
{TI(ngraph::op::Relu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Relu>}, {TI(ngraph::op::Relu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Relu>},
{TI(ngraph::op::Result), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Result>}, {TI(ngraph::op::Result), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Result>},
{TI(ngraph::op::ReluBackprop), {TI(ngraph::op::ReluBackprop),
......
...@@ -91,11 +91,89 @@ TEST(cpu_fusion, gemm_pattern) ...@@ -91,11 +91,89 @@ TEST(cpu_fusion, gemm_pattern)
ASSERT_EQ(n.get_pattern_map()[x], B); ASSERT_EQ(n.get_pattern_map()[x], B);
ASSERT_EQ(n.get_pattern_map()[b], C); ASSERT_EQ(n.get_pattern_map()[b], C);
auto cg = auto cg = make_shared<op::MatmulBias>(
make_shared<op::MatmulBias>(W, x, broadcast, W->get_shape(), x->get_shape(), false, false); W, x, C, W->get_shape(), x->get_shape(), false, false, AxisSet{0});
}
TEST(cpu_fusion, gemm_cpu_broadcast_row)
{
Shape shapeA{3, 2};
Shape shapeB{2, 3};
Shape shapeC{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shapeA);
auto B = make_shared<op::Parameter>(element::f32, shapeB);
auto reshape_w = make_shared<op::Reshape>(A, AxisVector{1, 0}, Shape{2, 3});
auto reshape_x = make_shared<op::Reshape>(B, AxisVector{1, 0}, Shape{3, 2});
auto one = op::Constant::create<float>(element::f32, Shape{2}, std::vector<float>{1.0f, 1.0f});
auto broadcast = make_shared<op::Broadcast>(one, shapeC, AxisSet{0});
auto cg = make_shared<op::MatmulBias>(
A, B, one, A->get_shape(), B->get_shape(), true, true, AxisSet{0});
auto f = make_shared<Function>(cg, op::ParameterVector{A, B});
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
shared_ptr<runtime::TensorView> a = backend->make_primary_tensor_view(element::f32, shapeA);
shared_ptr<runtime::TensorView> b = backend->make_primary_tensor_view(element::f32, shapeB);
shared_ptr<runtime::TensorView> result =
backend->make_primary_tensor_view(element::f32, shapeC);
vector<float> dataA{1.0f, 4.0f, 1.0f, 4.0f, 1.0f, 4.0f};
vector<float> dataB{3.0f, 3.0f, 3.0f, 9.0f, 9.0f, 9.0f};
copy_data(a, dataA);
copy_data(b, dataB);
cf->call({a, b}, {result});
vector<float> expected{10, 28, 37, 109};
ASSERT_TRUE(read_vector<float>(result) == expected);
} }
TEST(cpu_fusion, gemm_cpu) TEST(cpu_fusion, gemm_cpu_broadcast_column)
{
Shape shapeA{3, 2};
Shape shapeB{2, 3};
Shape shapeC{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shapeA);
auto B = make_shared<op::Parameter>(element::f32, shapeB);
auto reshape_w = make_shared<op::Reshape>(A, AxisVector{1, 0}, Shape{2, 3});
auto reshape_x = make_shared<op::Reshape>(B, AxisVector{1, 0}, Shape{3, 2});
auto one = op::Constant::create<float>(element::f32, Shape{2}, std::vector<float>{1.0f, 1.0f});
auto broadcast = make_shared<op::Broadcast>(one, shapeC, AxisSet{1});
auto cg = make_shared<op::MatmulBias>(
A, B, one, A->get_shape(), B->get_shape(), true, true, AxisSet{1});
auto f = make_shared<Function>(cg, op::ParameterVector{A, B});
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
shared_ptr<runtime::TensorView> a = backend->make_primary_tensor_view(element::f32, shapeA);
shared_ptr<runtime::TensorView> b = backend->make_primary_tensor_view(element::f32, shapeB);
shared_ptr<runtime::TensorView> result =
backend->make_primary_tensor_view(element::f32, shapeC);
vector<float> dataA{1.0f, 4.0f, 1.0f, 4.0f, 1.0f, 4.0f};
vector<float> dataB{3.0f, 3.0f, 3.0f, 9.0f, 9.0f, 9.0f};
copy_data(a, dataA);
copy_data(b, dataB);
cf->call({a, b}, {result});
vector<float> expected{10, 28, 37, 109};
ASSERT_TRUE(read_vector<float>(result) == expected);
}
TEST(cpu_fusion, gemm_cpu_broadcast_matrix)
{ {
Shape shapeA{3, 2}; Shape shapeA{3, 2};
Shape shapeB{2, 3}; Shape shapeB{2, 3};
...@@ -109,8 +187,8 @@ TEST(cpu_fusion, gemm_cpu) ...@@ -109,8 +187,8 @@ TEST(cpu_fusion, gemm_cpu)
auto one = op::Constant::create<float>(element::f32, Shape{}, std::vector<float>{1.0f}); auto one = op::Constant::create<float>(element::f32, Shape{}, std::vector<float>{1.0f});
auto broadcast = make_shared<op::Broadcast>(one, shapeC, AxisSet{0, 1}); auto broadcast = make_shared<op::Broadcast>(one, shapeC, AxisSet{0, 1});
auto cg = auto cg = make_shared<op::MatmulBias>(
make_shared<op::MatmulBias>(A, B, broadcast, A->get_shape(), B->get_shape(), true, true); A, B, one, A->get_shape(), B->get_shape(), true, true, AxisSet{0, 1});
auto f = make_shared<Function>(cg, op::ParameterVector{A, B}); auto f = make_shared<Function>(cg, op::ParameterVector{A, B});
...@@ -212,7 +290,7 @@ TEST(cpu_fusion, cpu_fusion_pass_matmul_bias) ...@@ -212,7 +290,7 @@ TEST(cpu_fusion, cpu_fusion_pass_matmul_bias)
pass_manager.run_passes(func); pass_manager.run_passes(func);
auto gmm = graph->get_input_op(0); auto gmm = graph->get_input_op(0);
ASSERT_TRUE(std::dynamic_pointer_cast<op::MatmulBias>(gmm)); ASSERT_TRUE(std::dynamic_pointer_cast<op::MatmulBias>(gmm));
ASSERT_EQ(gmm->get_input_op(2), broadcast); ASSERT_EQ(gmm->get_input_op(2), b);
} }
TEST(cpu_fusion, cpu_fusion_pass_matmul_no_bias) TEST(cpu_fusion, cpu_fusion_pass_matmul_no_bias)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment