Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
e3b442aa
Commit
e3b442aa
authored
Dec 17, 2019
by
baojun
Committed by
Scott Cyphers
Dec 17, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
set output type for dynshape (#4072)
parent
72493caf
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
13 deletions
+30
-13
matmul.cpp
src/ngraph/frontend/fluid/operators/matmul.cpp
+23
-9
matmul.hpp
src/ngraph/frontend/fluid/operators/matmul.hpp
+7
-4
No files found.
src/ngraph/frontend/fluid/operators/matmul.cpp
View file @
e3b442aa
...
@@ -13,18 +13,19 @@
...
@@ -13,18 +13,19 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
//*****************************************************************************
//*****************************************************************************
#include "ngraph/frontend/fluid/operators/matmul.hpp"
#include <memory>
#include <memory>
#include <numeric>
#include <numeric>
#include "ngraph/builder/matmul_factory.hpp"
#include "ngraph/builder/matmul_factory.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/frontend/fluid/operators/matmul.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reshape.hpp"
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
::
fluid
;
constexpr
NodeTypeInfo
fluid
::
MatMul
::
type_info
;
constexpr
NodeTypeInfo
MatMul
::
type_info
;
fluid
::
MatMul
::
MatMul
(
const
Output
<
Node
>&
A
,
MatMul
::
MatMul
(
const
Output
<
Node
>&
A
,
const
Output
<
Node
>&
B
,
const
Output
<
Node
>&
B
,
const
bool
&
transpose_a
,
const
bool
&
transpose_a
,
const
bool
&
transpose_b
)
const
bool
&
transpose_b
)
...
@@ -39,6 +40,7 @@ template <class Input>
...
@@ -39,6 +40,7 @@ template <class Input>
void
DecomposeLogic
(
Input
&
input
,
bool
transpose
,
bool
reverse
=
false
)
void
DecomposeLogic
(
Input
&
input
,
bool
transpose
,
bool
reverse
=
false
)
{
{
auto
rank
=
input
.
get_shape
().
size
();
auto
rank
=
input
.
get_shape
().
size
();
if
(
rank
<
2
)
if
(
rank
<
2
)
{
{
if
(
rank
)
if
(
rank
)
...
@@ -60,6 +62,7 @@ void DecomposeLogic(Input& input, bool transpose, bool reverse = false)
...
@@ -60,6 +62,7 @@ void DecomposeLogic(Input& input, bool transpose, bool reverse = false)
}
}
rank
=
2
;
rank
=
2
;
}
}
if
(
transpose
)
if
(
transpose
)
{
{
vector
<
size_t
>
axes_order
(
rank
);
vector
<
size_t
>
axes_order
(
rank
);
...
@@ -75,48 +78,59 @@ inline NodeVector remove_1(std::shared_ptr<ngraph::Node> input_node)
...
@@ -75,48 +78,59 @@ inline NodeVector remove_1(std::shared_ptr<ngraph::Node> input_node)
AxisVector
axis
(
input_shape
.
size
());
AxisVector
axis
(
input_shape
.
size
());
iota
(
axis
.
begin
(),
axis
.
end
(),
0
);
iota
(
axis
.
begin
(),
axis
.
end
(),
0
);
Shape
shape
(
input_shape
.
begin
(),
input_shape
.
end
());
Shape
shape
(
input_shape
.
begin
(),
input_shape
.
end
());
auto
b_remove
=
std
::
remove
(
shape
.
begin
(),
shape
.
end
(),
1
);
auto
b_remove
=
std
::
remove
(
shape
.
begin
(),
shape
.
end
(),
1
);
shape
.
erase
(
b_remove
,
shape
.
end
());
shape
.
erase
(
b_remove
,
shape
.
end
());
Output
<
Node
>
node
(
input_node
);
Output
<
Node
>
node
(
input_node
);
auto
reshape
=
make_shared
<
op
::
Reshape
>
(
node
,
axis
,
shape
);
auto
reshape
=
make_shared
<
op
::
Reshape
>
(
node
,
axis
,
shape
);
NodeVector
final_vector
{
reshape
};
NodeVector
final_vector
{
reshape
};
return
final_vector
;
return
final_vector
;
}
}
void
fluid
::
MatMul
::
pre_validate_and_infer_types
()
void
MatMul
::
pre_validate_and_infer_types
()
{
{
element
::
Type
input_element_type
=
get_input_element_type
(
0
);
element
::
Type
input_element_type
=
get_input_element_type
(
0
);
PartialShape
pshape_A
=
get_input_partial_shape
(
0
);
PartialShape
pshape_B
=
get_input_partial_shape
(
1
);
NODE_VALIDATION_CHECK
(
this
,
NODE_VALIDATION_CHECK
(
this
,
input_element_type
.
is_dynamic
()
||
input_element_type
.
is_real
(),
input_element_type
.
is_dynamic
()
||
input_element_type
.
is_real
(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got "
,
"Argument element type must be f16, bf16, f32, f64 or dynamic (got "
,
input_element_type
,
input_element_type
,
")."
);
")."
);
if
(
is_dynamic
())
if
(
pshape_A
.
is_dynamic
()
||
pshape_B
.
is_dynamic
())
{
{
set_output_type
(
0
,
get_input_element_type
(
0
)
,
PartialShape
::
dynamic
());
set_output_type
(
0
,
input_element_type
,
PartialShape
::
dynamic
());
}
}
}
}
NodeVector
fluid
::
MatMul
::
decompose_op
()
const
NodeVector
MatMul
::
decompose_op
()
const
{
{
auto
A
=
input_value
(
0
);
auto
A
=
input_value
(
0
);
auto
B
=
input_value
(
1
);
auto
B
=
input_value
(
1
);
DecomposeLogic
(
A
,
m_transpose_a
);
DecomposeLogic
(
A
,
m_transpose_a
);
DecomposeLogic
(
B
,
m_transpose_b
,
true
);
DecomposeLogic
(
B
,
m_transpose_b
,
true
);
builder
::
MatmulFactory
factory
({
A
,
B
});
builder
::
MatmulFactory
factory
({
A
,
B
});
auto
node_vector_matmul
=
factory
.
make_matmul_op
();
auto
node_vector_matmul
=
factory
.
make_matmul_op
();
auto
first_item_node_vector
=
node_vector_matmul
[
0
];
auto
first_item_node_vector
=
node_vector_matmul
[
0
];
auto
b
=
first_item_node_vector
->
get_shape
().
begin
();
auto
b
=
first_item_node_vector
->
get_shape
().
begin
();
auto
e
=
first_item_node_vector
->
get_shape
().
end
();
auto
e
=
first_item_node_vector
->
get_shape
().
end
();
auto
it
=
std
::
find
(
b
,
e
,
1
);
auto
it
=
std
::
find
(
b
,
e
,
1
);
if
(
it
!=
e
)
if
(
it
!=
e
)
{
{
node_vector_matmul
=
remove_1
(
first_item_node_vector
);
node_vector_matmul
=
remove_1
(
first_item_node_vector
);
}
}
return
node_vector_matmul
;
return
node_vector_matmul
;
}
}
shared_ptr
<
Node
>
fluid
::
MatMul
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
MatMul
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
check_new_args_count
(
this
,
new_args
);
check_new_args_count
(
this
,
new_args
);
return
make_shared
<
MatMul
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
m_transpose_a
,
m_transpose_b
);
return
make_shared
<
MatMul
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
m_transpose_a
,
m_transpose_b
);
...
...
src/ngraph/frontend/fluid/operators/matmul.hpp
View file @
e3b442aa
...
@@ -20,18 +20,21 @@
...
@@ -20,18 +20,21 @@
#include "ngraph/op/op.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/op/util/fused_op.hpp"
using
namespace
std
;
using
namespace
ngraph
;
namespace
ngraph
namespace
ngraph
{
{
namespace
fluid
namespace
fluid
{
{
/// \brief Operator performing Matrix Multiplication.
/// \brief Operator performing Matrix Multiplication.
class
NGRAPH_API
MatMul
:
public
ngraph
::
op
::
util
::
FusedOp
class
NGRAPH_API
MatMul
:
public
op
::
util
::
FusedOp
{
{
public
:
public
:
static
constexpr
NodeTypeInfo
type_info
{
"MatMul"
,
0
};
static
constexpr
NodeTypeInfo
type_info
{
"MatMul"
,
0
};
const
NodeTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
const
NodeTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
MatMul
()
=
default
;
MatMul
()
=
default
;
/// \brief Constructs a
n ScaleShift
operation.
/// \brief Constructs a
MatMul
operation.
///
///
/// \param A Matrix A
/// \param A Matrix A
/// \param B Matrix B
/// \param B Matrix B
...
@@ -43,10 +46,10 @@ namespace ngraph
...
@@ -43,10 +46,10 @@ namespace ngraph
const
bool
&
transpose_b
=
0
);
const
bool
&
transpose_b
=
0
);
virtual
NodeVector
decompose_op
()
const
override
;
virtual
NodeVector
decompose_op
()
const
override
;
void
pre_validate_and_infer_types
()
override
;
void
pre_validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
virtual
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
bool
get_transpose_a
()
const
{
return
m_transpose_a
;
}
bool
get_transpose_a
()
const
{
return
m_transpose_a
;
}
bool
get_transpose_b
()
const
{
return
m_transpose_b
;
}
bool
get_transpose_b
()
const
{
return
m_transpose_b
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment