Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
6bca3efd
Commit
6bca3efd
authored
Aug 01, 2018
by
Nick Korovaiko
Committed by
Scott Cyphers
Aug 01, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Refactoring MMB (#1224)
* rank3xrank2 cpu_emitter version 1 * refactoring matmulbias * add comment
parent
5927bbe4
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
241 additions
and
183 deletions
+241
-183
matmul_bias.cpp
src/ngraph/runtime/cpu/builder/matmul_bias.cpp
+4
-4
cpu_emitter.cpp
src/ngraph/runtime/cpu/cpu_emitter.cpp
+227
-171
batch_dot.hpp
src/ngraph/runtime/cpu/op/batch_dot.hpp
+2
-0
matmul_bias.hpp
src/ngraph/runtime/cpu/op/matmul_bias.hpp
+4
-4
cpu_fusion.cpp
src/ngraph/runtime/cpu/pass/cpu_fusion.cpp
+4
-4
No files found.
src/ngraph/runtime/cpu/builder/matmul_bias.cpp
View file @
6bca3efd
...
...
@@ -39,8 +39,8 @@ namespace ngraph
const
ngraph
::
op
::
MatmulBias
*
mm
=
static_cast
<
const
ngraph
::
op
::
MatmulBias
*>
(
node
);
const
auto
&
arg0_shape
=
mm
->
get_a
rg0
_shape
();
const
auto
&
arg1_shape
=
mm
->
get_
arg1
_shape
();
const
auto
&
arg0_shape
=
mm
->
get_a_shape
();
const
auto
&
arg1_shape
=
mm
->
get_
b
_shape
();
const
auto
&
arg2_shape
=
node
->
get_shape
();
auto
m
=
arg0_shape
[
0
];
...
...
@@ -51,14 +51,14 @@ namespace ngraph
auto
lda
=
arg0_shape
[
1
];
auto
ldb
=
arg1_shape
[
1
];
if
(
mm
->
get_is_a
rg0
_transposed
())
if
(
mm
->
get_is_a_transposed
())
{
transpose_A
=
true
;
m
=
arg0_shape
[
1
];
k
=
arg0_shape
[
0
];
}
if
(
mm
->
get_is_
arg1
_transposed
())
if
(
mm
->
get_is_
b
_transposed
())
{
transpose_B
=
true
;
n
=
arg1_shape
[
0
];
...
...
src/ngraph/runtime/cpu/cpu_emitter.cpp
View file @
6bca3efd
...
...
@@ -233,159 +233,19 @@ namespace ngraph
}
#endif
template
<>
void
CPU_Emitter
::
EMITTER_DECL
(
ngraph
::
op
::
MatmulBias
)
{
const
ngraph
::
op
::
MatmulBias
*
cg
=
static_cast
<
const
ngraph
::
op
::
MatmulBias
*>
(
node
);
const
Shape
&
arg0_shape
=
cg
->
get_arg0_shape
();
//W
const
Shape
&
arg1_shape
=
cg
->
get_arg1_shape
();
//x
const
Shape
&
arg2_shape
=
node
->
get_shape
();
//bias (C)
static
const
char
*
ctranspose
=
"cblas::Transpose::Transpose, "
;
static
const
char
*
cnotranspose
=
"cblas::Transpose::None, "
;
size_t
m
=
arg0_shape
[
0
];
size_t
n
=
arg1_shape
[
1
];
size_t
k
=
arg0_shape
[
1
];
//
const
char
*
tranpose_a
=
cnotranspose
;
const
char
*
tranpose_b
=
cnotranspose
;
size_t
lda
=
arg0_shape
[
1
];
size_t
ldb
=
arg1_shape
[
1
];
if
(
cg
->
get_is_arg0_transposed
())
{
tranpose_a
=
ctranspose
;
m
=
arg0_shape
[
1
];
k
=
arg0_shape
[
0
];
}
if
(
cg
->
get_is_arg1_transposed
())
{
tranpose_b
=
ctranspose
;
n
=
arg1_shape
[
0
];
}
writer
.
block_begin
();
const
char
*
cbeta
=
"0.0f"
;
writer
<<
"cblas::cblas_sgemm("
<<
"cblas::Layout::RowMajor, "
<<
tranpose_a
<<
tranpose_b
<<
m
<<
", "
<<
n
<<
", "
<<
k
<<
",
\n
"
<<
" 1.0f, "
<<
args
[
0
].
get_name
()
<<
", "
<<
max
(
1UL
,
lda
)
<<
", "
<<
args
[
1
].
get_name
()
<<
", "
<<
max
(
1UL
,
ldb
)
<<
", "
<<
cbeta
<<
",
\n
"
<<
" "
<<
out
[
0
].
get_name
()
<<
", "
<<
max
(
1UL
,
arg2_shape
[
1
])
<<
");
\n
"
;
if
(
args
.
size
()
>
2
)
void
emitCblasSgemmBatch
(
codegen
::
CodeWriter
&
writer
,
const
Shape
&
shape_a
,
const
Shape
&
shape_b
,
const
Shape
&
shape_c
,
bool
transpose_a
,
bool
transpose_b
,
const
std
::
string
&
data_a
,
const
std
::
string
&
data_b
,
const
std
::
string
&
data_c
,
const
std
::
string
&
alpha
,
const
std
::
string
&
beta
,
size_t
group_size
)
{
auto
axes
=
cg
->
get_broadcast_axes
();
if
(
axes
.
size
()
==
1
)
{
if
(
*
(
axes
.
begin
())
==
0
)
{
writer
<<
"static "
<<
out
[
0
].
get_element_type
().
c_type_string
()
<<
" ones_row["
<<
arg2_shape
[
0
]
<<
"]"
<<
" = { 1.0f"
;
for
(
size_t
i
=
1
;
i
<
arg2_shape
[
0
];
++
i
)
{
writer
<<
", 1.0f"
;
}
writer
<<
"};
\n
"
;
writer
<<
"cblas::cblas_sgemm("
<<
"cblas::Layout::RowMajor, "
<<
cnotranspose
<<
cnotranspose
<<
arg2_shape
[
0
]
<<
", "
<<
arg2_shape
[
1
]
<<
", 1"
<<
",
\n
"
<<
" 1.0f, ones_row, "
<<
"1"
<<
", "
<<
args
[
2
].
get_name
()
<<
", "
<<
max
(
1UL
,
arg2_shape
[
1
])
<<
", "
<<
"1.0f"
<<
",
\n
"
<<
" "
<<
out
[
0
].
get_name
()
<<
", "
<<
max
(
1UL
,
arg2_shape
[
1
])
<<
");
\n
"
;
}
else
{
writer
<<
"static "
<<
out
[
0
].
get_element_type
().
c_type_string
()
<<
" ones_col["
<<
arg2_shape
[
1
]
<<
"]"
<<
" = { 1.0f"
;
for
(
size_t
i
=
1
;
i
<
arg2_shape
[
1
];
++
i
)
{
writer
<<
", 1.0f"
;
}
writer
<<
"};
\n
"
;
writer
<<
"cblas::cblas_sgemm("
<<
"cblas::Layout::RowMajor, "
<<
cnotranspose
<<
cnotranspose
<<
arg2_shape
[
0
]
<<
", "
<<
arg2_shape
[
1
]
<<
", 1,
\n
"
<<
"1.0f, "
<<
args
[
2
].
get_name
()
<<
", 1, "
<<
"ones_col, "
<<
max
(
1UL
,
arg2_shape
[
1
])
<<
", "
<<
"1.0f"
<<
",
\n
"
<<
" "
<<
out
[
0
].
get_name
()
<<
", "
<<
max
(
1UL
,
arg2_shape
[
1
])
<<
");
\n
"
;
}
}
else
{
if
(
axes
.
size
()
!=
2
)
{
throw
ngraph_error
(
"unexpected broadcast rank"
);
}
writer
<<
out
[
0
].
get_element_type
().
c_type_string
()
<<
" bias["
<<
arg2_shape
[
1
]
<<
"]"
<<
" = { "
<<
args
[
2
].
get_name
()
<<
"[0]"
;
for
(
size_t
i
=
1
;
i
<
arg2_shape
[
1
];
++
i
)
{
writer
<<
","
<<
args
[
2
].
get_name
()
<<
"[0]"
;
}
writer
<<
"};
\n
"
;
writer
<<
"static "
<<
out
[
0
].
get_element_type
().
c_type_string
()
<<
" ones_scalar["
<<
arg2_shape
[
0
]
<<
"]"
<<
" = { 1.0f"
;
for
(
size_t
i
=
1
;
i
<
arg2_shape
[
0
];
++
i
)
{
writer
<<
", 1.0f"
;
}
writer
<<
"};
\n
"
;
writer
<<
"cblas::cblas_sgemm("
<<
"cblas::Layout::RowMajor, "
<<
cnotranspose
<<
cnotranspose
<<
arg2_shape
[
0
]
<<
", "
<<
arg2_shape
[
1
]
<<
", 1"
<<
",
\n
"
<<
" 1.0f, ones_scalar, "
<<
"1"
<<
", "
<<
"bias"
<<
", "
<<
max
(
1UL
,
arg2_shape
[
1
])
<<
", "
<<
"1.0f"
<<
",
\n
"
<<
" "
<<
out
[
0
].
get_name
()
<<
", "
<<
max
(
1UL
,
arg2_shape
[
1
])
<<
");
\n
"
;
}
}
writer
.
block_end
();
}
template
<>
void
CPU_Emitter
::
EMITTER_DECL
(
ngraph
::
op
::
BatchDot
)
{
const
ngraph
::
op
::
BatchDot
*
batch_dot
=
static_cast
<
const
ngraph
::
op
::
BatchDot
*>
(
node
);
auto
mat_a
=
args
[
0
];
auto
mat_b
=
args
[
1
];
auto
mat_c
=
out
[
0
];
const
Shape
&
shape_a
=
mat_a
.
get_shape
();
const
Shape
&
shape_b
=
mat_b
.
get_shape
();
static
const
char
*
cblas_transpose
=
"cblas::Transpose::Transpose"
;
static
const
char
*
cblas_no_transpose
=
"cblas::Transpose::None"
;
...
...
@@ -394,31 +254,30 @@ namespace ngraph
size_t
n
=
shape_b
[
2
];
size_t
lda
=
std
::
max
(
1UL
,
k
);
size_t
ldb
=
std
::
max
(
1UL
,
n
);
const
char
*
transpose_a
=
cblas_no_transpose
;
const
char
*
transpose_b
=
cblas_no_transpose
;
if
(
batch_dot
->
get_is_a_transposed
()
)
const
char
*
c
transpose_a
=
cblas_no_transpose
;
const
char
*
c
transpose_b
=
cblas_no_transpose
;
if
(
transpose_a
)
{
transpose_a
=
cblas_transpose
;
c
transpose_a
=
cblas_transpose
;
m
=
shape_a
[
2
];
k
=
shape_a
[
1
];
lda
=
std
::
max
(
1UL
,
m
);
}
if
(
batch_dot
->
get_is_b_transposed
()
)
if
(
transpose_b
)
{
transpose_b
=
cblas_transpose
;
c
transpose_b
=
cblas_transpose
;
n
=
shape_b
[
1
];
ldb
=
std
::
max
(
1UL
,
k
);
}
size_t
ldc
=
std
::
max
(
1UL
,
n
);
const
size_t
offset_a
=
m
*
k
;
const
size_t
offset_b
=
k
*
n
;
const
size_t
offset_c
=
m
*
n
;
const
size_t
offset_a
=
(
shape_a
.
at
(
0
)
>
1
)
?
m
*
k
:
0
;
const
size_t
offset_b
=
(
shape_b
.
at
(
0
)
>
1
)
?
k
*
n
:
0
;
const
size_t
offset_c
=
(
shape_c
.
at
(
0
)
>
1
)
?
m
*
n
:
0
;
writer
.
block_begin
();
const
size_t
group_count
=
1
;
const
size_t
group_size
=
shape_a
[
0
];
auto
populate_array
=
[
&
writer
](
const
std
::
string
&
var
,
size_t
size
,
size_t
offset
)
{
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
...
...
@@ -426,25 +285,23 @@ namespace ngraph
writer
<<
var
<<
"+"
<<
i
*
offset
<<
((
i
<
size
-
1
)
?
", "
:
""
);
}
};
writer
<<
"cblas::Transpose transa_array[] = {"
<<
transpose_a
<<
"};
\n
"
;
writer
<<
"cblas::Transpose transb_array[] = {"
<<
transpose_b
<<
"};
\n
"
;
writer
<<
"cblas::Transpose transa_array[] = {"
<<
c
transpose_a
<<
"};
\n
"
;
writer
<<
"cblas::Transpose transb_array[] = {"
<<
c
transpose_b
<<
"};
\n
"
;
writer
<<
"int64_t m_array[] = {"
<<
m
<<
"};
\n
"
;
writer
<<
"int64_t n_array[] = {"
<<
n
<<
"};
\n
"
;
writer
<<
"int64_t k_array[] = {"
<<
k
<<
"};
\n
"
;
writer
<<
"float alpha_array[] = {1.0f};
\n
"
;
writer
<<
"std::vector<const float*> a{"
;
populate_array
(
mat_a
.
get_name
()
,
group_size
,
offset_a
);
populate_array
(
data_a
,
group_size
,
offset_a
);
writer
<<
"};
\n
"
;
writer
<<
"const float** a_array = &a[0];
\n
"
;
writer
<<
"int64_t lda_array[] = {"
<<
lda
<<
"};
\n
"
;
writer
<<
"std::vector<const float*> b{"
;
populate_array
(
mat_b
.
get_name
()
,
group_size
,
offset_b
);
populate_array
(
data_b
,
group_size
,
offset_b
);
writer
<<
"};
\n
"
;
writer
<<
"const float** b_array = &b[0];
\n
"
;
writer
<<
"int64_t ldb_array[] = {"
<<
ldb
<<
"};
\n
"
;
writer
<<
"float beta_array[] = {0.0f};
\n
"
;
writer
<<
"std::vector<float*> c{"
;
populate_array
(
mat_c
.
get_name
()
,
group_size
,
offset_c
);
populate_array
(
data_c
,
group_size
,
offset_c
);
writer
<<
"};
\n
"
;
writer
<<
"float** c_array = &c[0];
\n
"
;
writer
<<
"int64_t ldc_array[] = {"
<<
ldc
<<
"};
\n
"
;
...
...
@@ -452,11 +309,210 @@ namespace ngraph
writer
<<
"cblas_sgemm_batch(cblas::Layout::RowMajor, "
;
writer
<<
"transa_array, transb_array, m_array, n_array, k_array,
\n
"
;
writer
<<
"alpha_array, a_array, lda_array, b_array, ldb_array, beta_array
,
\n
"
;
writer
<<
alpha
<<
", a_array, lda_array, b_array, ldb_array, "
<<
beta
<<
"
,
\n
"
;
writer
<<
"c_array, ldc_array, "
<<
group_count
<<
", group_size);
\n
"
;
writer
.
block_end
();
}
template
<
typename
T
>
static
void
emitBatchDot
(
const
ngraph
::
Node
*
node
,
const
Shape
&
shape_a
,
const
Shape
&
shape_b
,
const
Shape
&
shape_c
,
const
std
::
vector
<
TensorViewWrapper
>&
args
,
const
std
::
vector
<
TensorViewWrapper
>&
out
,
codegen
::
CodeWriter
&
writer
)
{
writer
.
block_begin
();
const
T
*
batch_dot
=
static_cast
<
const
T
*>
(
node
);
auto
mat_a
=
args
[
0
];
auto
mat_b
=
args
[
1
];
auto
mat_c
=
out
[
0
];
writer
<<
"float alpha_array[] = {1.0f};
\n
"
;
writer
<<
"float beta_array[] = {0.0f};
\n
"
;
const
size_t
group_size
=
shape_a
[
0
];
emitCblasSgemmBatch
(
writer
,
shape_a
,
shape_b
,
shape_c
,
batch_dot
->
get_is_a_transposed
(),
batch_dot
->
get_is_b_transposed
(),
mat_a
.
get_name
(),
mat_b
.
get_name
(),
mat_c
.
get_name
(),
"alpha_array"
,
"beta_array"
,
group_size
);
writer
.
block_end
();
}
static
Shape
pad_with
(
Shape
v
,
size_t
val
,
size_t
length
)
{
if
(
length
<=
v
.
size
())
{
return
v
;
}
Shape
tv
(
length
-
v
.
size
(),
val
);
v
.
insert
(
v
.
begin
(),
tv
.
begin
(),
tv
.
end
());
return
v
;
}
static
std
::
string
emit_constant_array
(
const
std
::
string
&
type
,
const
std
::
string
&
name
,
const
string
&
val
,
size_t
size
)
{
std
::
stringstream
writer
;
writer
<<
"static "
<<
type
<<
" "
<<
name
<<
"["
<<
size
<<
"]"
<<
" = { "
<<
val
;
for
(
size_t
i
=
1
;
i
<
size
;
++
i
)
{
writer
<<
", "
<<
val
;
}
writer
<<
"};
\n
"
;
return
writer
.
str
();
}
template
<>
void
CPU_Emitter
::
EMITTER_DECL
(
ngraph
::
op
::
MatmulBias
)
{
const
ngraph
::
op
::
MatmulBias
*
cg
=
static_cast
<
const
ngraph
::
op
::
MatmulBias
*>
(
node
);
const
Shape
&
arg0_shape
=
pad_with
(
cg
->
get_a_shape
(),
1
,
3
);
//A
const
Shape
&
arg1_shape
=
pad_with
(
cg
->
get_b_shape
(),
1
,
3
);
//B
const
Shape
&
arg2_shape
=
node
->
get_shape
();
//bias (C)
const
Shape
&
padded_result_shape
=
pad_with
(
node
->
get_shape
(),
1
,
3
);
//Step 1: dot(A,B)
emitBatchDot
<
ngraph
::
op
::
MatmulBias
>
(
node
,
arg0_shape
,
arg1_shape
,
padded_result_shape
,
args
,
out
,
writer
);
//Step 2: add bias
if
(
args
.
size
()
<
3
)
{
//no bias
return
;
}
auto
mat_c
=
args
[
2
];
//the bias argument of add(dot(A,B), broadcast(C)) is typically C
//In order to broadcast C to the same shape as dot(A,B)
//we use cblas_gemm_batch(ones, C) or cblas_gemm_batch(C, ones)
//where ones is a tensor of appropriate shape
//consisting of the identity element
// Consider an example of broadcasing a tensor of Shape{1,3}
// to Shape {4,3}
//
// [1 [1 2 3] [1 2 3
// 1 1 2 3
// 1 * 1 2 3
// 1] 1 2 3]
//The next example is broadcasting a tensor of Shape{3,1} to Shape {3,4}
//
// [1 [1 1 1 1] [1 1 1 1
// 2 * 2 2 2 2
// 3] 3 3 3 3]
writer
<<
"float alpha_beta_array[] = {1.0f};
\n
"
;
const
size_t
group_size
=
1
;
auto
axes
=
cg
->
get_broadcast_axes
();
if
(
axes
.
size
()
==
1
)
{
auto
second_broadcast_axis
=
*
axes
.
begin
();
if
(
second_broadcast_axis
==
0
)
{
writer
<<
emit_constant_array
(
out
[
0
].
get_element_type
().
c_type_string
(),
"ones"
,
"1.0f"
,
arg2_shape
.
at
(
0
));
;
emitCblasSgemmBatch
(
writer
,
Shape
{
1
,
arg2_shape
.
at
(
0
),
1
},
// ones shape
Shape
{
1
,
1
,
arg2_shape
.
at
(
1
)},
// C shape
node
->
get_shape
(),
false
,
false
,
"ones"
,
// ones
mat_c
.
get_name
(),
// C
out
[
0
].
get_name
(),
// dot(A,B)
"alpha_beta_array"
,
"alpha_beta_array"
,
group_size
);
}
else
{
writer
<<
emit_constant_array
(
out
[
0
].
get_element_type
().
c_type_string
(),
"ones"
,
"1.0f"
,
arg2_shape
.
at
(
1
));
emitCblasSgemmBatch
(
writer
,
Shape
{
1
,
arg2_shape
.
at
(
0
),
1
},
//C shape
Shape
{
1
,
1
,
arg2_shape
.
at
(
1
)},
// ones shape
node
->
get_shape
(),
false
,
// C transpose
false
,
// C shape
mat_c
.
get_name
(),
"ones"
,
out
[
0
].
get_name
(),
// dot(A,B)
"alpha_beta_array"
,
"alpha_beta_array"
,
group_size
);
}
}
else
{
if
(
axes
.
size
()
!=
2
)
{
throw
ngraph_error
(
"unexpected broadcast rank"
);
}
writer
<<
emit_constant_array
(
out
[
0
].
get_element_type
().
c_type_string
(),
"ones"
,
"1.0f"
,
arg2_shape
.
at
(
1
));
auto
bias_scalar
=
args
[
2
].
get_name
()
+
"[0]"
;
writer
<<
emit_constant_array
(
out
[
0
].
get_element_type
().
c_type_string
(),
"bias_vector"
,
bias_scalar
,
arg2_shape
.
at
(
0
));
emitCblasSgemmBatch
(
writer
,
Shape
{
1
,
arg2_shape
.
at
(
0
),
1
},
// bias_vector shape
Shape
{
1
,
1
,
arg2_shape
.
at
(
1
)},
// ones shape
node
->
get_shape
(),
false
,
// bias_vector tranpose
false
,
// ones tranpose
"bias_vector"
,
"ones"
,
out
[
0
].
get_name
(),
// dot(A,B)
"alpha_beta_array"
,
"alpha_beta_array"
,
group_size
);
}
}
template
<>
void
CPU_Emitter
::
EMITTER_DECL
(
ngraph
::
op
::
BatchDot
)
{
const
auto
*
cg
=
static_cast
<
const
ngraph
::
op
::
BatchDot
*>
(
node
);
emitBatchDot
<
ngraph
::
op
::
BatchDot
>
(
node
,
cg
->
get_a_shape
(),
cg
->
get_b_shape
(),
out
[
0
].
get_shape
(),
args
,
out
,
writer
);
}
template
<>
void
CPU_Emitter
::
EMITTER_DECL
(
ngraph
::
op
::
Lstm
)
{
...
...
src/ngraph/runtime/cpu/op/batch_dot.hpp
View file @
6bca3efd
...
...
@@ -32,6 +32,8 @@ namespace ngraph
bool
get_is_a_transposed
()
const
{
return
m_transpose_a
;
}
bool
get_is_b_transposed
()
const
{
return
m_transpose_b
;
}
Shape
get_a_shape
()
const
{
return
get_argument
(
0
)
->
get_shape
();
}
Shape
get_b_shape
()
const
{
return
get_argument
(
1
)
->
get_shape
();
}
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/runtime/cpu/op/matmul_bias.hpp
View file @
6bca3efd
...
...
@@ -35,10 +35,10 @@ namespace ngraph
bool
transpose_x
,
AxisSet
axes
=
AxisSet
{});
bool
get_is_a
rg0
_transposed
()
const
{
return
m_transpose_w
;
}
bool
get_is_
arg1
_transposed
()
const
{
return
m_transpose_x
;
}
Shape
get_a
rg0
_shape
()
const
{
return
m_shape_w
;
}
Shape
get_
arg1
_shape
()
const
{
return
m_shape_x
;
}
bool
get_is_a_transposed
()
const
{
return
m_transpose_w
;
}
bool
get_is_
b
_transposed
()
const
{
return
m_transpose_x
;
}
Shape
get_a_shape
()
const
{
return
m_shape_w
;
}
Shape
get_
b
_shape
()
const
{
return
m_shape_x
;
}
const
AxisSet
&
get_broadcast_axes
()
const
{
return
m_broadcast_axes
;
}
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/runtime/cpu/pass/cpu_fusion.cpp
View file @
6bca3efd
...
...
@@ -153,10 +153,10 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias()
auto
mmb
=
std
::
make_shared
<
op
::
MatmulBias
>
(
pattern_map
[
W
],
pattern_map
[
x
],
m_bias
,
m_matmul
->
get_a
rg0
_shape
(),
m_matmul
->
get_
arg1
_shape
(),
m_matmul
->
get_is_a
rg0
_transposed
(),
m_matmul
->
get_is_
arg1
_transposed
(),
m_matmul
->
get_a_shape
(),
m_matmul
->
get_
b
_shape
(),
m_matmul
->
get_is_a_transposed
(),
m_matmul
->
get_is_
b
_transposed
(),
m_broadcast
->
get_broadcast_axes
());
ngraph
::
replace_node
(
m
.
get_match_root
(),
mmb
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment