Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
bff65fe3
Commit
bff65fe3
authored
6 years ago
by
Nick Korovaiko
Committed by
Robert Kimball
6 years ago
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Any op (#1036)
* add any op
parent
05a4fbef
master
v0.29.0-rc.0
v0.28.0-rc.1
v0.28.0-rc.0
v0.27.1-rc.3
v0.27.1-rc.2
v0.27.1-rc.1
v0.27.1-rc.0
v0.27.0-rc.1
v0.27.0-rc.0
v0.26.1-rc.0
v0.26.0
v0.26.0-rc.8
v0.26.0-rc.7
v0.26.0-rc.6
v0.26.0-rc.5
v0.26.0-rc.4
v0.26.0-rc.3
v0.26.0-rc.2
v0.26.0-rc.0
v0.25.1-rc.11
v0.25.1-rc.10
v0.25.1-rc.9
v0.25.1-rc.8
v0.25.1-rc.7
v0.25.1-rc.6
v0.25.1-rc.5
v0.25.1-rc.4
v0.25.1-rc.3
v0.25.1-rc.2
v0.25.1-rc.1
v0.25.1-rc.0
v0.25.0
v0.25.0-rc.3
v0.25.0-rc.2
v0.25.0-rc.1
v0.25.0-rc.0
v0.25.0-dev.0
v0.24.0
v0.24.0-rc.3
v0.24.0-rc.2
v0.24.0-rc.1
v0.24.0-rc.0
v0.23.0-rc.7
v0.23.0-rc.6
v0.23.0-rc.5
v0.23.0-rc.4
v0.23.0-rc.3
v0.23.0-rc.2
v0.23.0-rc.1
v0.23.0-rc.0
v0.22.2-rc.0
v0.22.1
v0.22.1-rc.0
v0.22.0
v0.22.0-rc.2
v0.22.0-rc.0
v0.21.0
v0.21.0-rc.1
v0.21.0-rc.0
v0.20.1-rc.4
v0.20.1-rc.3
v0.20.1-rc.2
v0.20.1-rc.1
v0.20.1-rc.0
v0.20.0-rc.2
v0.20.0-rc.1
v0.20.0-rc.0
v0.20.0-dev.0
v0.19.1
v0.19.1-rc.0
v0.19.0
v0.19.0-rc.5
v0.19.0-rc.4
v0.19.0-rc.3
v0.19.0-rc.2
v0.19.0-rc.1
v0.19.0-rc.0
v0.18.1
v0.18.1-rc.1
v0.18.1-rc.0
v0.18.0
v0.18.0-rc.2
v0.18.0-rc.1
v0.18.0-rc.0
v0.17.0-rc.1
v0.17.0-rc.0
v0.16.0-rc.3
v0.16.0-rc.2
v0.16.0-rc.1
v0.16.0-rc.0
v0.15.1-rc.2
v0.15.1-rc.1
v0.15.0
v0.15.0-rc.2
v0.15.0-rc.1
v0.15.0-rc.0
v0.14.0
v0.14.0-rc.1
v0.14.0-rc.0
v0.13.0
v0.12.0
v0.12.0-rc.2
v0.12.0-rc.1
v0.12.0-rc.0
v0.11.1
v0.11.0
v0.11.0-rc.1
v0.11.0-rc.0
v0.10.1
v0.10.0
v0.10.0-rc.6
v0.10.0-rc.5
v0.10.0-rc.4
v0.10.0-rc.3
v0.10.0-rc.2
v0.10.0-rc.1
v0.10.0-rc.0
v0.9.1
v0.9.1-rc.0
v0.9.0
v0.9.0-rc.5
v0.9.0-rc.4
v0.9.0-rc.3
v0.9.0-rc.2
v0.9.0-rc.1
v0.9.0-rc.0
v0.8.2-rc.0
v0.8.1
v0.8.1-rc.0
v0.8.0
v0.8.0-rc.2
v0.8.0-rc.1
v0.8.0-rc.0
v0.7.0
v0.6.0
v0.6.0rc0
v0.6.0-rc.0
v0.6.0-rc0
v0.5.0
v0.4.0
No related merge requests found
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
119 additions
and
11 deletions
+119
-11
matcher.cpp
src/ngraph/pattern/matcher.cpp
+32
-7
matcher.hpp
src/ngraph/pattern/matcher.hpp
+16
-0
any.hpp
src/ngraph/pattern/op/any.hpp
+54
-0
pattern.cpp
test/pattern.cpp
+17
-4
No files found.
src/ngraph/pattern/matcher.cpp
View file @
bff65fe3
...
...
@@ -71,19 +71,19 @@ namespace ngraph
return
is_match
;
}
bool
Matcher
::
match_skip
(
const
std
::
shared_ptr
<
op
::
Skip
>&
any
,
bool
Matcher
::
match_skip
(
const
std
::
shared_ptr
<
op
::
Skip
>&
skip
,
const
std
::
shared_ptr
<
Node
>&
graph_node
,
PatternMap
&
pattern_map
)
{
auto
predicate
=
any
->
get_predicate
();
auto
predicate
=
skip
->
get_predicate
();
if
(
!
predicate
||
any
->
get_predicate
()
(
graph_node
))
if
(
!
predicate
||
predicate
(
graph_node
))
{
return
match_arguments
(
any
,
graph_node
,
pattern_map
);
return
match_arguments
(
skip
,
graph_node
,
pattern_map
);
}
else
{
auto
args
=
any
->
get_arguments
();
auto
args
=
skip
->
get_arguments
();
if
(
args
.
size
()
!=
1
)
{
throw
ngraph_error
(
"Skip can only take one argument"
);
...
...
@@ -93,6 +93,26 @@ namespace ngraph
}
}
bool
Matcher
::
match_any
(
const
std
::
shared_ptr
<
op
::
Any
>&
any
,
const
std
::
shared_ptr
<
Node
>&
graph_node
,
PatternMap
&
pattern_map
)
{
auto
predicate
=
any
->
get_predicate
();
if
(
!
predicate
)
{
throw
ngraph_error
(
"predicate is required"
);
}
if
(
predicate
(
graph_node
))
{
return
match_arguments
(
any
,
graph_node
,
pattern_map
);
}
else
{
return
false
;
}
}
bool
Matcher
::
match_node
(
const
std
::
shared_ptr
<
Node
>&
pattern_node
,
const
std
::
shared_ptr
<
Node
>&
graph_node
,
PatternMap
&
pattern_map
)
...
...
@@ -111,10 +131,15 @@ namespace ngraph
return
match_pattern
(
label_node
,
graph_node
,
pattern_map
);
}
if
(
auto
any
_node
=
std
::
dynamic_pointer_cast
<
op
::
Skip
>
(
if
(
auto
skip
_node
=
std
::
dynamic_pointer_cast
<
op
::
Skip
>
(
pattern_node
))
//matches PatternSkipOp semantics
{
return
match_skip
(
any_node
,
graph_node
,
pattern_map
);
return
match_skip
(
skip_node
,
graph_node
,
pattern_map
);
}
if
(
auto
any_node
=
std
::
dynamic_pointer_cast
<
op
::
Any
>
(
pattern_node
))
{
return
match_any
(
any_node
,
graph_node
,
pattern_map
);
}
auto
p_pattern_node
=
pattern_node
.
get
();
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/pattern/matcher.hpp
View file @
bff65fe3
...
...
@@ -17,9 +17,12 @@
#pragma once
#include <cassert>
#include <functional>
#include <memory.h>
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp"
...
...
@@ -36,6 +39,16 @@ namespace ngraph
using
recurrent_graph_rewrite_callback
=
std
::
function
<
bool
(
class
RecurrentMatcher
&
m
)
>
;
using
RPatternMap
=
std
::
map
<
std
::
shared_ptr
<
op
::
Label
>
,
NodeVector
>
;
template
<
typename
T
>
std
::
function
<
bool
(
std
::
shared_ptr
<
Node
>
)
>
has_class
()
{
auto
pred
=
[](
std
::
shared_ptr
<
Node
>
node
)
->
bool
{
return
std
::
dynamic_pointer_cast
<
T
>
(
node
)
!=
nullptr
;
};
return
pred
;
}
namespace
op
{
class
Label
;
...
...
@@ -130,6 +143,9 @@ namespace ngraph
bool
match_skip
(
const
std
::
shared_ptr
<
op
::
Skip
>&
pattern_node
,
const
std
::
shared_ptr
<
Node
>&
graph_node
,
PatternMap
&
pattern_map
);
bool
match_any
(
const
std
::
shared_ptr
<
op
::
Any
>&
pattern_node
,
const
std
::
shared_ptr
<
Node
>&
graph_node
,
PatternMap
&
pattern_map
);
graph_rewrite_callback
m_callback
;
size_t
m_depth
;
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/pattern/op/any.hpp
0 → 100644
View file @
bff65fe3
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/pattern/op/pattern.hpp"
namespace
ngraph
{
namespace
pattern
{
namespace
op
{
/// \brief Anys are used in patterns to express arbitrary queries on a node
class
Any
:
public
Pattern
{
public
:
/// \brief creates a Any node containing a sub-pattern described by \sa type and \sa shape.
Any
(
const
element
::
Type
&
type
,
const
Shape
s
,
Predicate
pred
,
const
NodeVector
&
wrapped_nodes
)
:
Pattern
(
"Any"
,
wrapped_nodes
,
pred
)
{
if
(
!
pred
)
{
throw
ngraph_error
(
"predicate is required"
);
}
add_output
(
type
,
s
);
}
/// \brief creates a Any node containing a sub-pattern described by the type and shape of \sa node.
Any
(
std
::
shared_ptr
<
Node
>
node
,
Predicate
pred
,
const
NodeVector
&
wrapped_nodes
)
:
Any
(
node
->
get_element_type
(),
node
->
get_shape
(),
pred
,
wrapped_nodes
)
{
}
};
}
}
}
This diff is collapsed.
Click to expand it.
test/pattern.cpp
View file @
bff65fe3
...
...
@@ -402,19 +402,32 @@ TEST(pattern, matcher)
auto
any
=
std
::
make_shared
<
pattern
::
op
::
Skip
>
(
a
);
ASSERT_TRUE
(
n
.
match
(
any
,
abs
));
auto
any_false
=
std
::
make_shared
<
pattern
::
op
::
Skip
>
(
a
,
[](
std
::
shared_ptr
<
Node
>
no
)
{
return
false
;
}
);
auto
false_pred
=
[](
std
::
shared_ptr
<
Node
>
no
)
{
return
false
;
};
auto
any_false
=
std
::
make_shared
<
pattern
::
op
::
Skip
>
(
a
,
false_pred
);
ASSERT_TRUE
(
n
.
match
(
any_false
,
a
));
auto
pattern
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
a
);
ASSERT_TRUE
(
n
.
match
(
pattern
,
a
));
ASSERT_EQ
(
n
.
get_pattern_map
()[
pattern
],
a
);
auto
pattern_false
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
a
,
[](
std
::
shared_ptr
<
Node
>
no
)
{
return
false
;
});
auto
pattern_false
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
a
,
false_pred
);
ASSERT_FALSE
(
n
.
match
(
pattern_false
,
a
));
auto
b
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
shape
);
auto
is_bea
=
pattern
::
has_class
<
op
::
util
::
BinaryElementwiseArithmetic
>
();
auto
bea
=
std
::
make_shared
<
pattern
::
op
::
Any
>
(
a
,
is_bea
,
NodeVector
{
a
,
b
});
ASSERT_TRUE
(
n
.
match
(
bea
,
a
+
b
));
ASSERT_TRUE
(
n
.
match
(
bea
,
b
+
a
));
auto
bea_false
=
std
::
make_shared
<
pattern
::
op
::
Any
>
(
a
,
false_pred
,
NodeVector
{
a
,
b
});
ASSERT_FALSE
(
n
.
match
(
bea_false
,
a
+
b
));
auto
bea_label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
a
,
nullptr
,
NodeVector
{
bea
});
auto
ab
=
a
+
b
;
ASSERT_TRUE
(
n
.
match
(
bea_label
,
ab
));
ASSERT_EQ
(
n
.
get_pattern_map
()[
bea_label
],
ab
);
auto
d
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
shape
);
ASSERT_FALSE
(
n
.
match
(
d
,
b
));
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment