Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
d4a12feb
Commit
d4a12feb
authored
Dec 20, 2019
by
baojun
Committed by
Sang Ik Lee
Dec 20, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Update fluid matmul fprop (#4099)
* update matmul fprop * clean up
parent
d2c6c27e
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
110 additions
and
127 deletions
+110
-127
matmul.cpp
src/ngraph/frontend/fluid/operators/matmul.cpp
+110
-127
No files found.
src/ngraph/frontend/fluid/operators/matmul.cpp
View file @
d4a12feb
...
...
@@ -16,8 +16,6 @@
#include <memory>
#include <numeric>
#include "ngraph/builder/matmul_factory.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/frontend/fluid/operators/matmul.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/dot.hpp"
...
...
@@ -29,131 +27,6 @@
using
namespace
std
;
using
namespace
ngraph
::
fluid
;
constexpr
NodeTypeInfo
MatMul
::
type_info
;
MatMul
::
MatMul
(
const
Output
<
Node
>&
A
,
const
Output
<
Node
>&
B
,
const
bool
transpose_a
,
const
bool
transpose_b
)
:
FusedOp
(
OutputVector
{
A
,
B
})
,
m_transpose_a
{
transpose_a
}
,
m_transpose_b
{
transpose_b
}
{
constructor_validate_and_infer_types
();
}
void
decompose_logic
(
Output
<
Node
>&
input
,
bool
transpose
,
bool
reverse
=
false
)
{
auto
rank
=
input
.
get_shape
().
size
();
if
(
rank
<
2
)
{
if
(
rank
)
{
if
(
reverse
)
{
input
=
make_shared
<
op
::
Reshape
>
(
input
,
AxisVector
{
0
},
Shape
{
input
.
get_shape
()[
0
],
1
});
}
else
{
input
=
make_shared
<
op
::
Reshape
>
(
input
,
AxisVector
{
0
},
Shape
{
1
,
input
.
get_shape
()[
0
]});
}
}
else
{
input
=
make_shared
<
op
::
Reshape
>
(
input
,
AxisVector
{},
Shape
{
1
,
1
});
}
rank
=
2
;
}
if
(
transpose
)
{
vector
<
size_t
>
axes_order
(
rank
);
iota
(
axes_order
.
begin
(),
axes_order
.
end
(),
0
);
swap
(
axes_order
[
rank
-
1
],
axes_order
[
rank
-
2
]);
input
=
builder
::
reorder_axes
(
input
,
axes_order
);
}
}
NodeVector
remove_1
(
shared_ptr
<
Node
>
input_node
)
{
auto
input_shape
=
input_node
->
get_shape
();
AxisVector
axis
(
input_shape
.
size
());
iota
(
axis
.
begin
(),
axis
.
end
(),
0
);
Shape
shape
(
input_shape
.
begin
(),
input_shape
.
end
());
auto
b_remove
=
remove
(
shape
.
begin
(),
shape
.
end
(),
1
);
shape
.
erase
(
b_remove
,
shape
.
end
());
Output
<
Node
>
node
(
input_node
);
auto
reshape
=
make_shared
<
op
::
Reshape
>
(
node
,
axis
,
shape
);
NodeVector
final_vector
{
reshape
};
return
final_vector
;
}
void
MatMul
::
pre_validate_and_infer_types
()
{
element
::
Type
input_element_type
=
get_input_element_type
(
0
);
PartialShape
pshape_A
=
get_input_partial_shape
(
0
);
PartialShape
pshape_B
=
get_input_partial_shape
(
1
);
NODE_VALIDATION_CHECK
(
this
,
input_element_type
.
is_dynamic
()
||
input_element_type
.
is_real
(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got "
,
input_element_type
,
")."
);
if
(
pshape_A
.
is_dynamic
()
||
pshape_B
.
is_dynamic
())
{
set_output_type
(
0
,
input_element_type
,
PartialShape
::
dynamic
());
}
}
NodeVector
MatMul
::
decompose_op
()
const
{
auto
A
=
input_value
(
0
);
auto
B
=
input_value
(
1
);
decompose_logic
(
A
,
m_transpose_a
);
decompose_logic
(
B
,
m_transpose_b
,
true
);
builder
::
MatmulFactory
factory
({
A
,
B
});
auto
node_vector_matmul
=
factory
.
make_matmul_op
();
auto
first_item_node_vector
=
node_vector_matmul
[
0
];
auto
b
=
first_item_node_vector
->
get_shape
().
begin
();
auto
e
=
first_item_node_vector
->
get_shape
().
end
();
auto
it
=
find
(
b
,
e
,
1
);
if
(
it
!=
e
)
{
node_vector_matmul
=
remove_1
(
first_item_node_vector
);
}
return
node_vector_matmul
;
}
shared_ptr
<
Node
>
MatMul
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
return
make_shared
<
MatMul
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
m_transpose_a
,
m_transpose_b
);
}
constexpr
NodeTypeInfo
MatMulGrad
::
type_info
;
MatMulGrad
::
MatMulGrad
(
const
Output
<
Node
>&
A
,
const
Output
<
Node
>&
B
,
const
Output
<
Node
>&
Out
,
const
bool
transpose_a
,
const
bool
transpose_b
)
:
FusedOp
(
OutputVector
{
A
,
B
,
Out
})
,
m_transpose_a
{
transpose_a
}
,
m_transpose_b
{
transpose_b
}
{
constructor_validate_and_infer_types
();
}
shared_ptr
<
Node
>
broadcast_to_3d
(
const
shared_ptr
<
Node
>&
input
,
size_t
axis0
)
{
auto
shape
=
input
->
get_shape
();
...
...
@@ -249,6 +122,116 @@ shared_ptr<Node> reshape_to_original(shared_ptr<Node> input, const Shape& shape)
return
make_shared
<
op
::
Reshape
>
(
input
,
get_default_order
(
input_shape
),
shape
);
}
constexpr
NodeTypeInfo
MatMul
::
type_info
;
MatMul
::
MatMul
(
const
Output
<
Node
>&
A
,
const
Output
<
Node
>&
B
,
const
bool
transpose_a
,
const
bool
transpose_b
)
:
FusedOp
(
OutputVector
{
A
,
B
})
,
m_transpose_a
{
transpose_a
}
,
m_transpose_b
{
transpose_b
}
{
constructor_validate_and_infer_types
();
}
void
MatMul
::
pre_validate_and_infer_types
()
{
element
::
Type
input_element_type
=
get_input_element_type
(
0
);
PartialShape
pshape_A
=
get_input_partial_shape
(
0
);
PartialShape
pshape_B
=
get_input_partial_shape
(
1
);
NODE_VALIDATION_CHECK
(
this
,
input_element_type
.
is_dynamic
()
||
input_element_type
.
is_real
(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got "
,
input_element_type
,
")."
);
if
(
pshape_A
.
is_dynamic
()
||
pshape_B
.
is_dynamic
())
{
set_output_type
(
0
,
input_element_type
,
PartialShape
::
dynamic
());
}
}
NodeVector
MatMul
::
decompose_op
()
const
{
auto
x
=
input_value
(
0
).
get_node_shared_ptr
();
auto
y
=
input_value
(
1
).
get_node_shared_ptr
();
auto
x_shape
=
x
->
get_shape
();
auto
y_shape
=
y
->
get_shape
();
size_t
nx
=
x_shape
.
size
();
size_t
ny
=
y_shape
.
size
();
x
=
transpose_and_flatten3d
(
x
,
m_transpose_a
,
true
);
y
=
transpose_and_flatten3d
(
y
,
m_transpose_b
,
false
);
auto
y_shape3
=
y
->
get_shape
();
auto
x_shape3
=
x
->
get_shape
();
shared_ptr
<
Node
>
out
;
Shape
out_shape
;
if
(
nx
>
2
||
ny
>
2
)
{
Shape
out_shape
=
x_shape
;
if
(
nx
!=
3
)
{
x
=
broadcast_to_3d
(
x
,
y_shape3
[
0
]);
out_shape
=
y_shape
;
}
if
(
ny
!=
3
)
{
y
=
broadcast_to_3d
(
y
,
x_shape3
[
0
]);
out_shape
=
x_shape
;
}
auto
nout
=
out_shape
.
size
();
auto
out3
=
make_shared
<
op
::
BatchMatMul
>
(
x
,
y
);
auto
out3_shape
=
out3
->
get_shape
();
out_shape
[
nout
-
1
]
=
out3_shape
[
2
];
out_shape
[
nout
-
2
]
=
out3_shape
[
1
];
out
=
make_shared
<
op
::
Reshape
>
(
out3
,
AxisVector
{
0
,
1
,
2
},
out_shape
);
}
else
{
out
=
make_shared
<
op
::
Dot
>
(
x
,
y
);
}
out_shape
=
out
->
get_shape
();
auto
axis_vector
=
get_default_order
(
out_shape
);
for
(
size_t
i
=
out_shape
.
size
()
-
1
;
i
>
0
;
i
--
)
{
if
(
out_shape
[
i
]
==
1
)
{
out_shape
.
erase
(
out_shape
.
begin
()
+
i
);
}
}
auto
out_reshaped
=
make_shared
<
op
::
Reshape
>
(
out
,
axis_vector
,
out_shape
);
return
{
out_reshaped
};
}
shared_ptr
<
Node
>
MatMul
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
return
make_shared
<
MatMul
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
m_transpose_a
,
m_transpose_b
);
}
constexpr
NodeTypeInfo
MatMulGrad
::
type_info
;
MatMulGrad
::
MatMulGrad
(
const
Output
<
Node
>&
A
,
const
Output
<
Node
>&
B
,
const
Output
<
Node
>&
Out
,
const
bool
transpose_a
,
const
bool
transpose_b
)
:
FusedOp
(
OutputVector
{
A
,
B
,
Out
})
,
m_transpose_a
{
transpose_a
}
,
m_transpose_b
{
transpose_b
}
{
constructor_validate_and_infer_types
();
}
void
MatMulGrad
::
pre_validate_and_infer_types
()
{
element
::
Type
input_element_type
=
get_input_element_type
(
0
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment