Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
ca955d46
Unverified
Commit
ca955d46
authored
Jan 28, 2020
by
Scott Cyphers
Committed by
GitHub
Jan 28, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
A couple new attribute visitors (#4233)
parent
46c21a0d
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
91 additions
and
31 deletions
+91
-31
atan2.cpp
src/ngraph/op/atan2.cpp
+9
-3
atan2.hpp
src/ngraph/op/atan2.hpp
+26
-21
binary_convolution.cpp
src/ngraph/op/binary_convolution.cpp
+35
-7
binary_convolution.hpp
src/ngraph/op/binary_convolution.hpp
+21
-0
No files found.
src/ngraph/op/atan2.cpp
View file @
ca955d46
...
@@ -27,19 +27,19 @@ using namespace ngraph;
...
@@ -27,19 +27,19 @@ using namespace ngraph;
constexpr
NodeTypeInfo
op
::
Atan2
::
type_info
;
constexpr
NodeTypeInfo
op
::
Atan2
::
type_info
;
op
::
Atan2
::
Atan2
(
const
Output
<
Node
>&
y
,
const
Output
<
Node
>&
x
,
const
AutoBroadcastSpec
&
autob
)
op
::
v0
::
Atan2
::
Atan2
(
const
Output
<
Node
>&
y
,
const
Output
<
Node
>&
x
,
const
AutoBroadcastSpec
&
autob
)
:
BinaryElementwiseArithmetic
(
y
,
x
,
autob
)
:
BinaryElementwiseArithmetic
(
y
,
x
,
autob
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
shared_ptr
<
Node
>
op
::
Atan2
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
v0
::
Atan2
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
check_new_args_count
(
this
,
new_args
);
check_new_args_count
(
this
,
new_args
);
return
make_shared
<
Atan2
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
this
->
get_autob
());
return
make_shared
<
Atan2
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
this
->
get_autob
());
}
}
void
op
::
Atan2
::
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
OutputVector
&
deltas
)
void
op
::
v0
::
Atan2
::
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
OutputVector
&
deltas
)
{
{
if
(
get_autob
().
m_type
!=
op
::
AutoBroadcastType
::
NONE
)
if
(
get_autob
().
m_type
!=
op
::
AutoBroadcastType
::
NONE
)
{
{
...
@@ -51,3 +51,9 @@ void op::Atan2::generate_adjoints(autodiff::Adjoints& adjoints, const OutputVect
...
@@ -51,3 +51,9 @@ void op::Atan2::generate_adjoints(autodiff::Adjoints& adjoints, const OutputVect
adjoints
.
add_delta
(
y
,
x
*
delta_over_r
);
adjoints
.
add_delta
(
y
,
x
*
delta_over_r
);
adjoints
.
add_delta
(
x
,
-
y
*
delta_over_r
);
adjoints
.
add_delta
(
x
,
-
y
*
delta_over_r
);
}
}
bool
op
::
v0
::
Atan2
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
BinaryElementwiseArithmetic
::
visit_attributes
(
visitor
);
return
true
;
}
src/ngraph/op/atan2.hpp
View file @
ca955d46
...
@@ -24,30 +24,35 @@ namespace ngraph
...
@@ -24,30 +24,35 @@ namespace ngraph
{
{
namespace
op
namespace
op
{
{
/// \brief Elementwise full arctan operation
namespace
v0
class
NGRAPH_API
Atan2
:
public
util
::
BinaryElementwiseArithmetic
{
{
public
:
/// \brief Elementwise full arctan operation
static
constexpr
NodeTypeInfo
type_info
{
"Atan2"
,
0
};
class
NGRAPH_API
Atan2
:
public
util
::
BinaryElementwiseArithmetic
const
NodeTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
Atan2
()
:
util
::
BinaryElementwiseArithmetic
(
AutoBroadcastSpec
::
NONE
)
{
{
}
public
:
static
constexpr
NodeTypeInfo
type_info
{
"Atan2"
,
0
};
const
NodeTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
Atan2
()
:
util
::
BinaryElementwiseArithmetic
(
AutoBroadcastSpec
::
NONE
)
{
}
/// \brief atan2(y,x) is the angle from the origin to the point (x,y) (note reversed
/// \brief atan2(y,x) is the angle from the origin to the point (x,y) (note reversed
/// order).
/// order).
///
///
/// \param y
/// \param y
/// \param x
/// \param x
Atan2
(
const
Output
<
Node
>&
y
,
Atan2
(
const
Output
<
Node
>&
y
,
const
Output
<
Node
>&
x
,
const
Output
<
Node
>&
x
,
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
const
AutoBroadcastSpec
&
autob
=
AutoBroadcastSpec
());
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
protected
:
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
OutputVector
&
deltas
)
override
;
const
OutputVector
&
deltas
)
override
;
};
};
}
using
v0
::
Atan2
;
}
}
}
}
src/ngraph/op/binary_convolution.cpp
View file @
ca955d46
...
@@ -142,20 +142,48 @@ shared_ptr<Node> op::v1::BinaryConvolution::copy_with_new_args(const NodeVector&
...
@@ -142,20 +142,48 @@ shared_ptr<Node> op::v1::BinaryConvolution::copy_with_new_args(const NodeVector&
m_auto_pad
);
m_auto_pad
);
}
}
bool
op
::
v1
::
BinaryConvolution
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"strides"
,
m_strides
);
visitor
.
on_attribute
(
"pads_begin"
,
m_pads_begin
);
visitor
.
on_attribute
(
"pads_end"
,
m_pads_end
);
visitor
.
on_attribute
(
"dilations"
,
m_dilations
);
visitor
.
on_attribute
(
"mode"
,
m_mode
);
visitor
.
on_attribute
(
"pad_value"
,
m_pad_value
);
visitor
.
on_attribute
(
"auto_pad"
,
m_auto_pad
);
return
true
;
}
void
op
::
v1
::
BinaryConvolution
::
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
void
op
::
v1
::
BinaryConvolution
::
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
OutputVector
&
deltas
)
const
OutputVector
&
deltas
)
{
{
throw
ngraph_error
(
"BinaryConvolution generate_adjoints not implemented"
);
throw
ngraph_error
(
"BinaryConvolution generate_adjoints not implemented"
);
}
}
op
::
v1
::
BinaryConvolution
::
BinaryConvolutionMode
namespace
ngraph
op
::
v1
::
BinaryConvolution
::
mode_from_string
(
const
std
::
string
&
mode
)
const
{
{
static
const
std
::
map
<
std
::
string
,
BinaryConvolutionMode
>
allowed_values
=
{
template
<>
{
"xnor-popcount"
,
BinaryConvolutionMode
::
XNOR_POPCOUNT
}};
EnumNames
<
op
::
v1
::
BinaryConvolution
::
BinaryConvolutionMode
>&
EnumNames
<
op
::
v1
::
BinaryConvolution
::
BinaryConvolutionMode
>::
get
()
{
static
auto
enum_names
=
EnumNames
<
op
::
v1
::
BinaryConvolution
::
BinaryConvolutionMode
>
(
"op::v1::BinaryConvolution::BinaryConvolutionMode"
,
{{
"xnor-popcount"
,
op
::
v1
::
BinaryConvolution
::
BinaryConvolutionMode
::
XNOR_POPCOUNT
}});
return
enum_names
;
}
NODE_VALIDATION_CHECK
(
constexpr
DiscreteTypeInfo
this
,
allowed_values
.
count
(
mode
)
>
0
,
"Invalid binary convolution mode value passed in."
)
;
AttributeAdapter
<
op
::
v1
::
BinaryConvolution
::
BinaryConvolutionMode
>::
type_info
;
return
allowed_values
.
at
(
mode
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
v1
::
BinaryConvolution
::
BinaryConvolutionMode
&
type
)
{
return
s
<<
as_string
(
type
);
}
}
op
::
v1
::
BinaryConvolution
::
BinaryConvolutionMode
op
::
v1
::
BinaryConvolution
::
mode_from_string
(
const
std
::
string
&
mode
)
const
{
return
as_enum
<
BinaryConvolutionMode
>
(
mode
);
}
}
src/ngraph/op/binary_convolution.hpp
View file @
ca955d46
...
@@ -74,6 +74,8 @@ namespace ngraph
...
@@ -74,6 +74,8 @@ namespace ngraph
size_t
get_version
()
const
override
{
return
1
;
}
size_t
get_version
()
const
override
{
return
1
;
}
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
...
@@ -112,4 +114,23 @@ namespace ngraph
...
@@ -112,4 +114,23 @@ namespace ngraph
};
};
}
}
}
// namespace op
}
// namespace op
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
v1
::
BinaryConvolution
::
BinaryConvolutionMode
&
type
);
template
<>
class
NGRAPH_API
AttributeAdapter
<
op
::
v1
::
BinaryConvolution
::
BinaryConvolutionMode
>
:
public
EnumAttributeAdapterBase
<
op
::
v1
::
BinaryConvolution
::
BinaryConvolutionMode
>
{
public
:
AttributeAdapter
(
op
::
v1
::
BinaryConvolution
::
BinaryConvolutionMode
&
value
)
:
EnumAttributeAdapterBase
<
op
::
v1
::
BinaryConvolution
::
BinaryConvolutionMode
>
(
value
)
{
}
static
constexpr
DiscreteTypeInfo
type_info
{
"AttributeAdapter<op::v1::BinaryConvolution::BinaryConvolutionMode>"
,
0
};
const
DiscreteTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
};
}
// namespace ngraph
}
// namespace ngraph
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment