Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
c386da90
Unverified
Commit
c386da90
authored
Aug 31, 2018
by
Adam Procter
Committed by
GitHub
Aug 31, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Better node validation error messages (#1533)
parent
132b5305
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
77 changed files
with
350 additions
and
582 deletions
+350
-582
node.cpp
src/ngraph/node.cpp
+21
-14
node.hpp
src/ngraph/node.hpp
+15
-8
abs.cpp
src/ngraph/op/abs.cpp
+1
-4
acos.cpp
src/ngraph/op/acos.cpp
+1
-4
add.cpp
src/ngraph/op/add.cpp
+1
-4
allreduce.cpp
src/ngraph/op/allreduce.cpp
+7
-9
and.cpp
src/ngraph/op/and.cpp
+1
-4
argmax.cpp
src/ngraph/op/argmax.cpp
+1
-4
argmin.cpp
src/ngraph/op/argmin.cpp
+1
-4
asin.cpp
src/ngraph/op/asin.cpp
+1
-4
atan.cpp
src/ngraph/op/atan.cpp
+1
-4
avg_pool.cpp
src/ngraph/op/avg_pool.cpp
+0
-0
broadcast.cpp
src/ngraph/op/broadcast.cpp
+19
-12
ceiling.cpp
src/ngraph/op/ceiling.cpp
+1
-4
concat.cpp
src/ngraph/op/concat.cpp
+33
-31
constant.cpp
src/ngraph/op/constant.cpp
+1
-4
constant.hpp
src/ngraph/op/constant.hpp
+12
-9
convert.cpp
src/ngraph/op/convert.cpp
+1
-4
convolution.cpp
src/ngraph/op/convolution.cpp
+3
-12
cos.cpp
src/ngraph/op/cos.cpp
+1
-4
cosh.cpp
src/ngraph/op/cosh.cpp
+1
-4
divide.cpp
src/ngraph/op/divide.cpp
+1
-4
dot.cpp
src/ngraph/op/dot.cpp
+20
-17
dot.hpp
src/ngraph/op/dot.hpp
+1
-4
equal.cpp
src/ngraph/op/equal.cpp
+1
-4
exp.cpp
src/ngraph/op/exp.cpp
+1
-4
floor.cpp
src/ngraph/op/floor.cpp
+1
-4
function_call.cpp
src/ngraph/op/function_call.cpp
+14
-9
get_output_element.cpp
src/ngraph/op/get_output_element.cpp
+4
-8
greater.cpp
src/ngraph/op/greater.cpp
+1
-4
greater_eq.cpp
src/ngraph/op/greater_eq.cpp
+1
-4
less.cpp
src/ngraph/op/less.cpp
+1
-4
less_eq.cpp
src/ngraph/op/less_eq.cpp
+1
-4
log.cpp
src/ngraph/op/log.cpp
+1
-4
lrn.cpp
src/ngraph/op/lrn.cpp
+3
-8
max.cpp
src/ngraph/op/max.cpp
+1
-4
max_pool.cpp
src/ngraph/op/max_pool.cpp
+3
-11
maximum.cpp
src/ngraph/op/maximum.cpp
+1
-4
min.cpp
src/ngraph/op/min.cpp
+1
-4
minimum.cpp
src/ngraph/op/minimum.cpp
+1
-4
multiply.cpp
src/ngraph/op/multiply.cpp
+1
-4
negative.cpp
src/ngraph/op/negative.cpp
+1
-4
not.cpp
src/ngraph/op/not.cpp
+1
-4
not_equal.cpp
src/ngraph/op/not_equal.cpp
+1
-4
one_hot.cpp
src/ngraph/op/one_hot.cpp
+7
-12
or.cpp
src/ngraph/op/or.cpp
+1
-4
pad.cpp
src/ngraph/op/pad.cpp
+16
-24
parameter.cpp
src/ngraph/op/parameter.cpp
+1
-4
power.cpp
src/ngraph/op/power.cpp
+1
-4
product.cpp
src/ngraph/op/product.cpp
+1
-4
reduce.cpp
src/ngraph/op/reduce.cpp
+1
-4
reduce_window.cpp
src/ngraph/op/reduce_window.cpp
+1
-4
relu.cpp
src/ngraph/op/relu.cpp
+4
-19
relu.hpp
src/ngraph/op/relu.hpp
+2
-2
remainder.cpp
src/ngraph/op/remainder.cpp
+1
-4
replace_slice.cpp
src/ngraph/op/replace_slice.cpp
+30
-43
reshape.cpp
src/ngraph/op/reshape.cpp
+10
-19
result.cpp
src/ngraph/op/result.cpp
+3
-13
reverse.cpp
src/ngraph/op/reverse.cpp
+3
-10
reverse_sequence.cpp
src/ngraph/op/reverse_sequence.cpp
+15
-21
select.cpp
src/ngraph/op/select.cpp
+15
-16
select_and_scatter.cpp
src/ngraph/op/select_and_scatter.cpp
+1
-4
sigmoid.cpp
src/ngraph/op/sigmoid.cpp
+10
-17
sign.cpp
src/ngraph/op/sign.cpp
+1
-4
sin.cpp
src/ngraph/op/sin.cpp
+1
-4
sinh.cpp
src/ngraph/op/sinh.cpp
+1
-4
slice.cpp
src/ngraph/op/slice.cpp
+21
-31
softmax.cpp
src/ngraph/op/softmax.cpp
+4
-8
sqrt.cpp
src/ngraph/op/sqrt.cpp
+1
-4
stop_gradient.cpp
src/ngraph/op/stop_gradient.cpp
+1
-4
subtract.cpp
src/ngraph/op/subtract.cpp
+1
-4
sum.cpp
src/ngraph/op/sum.cpp
+1
-4
tan.cpp
src/ngraph/op/tan.cpp
+1
-4
tanh.cpp
src/ngraph/op/tanh.cpp
+1
-4
arithmetic_reduction.cpp
src/ngraph/op/util/arithmetic_reduction.cpp
+4
-4
index_reduction.cpp
src/ngraph/op/util/index_reduction.cpp
+4
-3
type_prop.cpp
test/type_prop.cpp
+0
-0
No files found.
src/ngraph/node.cpp
View file @
c386da90
...
@@ -344,6 +344,20 @@ NodeVector Node::get_users() const
...
@@ -344,6 +344,20 @@ NodeVector Node::get_users() const
return
result
;
return
result
;
}
}
std
::
string
ngraph
::
node_validation_assertion_string
(
const
Node
*
node
)
{
std
::
stringstream
ss
;
ss
<<
"While validating node '"
<<
*
node
<<
"' of type '"
<<
node
->
description
()
<<
"'"
;
return
ss
.
str
();
}
void
ngraph
::
check_new_args_count
(
const
Node
*
node
,
const
NodeVector
&
new_args
)
{
NODE_VALIDATION_ASSERT
(
node
,
new_args
.
size
()
==
node
->
get_arguments
().
size
())
<<
"copy_with_new_args() expected "
<<
node
->
get_arguments
().
size
()
<<
" argument"
<<
(
node
->
get_arguments
().
size
()
==
1
?
""
:
"s"
)
<<
" but got "
<<
new_args
.
size
();
}
const
std
::
shared_ptr
<
Node
>&
ngraph
::
check_single_output_arg
(
const
std
::
shared_ptr
<
Node
>&
node
,
const
std
::
shared_ptr
<
Node
>&
ngraph
::
check_single_output_arg
(
const
std
::
shared_ptr
<
Node
>&
node
,
size_t
i
)
size_t
i
)
{
{
...
@@ -361,13 +375,6 @@ const NodeVector& ngraph::check_single_output_args(const NodeVector& args)
...
@@ -361,13 +375,6 @@ const NodeVector& ngraph::check_single_output_args(const NodeVector& args)
return
args
;
return
args
;
}
}
std
::
string
ngraph
::
type_check_assert_string
(
const
Node
*
node
)
{
std
::
stringstream
ss
;
ss
<<
"While type-checking node "
<<
*
node
;
return
ss
.
str
();
}
void
Node
::
validate_and_infer_elementwise
(
element
::
Type
result_type
)
void
Node
::
validate_and_infer_elementwise
(
element
::
Type
result_type
)
{
{
const
element
::
Type
&
element_type
=
get_input_element_type
(
0
);
const
element
::
Type
&
element_type
=
get_input_element_type
(
0
);
...
@@ -376,12 +383,12 @@ void Node::validate_and_infer_elementwise(element::Type result_type)
...
@@ -376,12 +383,12 @@ void Node::validate_and_infer_elementwise(element::Type result_type)
{
{
for
(
size_t
i
=
1
;
i
<
get_input_size
();
++
i
)
for
(
size_t
i
=
1
;
i
<
get_input_size
();
++
i
)
{
{
TYPE_CHECK
_ASSERT
(
this
,
get_input_element_type
(
i
)
==
element_type
)
NODE_VALIDATION
_ASSERT
(
this
,
get_input_element_type
(
i
)
==
element_type
)
<<
"Argument 0 element type "
<<
element_type
<<
"Argument 0 element type "
<<
element_type
<<
" differs in element type from argument "
<<
i
<<
" "
<<
*
get_argument
(
i
)
<<
" differs in element type from argument "
<<
i
<<
" "
<<
*
get_argument
(
i
)
<<
" element type "
<<
get_input_element_type
(
i
);
<<
" element type "
<<
get_input_element_type
(
i
);
TYPE_CHECK
_ASSERT
(
this
,
get_input_shape
(
i
)
==
shape
)
NODE_VALIDATION
_ASSERT
(
this
,
get_input_shape
(
i
)
==
shape
)
<<
"Argument 0 shape "
<<
shape
<<
" differs in shape from argument "
<<
i
<<
" "
<<
"Argument 0 shape "
<<
shape
<<
" differs in shape from argument "
<<
i
<<
" "
<<
*
get_argument
(
i
)
<<
" shape "
<<
get_input_shape
(
i
);
<<
*
get_argument
(
i
)
<<
" shape "
<<
get_input_shape
(
i
);
}
}
...
@@ -391,16 +398,16 @@ void Node::validate_and_infer_elementwise(element::Type result_type)
...
@@ -391,16 +398,16 @@ void Node::validate_and_infer_elementwise(element::Type result_type)
void
Node
::
validate_and_infer_elementwise_arithmetic
()
void
Node
::
validate_and_infer_elementwise_arithmetic
()
{
{
TYPE_CHECK
_ASSERT
(
this
,
get_input_element_type
(
0
)
!=
element
::
boolean
)
NODE_VALIDATION
_ASSERT
(
this
,
get_input_element_type
(
0
)
!=
element
::
boolean
)
<<
"
Operands for arithmetic operators must have numeric element type but have element type
"
<<
"
Arguments cannot have boolean element type (argument element type:
"
<<
get_input_element_type
(
0
);
<<
get_input_element_type
(
0
)
<<
")."
;
validate_and_infer_elementwise
(
get_input_element_type
(
0
));
validate_and_infer_elementwise
(
get_input_element_type
(
0
));
}
}
void
Node
::
validate_and_infer_elementwise_logical
()
void
Node
::
validate_and_infer_elementwise_logical
()
{
{
TYPE_CHECK
_ASSERT
(
this
,
get_input_element_type
(
0
)
==
element
::
boolean
)
NODE_VALIDATION
_ASSERT
(
this
,
get_input_element_type
(
0
)
==
element
::
boolean
)
<<
"Operands for logical operators must have boolean element type but have element type "
<<
"Operands for logical operators must have boolean element type but have element type "
<<
get_input_element_type
(
0
);
<<
get_input_element_type
(
0
)
<<
"."
;
validate_and_infer_elementwise
(
get_input_element_type
(
0
));
validate_and_infer_elementwise
(
get_input_element_type
(
0
));
}
}
src/ngraph/node.hpp
View file @
c386da90
...
@@ -58,7 +58,11 @@ namespace ngraph
...
@@ -58,7 +58,11 @@ namespace ngraph
const
std
::
shared_ptr
<
Node
>&
dst_node
,
const
std
::
shared_ptr
<
Node
>&
dst_node
,
const
std
::
shared_ptr
<
Node
>&
new_node
);
const
std
::
shared_ptr
<
Node
>&
new_node
);
std
::
string
type_check_assert_string
(
const
Node
*
node
);
std
::
string
node_validation_assertion_string
(
const
Node
*
node
);
const
std
::
shared_ptr
<
Node
>&
check_single_output_arg
(
const
std
::
shared_ptr
<
Node
>&
node
,
size_t
i
);
const
NodeVector
&
check_single_output_args
(
const
NodeVector
&
args
);
const
std
::
shared_ptr
<
Node
>&
check_single_output_arg
(
const
std
::
shared_ptr
<
Node
>&
node
,
const
std
::
shared_ptr
<
Node
>&
check_single_output_arg
(
const
std
::
shared_ptr
<
Node
>&
node
,
size_t
i
);
size_t
i
);
...
@@ -223,22 +227,25 @@ namespace ngraph
...
@@ -223,22 +227,25 @@ namespace ngraph
Placement
m_placement
=
Placement
::
DEFAULT
;
Placement
m_placement
=
Placement
::
DEFAULT
;
};
};
class
TypeCheck
Error
:
public
AssertionFailure
class
NodeValidation
Error
:
public
AssertionFailure
{
{
public
:
public
:
TypeCheck
Error
(
std
::
string
what
)
NodeValidation
Error
(
std
::
string
what
)
:
AssertionFailure
(
what
)
:
AssertionFailure
(
what
)
{
{
}
}
TypeCheck
Error
(
const
char
*
what
)
NodeValidation
Error
(
const
char
*
what
)
:
AssertionFailure
(
what
)
:
AssertionFailure
(
what
)
{
{
}
}
};
};
void
check_new_args_count
(
const
Node
*
node
,
const
NodeVector
&
new_args
);
}
}
#define
TYPE_CHECK_ASSERT(node, cond)
\
#define
NODE_VALIDATION_ASSERT(node, cond)
\
NGRAPH_ASSERT_STREAM_WITH_LOC( \
NGRAPH_ASSERT_STREAM_WITH_LOC( \
::ngraph::TypeCheckError, cond, ::ngraph::type_check_assert_string(node))
::ngraph::NodeValidationError, cond, ::ngraph::node_validation_assertion_string(node))
#define TYPE_CHECK_FAIL(node) \
#define NODE_VALIDATION_FAIL(node) \
NGRAPH_FAIL_STREAM_WITH_LOC(::ngraph::TypeCheckError, ::ngraph::type_check_assert_string(node))
NGRAPH_FAIL_STREAM_WITH_LOC(::ngraph::NodeValidationError, \
::ngraph::node_validation_assertion_string(node))
src/ngraph/op/abs.cpp
View file @
c386da90
...
@@ -29,10 +29,7 @@ op::Abs::Abs(const shared_ptr<Node>& arg)
...
@@ -29,10 +29,7 @@ op::Abs::Abs(const shared_ptr<Node>& arg)
shared_ptr
<
Node
>
op
::
Abs
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Abs
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Abs
>
(
new_args
.
at
(
0
));
return
make_shared
<
Abs
>
(
new_args
.
at
(
0
));
}
}
...
...
src/ngraph/op/acos.cpp
View file @
c386da90
...
@@ -40,10 +40,7 @@ op::Acos::Acos(const shared_ptr<Node>& arg)
...
@@ -40,10 +40,7 @@ op::Acos::Acos(const shared_ptr<Node>& arg)
shared_ptr
<
Node
>
op
::
Acos
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Acos
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Acos
>
(
new_args
.
at
(
0
));
return
make_shared
<
Acos
>
(
new_args
.
at
(
0
));
}
}
...
...
src/ngraph/op/add.cpp
View file @
c386da90
...
@@ -27,10 +27,7 @@ op::Add::Add(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
...
@@ -27,10 +27,7 @@ op::Add::Add(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr
<
Node
>
op
::
Add
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Add
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Add
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
return
make_shared
<
Add
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
}
}
...
...
src/ngraph/op/allreduce.cpp
View file @
c386da90
...
@@ -27,19 +27,17 @@ op::AllReduce::AllReduce(const shared_ptr<Node>& arg)
...
@@ -27,19 +27,17 @@ op::AllReduce::AllReduce(const shared_ptr<Node>& arg)
void
op
::
AllReduce
::
validate_and_infer_types
()
void
op
::
AllReduce
::
validate_and_infer_types
()
{
{
set_output_type
(
0
,
get_input_element_type
(
0
),
get_input_shape
(
0
));
NODE_VALIDATION_ASSERT
(
this
,
get_input_element_type
(
0
)
==
element
::
f32
||
get_input_element_type
(
0
)
==
element
::
f64
)
<<
"Only element types f32 and f64 are supported (argument element type: "
<<
get_input_element_type
(
0
)
<<
")."
;
if
((
get_input_element_type
(
0
)
!=
element
::
f32
)
&&
(
get_input_element_type
(
0
)
!=
element
::
f64
))
set_output_type
(
0
,
get_input_element_type
(
0
),
get_input_shape
(
0
));
{
throw
ngraph_error
(
"Unsupported data type for AllReduce"
);
}
}
}
shared_ptr
<
Node
>
op
::
AllReduce
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
AllReduce
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
AllReduce
>
(
new_args
.
at
(
0
));
return
make_shared
<
AllReduce
>
(
new_args
.
at
(
0
));
}
}
src/ngraph/op/and.cpp
View file @
c386da90
...
@@ -27,9 +27,6 @@ op::And::And(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
...
@@ -27,9 +27,6 @@ op::And::And(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr
<
Node
>
op
::
And
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
And
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
And
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
return
make_shared
<
And
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
}
}
src/ngraph/op/argmax.cpp
View file @
c386da90
...
@@ -21,9 +21,6 @@ using namespace ngraph;
...
@@ -21,9 +21,6 @@ using namespace ngraph;
shared_ptr
<
Node
>
op
::
ArgMax
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
ArgMax
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
ArgMax
>
(
new_args
.
at
(
0
),
m_axis
,
this
->
get_element_type
());
return
make_shared
<
ArgMax
>
(
new_args
.
at
(
0
),
m_axis
,
this
->
get_element_type
());
}
}
src/ngraph/op/argmin.cpp
View file @
c386da90
...
@@ -21,9 +21,6 @@ using namespace ngraph;
...
@@ -21,9 +21,6 @@ using namespace ngraph;
shared_ptr
<
Node
>
op
::
ArgMin
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
ArgMin
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
ArgMin
>
(
new_args
.
at
(
0
),
m_axis
,
this
->
get_element_type
());
return
make_shared
<
ArgMin
>
(
new_args
.
at
(
0
),
m_axis
,
this
->
get_element_type
());
}
}
src/ngraph/op/asin.cpp
View file @
c386da90
...
@@ -39,10 +39,7 @@ op::Asin::Asin(const shared_ptr<Node>& arg)
...
@@ -39,10 +39,7 @@ op::Asin::Asin(const shared_ptr<Node>& arg)
shared_ptr
<
Node
>
op
::
Asin
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Asin
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Asin
>
(
new_args
.
at
(
0
));
return
make_shared
<
Asin
>
(
new_args
.
at
(
0
));
}
}
...
...
src/ngraph/op/atan.cpp
View file @
c386da90
...
@@ -38,10 +38,7 @@ op::Atan::Atan(const shared_ptr<Node>& arg)
...
@@ -38,10 +38,7 @@ op::Atan::Atan(const shared_ptr<Node>& arg)
shared_ptr
<
Node
>
op
::
Atan
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Atan
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Atan
>
(
new_args
.
at
(
0
));
return
make_shared
<
Atan
>
(
new_args
.
at
(
0
));
}
}
...
...
src/ngraph/op/avg_pool.cpp
View file @
c386da90
This diff is collapsed.
Click to expand it.
src/ngraph/op/broadcast.cpp
View file @
c386da90
...
@@ -44,25 +44,32 @@ void op::Broadcast::validate_and_infer_types()
...
@@ -44,25 +44,32 @@ void op::Broadcast::validate_and_infer_types()
Shape
target_shape
=
m_shape
;
Shape
target_shape
=
m_shape
;
for
(
auto
i
=
m_broadcast_axes
.
rbegin
();
i
!=
m_broadcast_axes
.
rend
();
++
i
)
for
(
auto
i
=
m_broadcast_axes
.
rbegin
();
i
!=
m_broadcast_axes
.
rend
();
++
i
)
{
{
if
(
*
i
>=
target_shape
.
size
())
NODE_VALIDATION_ASSERT
(
this
,
*
i
<
target_shape
.
size
())
{
<<
"Broadcast axis index ("
<<
*
i
<<
") exceeds target shape rank "
throw
ngraph_error
(
"Broadcast axis exceeds target shape rank"
);
<<
"(broadcast axes: "
<<
m_broadcast_axes
<<
", target shape: "
<<
target_shape
}
<<
")."
;
target_shape
.
erase
(
target_shape
.
begin
()
+
*
i
);
target_shape
.
erase
(
target_shape
.
begin
()
+
*
i
);
}
}
if
(
Shape
{
target_shape
}
!=
get_input_shape
(
0
))
{
// TODO(amprocte): We can probably have a more helpful error message here.
throw
ngraph_error
(
"Broadcast arg, shape, and axes are incompatible"
);
// There are two things that can go wrong, which are being picked up in
}
// one fell swoop by this check: either the number of broadcast axes is not
// enough (arg->get_shape().size() + broadcast_axes.size() != shape.size())
// or there is a mismatch with one of the pre-broadcast axis lengths
// (i.e. target_shape.size() == arg->get_shape.size() but there is some i
// where target_shape[i] != arg->get_shape[i]).
NODE_VALIDATION_ASSERT
(
this
,
target_shape
==
get_input_shape
(
0
))
<<
"Broadcast argument shape, target shape, and axes are incompatible "
<<
"(argument shape: "
<<
get_input_shape
(
0
)
<<
", target shape: "
<<
m_shape
<<
", broadcast axes: "
<<
m_broadcast_axes
<<
")."
;
set_output_type
(
0
,
get_input_element_type
(
0
),
m_shape
);
set_output_type
(
0
,
get_input_element_type
(
0
),
m_shape
);
}
}
shared_ptr
<
Node
>
op
::
Broadcast
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Broadcast
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Broadcast
>
(
new_args
.
at
(
0
),
m_shape
,
m_broadcast_axes
);
return
make_shared
<
Broadcast
>
(
new_args
.
at
(
0
),
m_shape
,
m_broadcast_axes
);
}
}
...
...
src/ngraph/op/ceiling.cpp
View file @
c386da90
...
@@ -27,9 +27,6 @@ op::Ceiling::Ceiling(const shared_ptr<Node>& arg)
...
@@ -27,9 +27,6 @@ op::Ceiling::Ceiling(const shared_ptr<Node>& arg)
shared_ptr
<
Node
>
op
::
Ceiling
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Ceiling
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Ceiling
>
(
new_args
.
at
(
0
));
return
make_shared
<
Ceiling
>
(
new_args
.
at
(
0
));
}
}
src/ngraph/op/concat.cpp
View file @
c386da90
...
@@ -32,56 +32,58 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
...
@@ -32,56 +32,58 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
void
op
::
Concat
::
validate_and_infer_types
()
void
op
::
Concat
::
validate_and_infer_types
()
{
{
if
(
m_inputs
.
size
()
<
1
)
NODE_VALIDATION_ASSERT
(
this
,
m_inputs
.
size
()
>=
1
)
<<
"At least one argument required."
;
{
throw
ngraph_error
(
"At least one argument required"
);
}
auto
&
input_0
=
get_inputs
().
at
(
0
);
auto
input_0_shape
=
input_0
.
get_shape
();
if
(
m_concatenation_axis
>=
input_0_shape
.
size
())
{
throw
ngraph_error
(
"Concatenation axis is out of bounds"
);
}
size_t
concatenation_axis_length
=
input_0_shape
.
at
(
m_concatenation_axis
);
Shape
first_input_shape
=
get_input_shape
(
0
);
auto
&
input_0_element_type
=
input_0
.
get_element_type
();
size_t
expected_rank
=
first_input_shape
.
size
();
element
::
Type
expected_et
=
get_input_element_type
(
0
);
for
(
auto
i
=
1
;
i
<
get_inputs
().
size
();
i
++
)
for
(
auto
i
=
1
;
i
<
get_inputs
().
size
();
i
++
)
{
{
auto
&
input_i
=
get_inputs
().
at
(
i
);
NODE_VALIDATION_ASSERT
(
this
,
get_input_shape
(
i
).
size
()
==
expected_rank
)
auto
input_i_shape
=
input_i
.
get_shape
();
<<
"Not all arguments have the same rank: argument 0 has shape "
<<
first_input_shape
if
(
input_i_shape
.
size
()
!=
input_0_shape
.
size
())
<<
" of rank "
<<
expected_rank
<<
" but argument "
<<
i
<<
" has shape "
{
<<
get_input_shape
(
i
)
<<
" of rank "
<<
get_input_shape
(
i
).
size
()
<<
"."
;
throw
ngraph_error
(
"Arguments to concat do not have same rank"
);
NODE_VALIDATION_ASSERT
(
this
,
get_input_element_type
(
i
)
==
expected_et
)
<<
"Not all arguments have the same element type: argument 0 has element type "
<<
expected_et
<<
" but argument "
<<
i
<<
" has element type "
<<
get_input_element_type
(
i
)
<<
"."
;
}
}
if
(
input_i
.
get_element_type
()
!=
input_0_element_type
)
NODE_VALIDATION_ASSERT
(
this
,
m_concatenation_axis
<
expected_rank
)
{
<<
"Concatenation axis ("
<<
m_concatenation_axis
<<
") is out of bounds (inputs have rank "
throw
ngraph_error
(
"Argument element types do not match"
);
<<
expected_rank
<<
")."
;
}
size_t
concatenation_axis_output_length
=
first_input_shape
.
at
(
m_concatenation_axis
);
for
(
auto
j
=
0
;
j
<
input_i_shape
.
size
();
j
++
)
for
(
auto
i
=
1
;
i
<
get_inputs
().
size
();
i
++
)
{
for
(
auto
j
=
0
;
j
<
get_input_shape
(
i
).
size
();
j
++
)
{
{
if
(
j
!=
m_concatenation_axis
&&
input_0_shape
.
at
(
j
)
!=
input_i_shape
.
at
(
j
)
)
if
(
j
!=
m_concatenation_axis
)
{
{
throw
ngraph_error
(
NODE_VALIDATION_ASSERT
(
this
,
first_input_shape
[
j
]
==
get_input_shape
(
i
)[
j
])
"Arguments to concat do not have same dimension on a non-concatenation axis"
);
<<
"Dimensions of argument "
<<
i
<<
" do not match for axis "
<<
j
<<
" (expected "
<<
first_input_shape
[
j
]
<<
", got "
<<
get_input_shape
(
i
)[
j
]
<<
")."
;
}
}
else
if
(
j
==
m_concatenation_axis
)
else
{
{
concatenation_axis_
length
+=
input_i_shape
.
at
(
j
)
;
concatenation_axis_
output_length
+=
get_input_shape
(
i
)[
j
]
;
}
}
}
}
}
}
vector
<
size_t
>
concatenated_shape
=
input_0_shape
;
concatenated_shape
.
at
(
m_concatenation_axis
)
=
concatenation_axis_length
;
set_output_type
(
0
,
input_0_element_type
,
concatenated_shape
);
Shape
concatenated_shape
=
first_input_shape
;
concatenated_shape
[
m_concatenation_axis
]
=
concatenation_axis_output_length
;
set_output_type
(
0
,
expected_et
,
concatenated_shape
);
}
}
shared_ptr
<
Node
>
op
::
Concat
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Concat
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
// TODO(amprocte): Should we check the new_args count here?
return
make_shared
<
Concat
>
(
new_args
,
m_concatenation_axis
);
return
make_shared
<
Concat
>
(
new_args
,
m_concatenation_axis
);
}
}
...
...
src/ngraph/op/constant.cpp
View file @
c386da90
...
@@ -151,10 +151,7 @@ vector<string> op::Constant::get_value_strings() const
...
@@ -151,10 +151,7 @@ vector<string> op::Constant::get_value_strings() const
shared_ptr
<
Node
>
op
::
Constant
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Constant
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
0
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Constant
>
(
m_element_type
,
m_shape
,
m_data
);
return
make_shared
<
Constant
>
(
m_element_type
,
m_shape
,
m_data
);
}
}
...
...
src/ngraph/op/constant.hpp
View file @
c386da90
...
@@ -46,17 +46,19 @@ namespace ngraph
...
@@ -46,17 +46,19 @@ namespace ngraph
,
m_data
(
ngraph
::
aligned_alloc
(
m_element_type
.
size
(),
,
m_data
(
ngraph
::
aligned_alloc
(
m_element_type
.
size
(),
shape_size
(
m_shape
)
*
m_element_type
.
size
()))
shape_size
(
m_shape
)
*
m_element_type
.
size
()))
{
{
NODE_VALIDATION_ASSERT
(
this
,
values
.
size
()
==
1
||
values
.
size
()
==
shape_size
(
m_shape
))
<<
"Did not get the expected number of literals for a constant of shape "
<<
m_shape
<<
" (got "
<<
values
.
size
()
<<
", expected "
<<
(
shape_size
(
m_shape
)
==
1
?
""
:
"1 or "
)
<<
shape_size
(
m_shape
)
<<
")."
;
if
(
values
.
size
()
==
1
)
if
(
values
.
size
()
==
1
)
{
{
write_values
(
std
::
vector
<
T
>
(
shape_size
(
m_shape
),
values
[
0
]));
write_values
(
std
::
vector
<
T
>
(
shape_size
(
m_shape
),
values
[
0
]));
}
}
else
if
(
values
.
size
()
==
shape_size
(
m_shape
))
{
write_values
(
values
);
}
else
else
{
{
throw
ngraph_error
(
"Constant does not have the expected number of literals"
);
write_values
(
values
);
}
}
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
...
@@ -74,10 +76,11 @@ namespace ngraph
...
@@ -74,10 +76,11 @@ namespace ngraph
,
m_data
(
ngraph
::
aligned_alloc
(
m_element_type
.
size
(),
,
m_data
(
ngraph
::
aligned_alloc
(
m_element_type
.
size
(),
shape_size
(
m_shape
)
*
m_element_type
.
size
()))
shape_size
(
m_shape
)
*
m_element_type
.
size
()))
{
{
if
(
values
.
size
()
!=
shape_size
(
m_shape
))
NODE_VALIDATION_ASSERT
(
this
,
values
.
size
()
==
shape_size
(
m_shape
))
{
<<
"Did not get the expected number of literals for a constant of shape "
throw
ngraph_error
(
"Constant does not have the expected number of literals"
);
<<
m_shape
<<
" (got "
<<
values
.
size
()
<<
", expected "
<<
shape_size
(
m_shape
)
}
<<
"."
;
std
::
vector
<
double
>
dvalues
=
parse_string
<
double
>
(
values
);
std
::
vector
<
double
>
dvalues
=
parse_string
<
double
>
(
values
);
write_values
(
dvalues
);
write_values
(
dvalues
);
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
...
...
src/ngraph/op/convert.cpp
View file @
c386da90
...
@@ -35,10 +35,7 @@ void op::Convert::validate_and_infer_types()
...
@@ -35,10 +35,7 @@ void op::Convert::validate_and_infer_types()
shared_ptr
<
Node
>
op
::
Convert
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Convert
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Convert
>
(
new_args
.
at
(
0
),
m_element_type
);
return
make_shared
<
Convert
>
(
new_args
.
at
(
0
),
m_element_type
);
}
}
...
...
src/ngraph/op/convolution.cpp
View file @
c386da90
...
@@ -379,10 +379,7 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch, const shared_pt
...
@@ -379,10 +379,7 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch, const shared_pt
shared_ptr
<
Node
>
op
::
Convolution
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Convolution
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Convolution
>
(
new_args
.
at
(
0
),
return
make_shared
<
Convolution
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
1
),
m_window_movement_strides
,
m_window_movement_strides
,
...
@@ -584,10 +581,7 @@ void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints
...
@@ -584,10 +581,7 @@ void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints
shared_ptr
<
Node
>
op
::
ConvolutionBackpropData
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
ConvolutionBackpropData
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
ConvolutionBackpropData
>
(
m_data_batch_shape
,
return
make_shared
<
ConvolutionBackpropData
>
(
m_data_batch_shape
,
new_args
.
at
(
0
),
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
1
),
...
@@ -687,10 +681,7 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types()
...
@@ -687,10 +681,7 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types()
shared_ptr
<
Node
>
shared_ptr
<
Node
>
op
::
ConvolutionBackpropFilters
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
op
::
ConvolutionBackpropFilters
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
ConvolutionBackpropFilters
>
(
new_args
.
at
(
0
),
return
make_shared
<
ConvolutionBackpropFilters
>
(
new_args
.
at
(
0
),
m_filters_shape
,
m_filters_shape
,
new_args
.
at
(
1
),
new_args
.
at
(
1
),
...
...
src/ngraph/op/cos.cpp
View file @
c386da90
...
@@ -30,10 +30,7 @@ op::Cos::Cos(const shared_ptr<Node>& arg)
...
@@ -30,10 +30,7 @@ op::Cos::Cos(const shared_ptr<Node>& arg)
shared_ptr
<
Node
>
op
::
Cos
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Cos
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Cos
>
(
new_args
.
at
(
0
));
return
make_shared
<
Cos
>
(
new_args
.
at
(
0
));
}
}
...
...
src/ngraph/op/cosh.cpp
View file @
c386da90
...
@@ -29,10 +29,7 @@ op::Cosh::Cosh(const shared_ptr<Node>& arg)
...
@@ -29,10 +29,7 @@ op::Cosh::Cosh(const shared_ptr<Node>& arg)
shared_ptr
<
Node
>
op
::
Cosh
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Cosh
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Cosh
>
(
new_args
.
at
(
0
));
return
make_shared
<
Cosh
>
(
new_args
.
at
(
0
));
}
}
...
...
src/ngraph/op/divide.cpp
View file @
c386da90
...
@@ -29,10 +29,7 @@ op::Divide::Divide(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
...
@@ -29,10 +29,7 @@ op::Divide::Divide(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr
<
Node
>
op
::
Divide
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Divide
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Divide
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
return
make_shared
<
Divide
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
}
}
...
...
src/ngraph/op/dot.cpp
View file @
c386da90
...
@@ -56,30 +56,33 @@ void op::Dot::validate_and_infer_types()
...
@@ -56,30 +56,33 @@ void op::Dot::validate_and_infer_types()
(
input_0
.
get_shape
().
size
()
==
0
||
input_1
.
get_shape
().
size
()
==
0
)
?
0
:
1
;
(
input_0
.
get_shape
().
size
()
==
0
||
input_1
.
get_shape
().
size
()
==
0
)
?
0
:
1
;
}
}
if
(
input_0
.
get_element_type
()
!
=
input_1
.
get_element_type
())
NODE_VALIDATION_ASSERT
(
this
,
input_0
.
get_element_type
()
=
=
input_1
.
get_element_type
())
{
<<
"Arguments do not have the same element type (arg0 element type: "
throw
ngraph_error
(
"Arguments to dot must have the same element type"
);
<<
input_0
.
get_element_type
()
<<
", arg1 element type: "
<<
input_1
.
get_element_type
()
}
<<
")."
;
Shape
input_0_shape
=
input_0
.
get_shape
();
Shape
input_0_shape
=
input_0
.
get_shape
();
Shape
input_1_shape
=
input_1
.
get_shape
();
Shape
input_1_shape
=
input_1
.
get_shape
();
if
(
m_reduction_axes_count
>
input_0_shape
.
size
())
NODE_VALIDATION_ASSERT
(
this
,
{
m_reduction_axes_count
<=
input_0_shape
.
size
()
&&
throw
ngraph_error
(
"Dot has too many axes for arg0"
);
m_reduction_axes_count
<=
input_1_shape
.
size
())
}
<<
"Reduction axes count ("
<<
m_reduction_axes_count
<<
") is too large (arg0 shape: "
<<
input_0_shape
<<
", arg1 shape: "
<<
input_1_shape
if
(
m_reduction_axes_count
>
input_1_shape
.
size
())
<<
")."
;
{
throw
ngraph_error
(
"Dot has too many axes for arg1"
);
}
for
(
size_t
i
=
0
;
i
<
m_reduction_axes_count
;
i
++
)
for
(
size_t
i
=
0
;
i
<
m_reduction_axes_count
;
i
++
)
{
{
if
(
input_0_shape
[
input_0_shape
.
size
()
-
m_reduction_axes_count
+
i
]
!=
input_1_shape
[
i
])
size_t
axis_index_arg0
=
input_0_shape
.
size
()
-
m_reduction_axes_count
+
i
;
{
size_t
axis_index_arg1
=
i
;
throw
ngraph_error
(
"Dot axes do not have same length"
);
}
NODE_VALIDATION_ASSERT
(
this
,
input_0_shape
[
axis_index_arg0
]
==
input_1_shape
[
axis_index_arg1
])
<<
"Paired axes (axis "
<<
axis_index_arg0
<<
" from arg0, axis "
<<
axis_index_arg1
<<
" from arg1) "
<<
"do not have same length (arg0 shape: "
<<
input_0_shape
<<
", arg1 shape: "
<<
input_1_shape
<<
", "
<<
"reduction axes count: "
<<
m_reduction_axes_count
<<
")."
;
}
}
Shape
result_shape
(
input_0_shape
.
size
()
+
input_1_shape
.
size
()
-
2
*
m_reduction_axes_count
);
Shape
result_shape
(
input_0_shape
.
size
()
+
input_1_shape
.
size
()
-
2
*
m_reduction_axes_count
);
...
...
src/ngraph/op/dot.hpp
View file @
c386da90
...
@@ -56,10 +56,7 @@ namespace ngraph
...
@@ -56,10 +56,7 @@ namespace ngraph
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
std
::
make_shared
<
Dot
>
(
return
std
::
make_shared
<
Dot
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
m_reduction_axes_count
);
new_args
.
at
(
0
),
new_args
.
at
(
1
),
m_reduction_axes_count
);
}
}
...
...
src/ngraph/op/equal.cpp
View file @
c386da90
...
@@ -27,9 +27,6 @@ op::Equal::Equal(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
...
@@ -27,9 +27,6 @@ op::Equal::Equal(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr
<
Node
>
op
::
Equal
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Equal
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Equal
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
return
make_shared
<
Equal
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
}
}
src/ngraph/op/exp.cpp
View file @
c386da90
...
@@ -28,10 +28,7 @@ op::Exp::Exp(const shared_ptr<Node>& arg)
...
@@ -28,10 +28,7 @@ op::Exp::Exp(const shared_ptr<Node>& arg)
shared_ptr
<
Node
>
op
::
Exp
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Exp
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Exp
>
(
new_args
.
at
(
0
));
return
make_shared
<
Exp
>
(
new_args
.
at
(
0
));
}
}
...
...
src/ngraph/op/floor.cpp
View file @
c386da90
...
@@ -27,9 +27,6 @@ op::Floor::Floor(const shared_ptr<Node>& arg)
...
@@ -27,9 +27,6 @@ op::Floor::Floor(const shared_ptr<Node>& arg)
shared_ptr
<
Node
>
op
::
Floor
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Floor
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Floor
>
(
new_args
.
at
(
0
));
return
make_shared
<
Floor
>
(
new_args
.
at
(
0
));
}
}
src/ngraph/op/function_call.cpp
View file @
c386da90
...
@@ -30,18 +30,22 @@ op::FunctionCall::FunctionCall(shared_ptr<Function> function, const NodeVector&
...
@@ -30,18 +30,22 @@ op::FunctionCall::FunctionCall(shared_ptr<Function> function, const NodeVector&
// TODO : [nikolayk] this needs to be rewritten as follows
// TODO : [nikolayk] this needs to be rewritten as follows
// for each i : FunctionCall->get_inputs.at(i).get_tensor_view_type ==
// for each i : FunctionCall->get_inputs.at(i).get_tensor_view_type ==
// flatten(function_parms).at(i)
// flatten(function_parms).at(i)
if
(
get_input_size
()
!=
function_params
.
size
())
NODE_VALIDATION_ASSERT
(
this
,
get_input_size
()
==
function_params
.
size
())
{
<<
"Number of arguments ("
<<
get_input_size
()
<<
") does not match "
throw
ngraph_error
(
"Wrong number of arguments."
);
<<
"number of function parameters ("
<<
function_params
.
size
()
<<
")."
;
}
for
(
size_t
i
=
0
;
i
<
get_input_size
();
i
++
)
for
(
size_t
i
=
0
;
i
<
get_input_size
();
i
++
)
{
{
if
(
get_input_element_type
(
i
)
!=
function
->
get_parameters
().
at
(
i
)
->
get_element_type
()
||
NODE_VALIDATION_ASSERT
(
get_input_shape
(
i
)
!=
function
->
get_parameters
().
at
(
i
)
->
get_shape
())
this
,
get_input_element_type
(
i
)
==
function
->
get_parameters
()[
i
]
->
get_element_type
())
{
<<
"Element type mismatch for argument "
<<
i
<<
" (argument has type "
throw
ngraph_error
(
"Function argument type mismatch."
);
<<
get_input_element_type
(
i
)
<<
", function expects type "
}
<<
function
->
get_parameters
()[
i
]
->
get_element_type
();
NODE_VALIDATION_ASSERT
(
this
,
get_input_shape
(
i
)
==
function
->
get_parameters
()[
i
]
->
get_shape
())
<<
"Shape mismatch for argument "
<<
i
<<
" (argument has shape "
<<
get_input_shape
(
i
)
<<
", function expects shape "
<<
function
->
get_parameters
()[
i
]
->
get_shape
();
}
}
set_output_size
(
m_function
->
get_output_size
());
set_output_size
(
m_function
->
get_output_size
());
...
@@ -53,6 +57,7 @@ op::FunctionCall::FunctionCall(shared_ptr<Function> function, const NodeVector&
...
@@ -53,6 +57,7 @@ op::FunctionCall::FunctionCall(shared_ptr<Function> function, const NodeVector&
shared_ptr
<
Node
>
op
::
FunctionCall
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
FunctionCall
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
check_new_args_count
(
this
,
new_args
);
shared_ptr
<
FunctionCall
>
fc
=
make_shared
<
FunctionCall
>
(
m_function
,
new_args
);
shared_ptr
<
FunctionCall
>
fc
=
make_shared
<
FunctionCall
>
(
m_function
,
new_args
);
fc
->
m_function
=
clone_function
(
*
m_function
);
fc
->
m_function
=
clone_function
(
*
m_function
);
return
fc
;
return
fc
;
...
...
src/ngraph/op/get_output_element.cpp
View file @
c386da90
...
@@ -25,20 +25,16 @@ op::GetOutputElement::GetOutputElement(const shared_ptr<Node>& arg, size_t n)
...
@@ -25,20 +25,16 @@ op::GetOutputElement::GetOutputElement(const shared_ptr<Node>& arg, size_t n)
:
Node
(
"GetOutputElement"
,
{
arg
})
:
Node
(
"GetOutputElement"
,
{
arg
})
,
m_n
{
n
}
,
m_n
{
n
}
{
{
if
(
m_n
>=
arg
->
get_output_size
())
NODE_VALIDATION_ASSERT
(
this
,
m_n
<
arg
->
get_output_size
())
{
<<
"Output at index "
<<
m_n
<<
" requested, but argument has only "
throw
ngraph_error
(
"Indexing tuple beyond its size"
);
<<
arg
->
get_output_size
()
<<
" outputs."
;
}
set_output_type
(
0
,
arg
->
get_output_element_type
(
n
),
arg
->
get_output_shape
(
n
));
set_output_type
(
0
,
arg
->
get_output_element_type
(
n
),
arg
->
get_output_shape
(
n
));
}
}
shared_ptr
<
Node
>
op
::
GetOutputElement
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
GetOutputElement
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
GetOutputElement
>
(
new_args
.
at
(
0
),
m_n
);
return
make_shared
<
GetOutputElement
>
(
new_args
.
at
(
0
),
m_n
);
}
}
...
...
src/ngraph/op/greater.cpp
View file @
c386da90
...
@@ -27,9 +27,6 @@ op::Greater::Greater(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
...
@@ -27,9 +27,6 @@ op::Greater::Greater(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr
<
Node
>
op
::
Greater
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Greater
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Greater
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
return
make_shared
<
Greater
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
}
}
src/ngraph/op/greater_eq.cpp
View file @
c386da90
...
@@ -27,9 +27,6 @@ op::GreaterEq::GreaterEq(const shared_ptr<Node>& arg0, const shared_ptr<Node>& a
...
@@ -27,9 +27,6 @@ op::GreaterEq::GreaterEq(const shared_ptr<Node>& arg0, const shared_ptr<Node>& a
shared_ptr
<
Node
>
op
::
GreaterEq
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
GreaterEq
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
GreaterEq
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
return
make_shared
<
GreaterEq
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
}
}
src/ngraph/op/less.cpp
View file @
c386da90
...
@@ -27,9 +27,6 @@ op::Less::Less(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
...
@@ -27,9 +27,6 @@ op::Less::Less(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr
<
Node
>
op
::
Less
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Less
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Less
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
return
make_shared
<
Less
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
}
}
src/ngraph/op/less_eq.cpp
View file @
c386da90
...
@@ -27,9 +27,6 @@ op::LessEq::LessEq(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
...
@@ -27,9 +27,6 @@ op::LessEq::LessEq(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr
<
Node
>
op
::
LessEq
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
LessEq
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
LessEq
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
return
make_shared
<
LessEq
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
}
}
src/ngraph/op/log.cpp
View file @
c386da90
...
@@ -28,10 +28,7 @@ op::Log::Log(const shared_ptr<Node>& arg)
...
@@ -28,10 +28,7 @@ op::Log::Log(const shared_ptr<Node>& arg)
shared_ptr
<
Node
>
op
::
Log
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Log
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Log
>
(
new_args
.
at
(
0
));
return
make_shared
<
Log
>
(
new_args
.
at
(
0
));
}
}
...
...
src/ngraph/op/lrn.cpp
View file @
c386da90
...
@@ -28,18 +28,13 @@ op::LRN::LRN(const std::shared_ptr<Node>& arg, double alpha, double beta, double
...
@@ -28,18 +28,13 @@ op::LRN::LRN(const std::shared_ptr<Node>& arg, double alpha, double beta, double
,
m_size
(
nsize
)
,
m_size
(
nsize
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
if
(
arg
->
get_shape
().
size
()
<
3
)
NODE_VALIDATION_ASSERT
(
this
,
arg
->
get_shape
().
size
()
>=
3
)
{
<<
"Argument must have rank >= 3 (argument shape: "
<<
arg
->
get_shape
()
<<
")."
;
throw
ngraph_error
(
"LRN expects a tensor at least of rank of 3"
);
}
}
}
shared_ptr
<
Node
>
op
::
LRN
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
LRN
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
op
::
LRN
>
(
new_args
.
at
(
0
),
m_alpha
,
m_beta
,
m_bias
,
m_size
);
return
make_shared
<
op
::
LRN
>
(
new_args
.
at
(
0
),
m_alpha
,
m_beta
,
m_bias
,
m_size
);
}
}
...
...
src/ngraph/op/max.cpp
View file @
c386da90
...
@@ -27,9 +27,6 @@ op::Max::Max(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
...
@@ -27,9 +27,6 @@ op::Max::Max(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
shared_ptr
<
Node
>
op
::
Max
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Max
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Max
>
(
new_args
.
at
(
0
),
m_reduction_axes
);
return
make_shared
<
Max
>
(
new_args
.
at
(
0
),
m_reduction_axes
);
}
}
src/ngraph/op/max_pool.cpp
View file @
c386da90
...
@@ -201,10 +201,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg, const Shape& window_shape)
...
@@ -201,10 +201,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg, const Shape& window_shape)
shared_ptr
<
Node
>
op
::
MaxPool
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
MaxPool
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
MaxPool
>
(
new_args
.
at
(
0
),
return
make_shared
<
MaxPool
>
(
new_args
.
at
(
0
),
m_window_shape
,
m_window_shape
,
m_window_movement_strides
,
m_window_movement_strides
,
...
@@ -378,18 +375,13 @@ shared_ptr<op::MaxPool> op::MaxPoolBackprop::get_forward_op() const
...
@@ -378,18 +375,13 @@ shared_ptr<op::MaxPool> op::MaxPoolBackprop::get_forward_op() const
shared_ptr
<
Node
>
op
::
MaxPoolBackprop
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
MaxPoolBackprop
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
return
make_shared
<
op
::
MaxPoolBackprop
>
(
new_args
.
at
(
0
),
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
MaxPoolBackprop
*
mpbp
=
new
MaxPoolBackprop
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
1
),
m_window_shape
,
m_window_shape
,
m_window_movement_strides
,
m_window_movement_strides
,
m_padding_below
,
m_padding_below
,
m_padding_above
);
m_padding_above
);
return
shared_ptr
<
op
::
MaxPoolBackprop
>
(
mpbp
);
}
}
void
op
::
MaxPool
::
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
void
op
::
MaxPool
::
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
...
...
src/ngraph/op/maximum.cpp
View file @
c386da90
...
@@ -33,10 +33,7 @@ op::Maximum::Maximum(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
...
@@ -33,10 +33,7 @@ op::Maximum::Maximum(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr
<
Node
>
op
::
Maximum
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Maximum
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Maximum
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
return
make_shared
<
Maximum
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
}
}
...
...
src/ngraph/op/min.cpp
View file @
c386da90
...
@@ -27,9 +27,6 @@ op::Min::Min(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
...
@@ -27,9 +27,6 @@ op::Min::Min(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
shared_ptr
<
Node
>
op
::
Min
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Min
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Min
>
(
new_args
.
at
(
0
),
m_reduction_axes
);
return
make_shared
<
Min
>
(
new_args
.
at
(
0
),
m_reduction_axes
);
}
}
src/ngraph/op/minimum.cpp
View file @
c386da90
...
@@ -33,10 +33,7 @@ op::Minimum::Minimum(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
...
@@ -33,10 +33,7 @@ op::Minimum::Minimum(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr
<
Node
>
op
::
Minimum
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Minimum
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Minimum
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
return
make_shared
<
Minimum
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
}
}
...
...
src/ngraph/op/multiply.cpp
View file @
c386da90
...
@@ -27,10 +27,7 @@ op::Multiply::Multiply(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg
...
@@ -27,10 +27,7 @@ op::Multiply::Multiply(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg
shared_ptr
<
Node
>
op
::
Multiply
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Multiply
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Multiply
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
return
make_shared
<
Multiply
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
}
}
...
...
src/ngraph/op/negative.cpp
View file @
c386da90
...
@@ -27,10 +27,7 @@ op::Negative::Negative(const shared_ptr<Node>& arg)
...
@@ -27,10 +27,7 @@ op::Negative::Negative(const shared_ptr<Node>& arg)
shared_ptr
<
Node
>
op
::
Negative
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Negative
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Negative
>
(
new_args
.
at
(
0
));
return
make_shared
<
Negative
>
(
new_args
.
at
(
0
));
}
}
...
...
src/ngraph/op/not.cpp
View file @
c386da90
...
@@ -33,9 +33,6 @@ void op::Not::validate_and_infer_types()
...
@@ -33,9 +33,6 @@ void op::Not::validate_and_infer_types()
shared_ptr
<
Node
>
op
::
Not
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Not
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Not
>
(
new_args
.
at
(
0
));
return
make_shared
<
Not
>
(
new_args
.
at
(
0
));
}
}
src/ngraph/op/not_equal.cpp
View file @
c386da90
...
@@ -27,9 +27,6 @@ op::NotEqual::NotEqual(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg
...
@@ -27,9 +27,6 @@ op::NotEqual::NotEqual(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg
shared_ptr
<
Node
>
op
::
NotEqual
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
NotEqual
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
NotEqual
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
return
make_shared
<
NotEqual
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
}
}
src/ngraph/op/one_hot.cpp
View file @
c386da90
...
@@ -30,27 +30,22 @@ op::OneHot::OneHot(const shared_ptr<Node>& arg, const Shape& shape, size_t one_h
...
@@ -30,27 +30,22 @@ op::OneHot::OneHot(const shared_ptr<Node>& arg, const Shape& shape, size_t one_h
auto
&
input
=
m_inputs
.
at
(
0
);
auto
&
input
=
m_inputs
.
at
(
0
);
auto
&
input_element_type
=
input
.
get_element_type
();
auto
&
input_element_type
=
input
.
get_element_type
();
if
(
one_hot_axis
>=
shape
.
size
())
NODE_VALIDATION_ASSERT
(
this
,
one_hot_axis
<
shape
.
size
())
{
<<
"One-hot axis ("
<<
one_hot_axis
throw
ngraph_error
(
"One-hot axis is out of bounds"
);
<<
") is out of bounds (requested result shape: "
<<
shape
<<
")."
;
}
auto
expected_input_shape
=
shape
;
auto
expected_input_shape
=
shape
;
expected_input_shape
.
erase
(
expected_input_shape
.
begin
()
+
one_hot_axis
);
expected_input_shape
.
erase
(
expected_input_shape
.
begin
()
+
one_hot_axis
);
if
(
input
.
get_shape
()
!=
expected_input_shape
)
NODE_VALIDATION_ASSERT
(
this
,
input
.
get_shape
()
==
expected_input_shape
)
{
<<
"Argument shape "
<<
input
.
get_shape
()
<<
" does not match the expected shape of "
throw
ngraph_error
(
"One-hot argument shape is not compatible with desired output shape"
);
<<
expected_input_shape
<<
"."
;
}
set_output_type
(
0
,
input_element_type
,
shape
);
set_output_type
(
0
,
input_element_type
,
shape
);
}
}
shared_ptr
<
Node
>
op
::
OneHot
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
OneHot
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
OneHot
>
(
new_args
.
at
(
0
),
m_shape
,
m_one_hot_axis
);
return
make_shared
<
OneHot
>
(
new_args
.
at
(
0
),
m_shape
,
m_one_hot_axis
);
}
}
src/ngraph/op/or.cpp
View file @
c386da90
...
@@ -27,9 +27,6 @@ op::Or::Or(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
...
@@ -27,9 +27,6 @@ op::Or::Or(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr
<
Node
>
op
::
Or
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Or
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Or
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
return
make_shared
<
Or
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
}
}
src/ngraph/op/pad.cpp
View file @
c386da90
...
@@ -33,32 +33,27 @@ op::Pad::Pad(const shared_ptr<Node>& arg,
...
@@ -33,32 +33,27 @@ op::Pad::Pad(const shared_ptr<Node>& arg,
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
if
(
get_input_element_type
(
0
)
!=
get_input_element_type
(
1
))
NODE_VALIDATION_ASSERT
(
this
,
get_input_element_type
(
0
)
==
get_input_element_type
(
1
))
{
<<
"Argument element types do not match (arg0 element type: "
<<
get_input_element_type
(
0
)
throw
ngraph_error
(
"Pad argument tensor and padding value element types do not match"
);
<<
", arg1 element type: "
<<
get_input_element_type
(
1
)
<<
")."
;
}
if
(
get_input_shape
(
1
)
!=
Shape
{})
NODE_VALIDATION_ASSERT
(
this
,
get_input_shape
(
1
)
==
Shape
{})
{
<<
"Argument for padding value is not a scalar (shape: "
<<
get_input_shape
(
1
)
<<
")."
;
throw
ngraph_error
(
"Padding value for pad is not a scalar"
);
}
auto
arg_shape
=
get_input_shape
(
0
);
auto
arg_shape
=
get_input_shape
(
0
);
if
(
arg_shape
.
size
()
!=
padding_below
.
size
())
NODE_VALIDATION_ASSERT
(
this
,
arg_shape
.
size
()
==
padding_below
.
size
())
{
<<
"Rank for padding below does not match the rank of the data argument (padding below: "
throw
ngraph_error
(
"Pad rank for below-padding does not match rank of argument tensor"
);
<<
padding_below
<<
", data argument shape: "
<<
arg_shape
<<
")."
;
}
if
(
arg_shape
.
size
()
!=
padding_above
.
size
())
NODE_VALIDATION_ASSERT
(
this
,
arg_shape
.
size
()
==
padding_above
.
size
())
{
<<
"Rank for padding above does not match the rank of the data argument (padding above: "
throw
ngraph_error
(
"Pad rank for above-padding does not match rank of argument tensor"
);
<<
padding_above
<<
", data argument shape: "
<<
arg_shape
<<
")."
;
}
if
(
arg_shape
.
size
()
!
=
padding_interior
.
size
())
NODE_VALIDATION_ASSERT
(
this
,
arg_shape
.
size
()
=
=
padding_interior
.
size
())
{
<<
"Rank for interior padding does not match the rank of the data argument (interior "
throw
ngraph_error
(
"Pad rank for interior padding does not match rank of argument tensor"
);
"padding: "
}
<<
padding_interior
<<
", data argument shape: "
<<
arg_shape
<<
")."
;
Shape
result_shape
;
Shape
result_shape
;
...
@@ -75,10 +70,7 @@ op::Pad::Pad(const shared_ptr<Node>& arg,
...
@@ -75,10 +70,7 @@ op::Pad::Pad(const shared_ptr<Node>& arg,
shared_ptr
<
Node
>
op
::
Pad
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Pad
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Pad
>
(
return
make_shared
<
Pad
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
m_padding_below
,
m_padding_above
,
m_padding_interior
);
new_args
.
at
(
0
),
new_args
.
at
(
1
),
m_padding_below
,
m_padding_above
,
m_padding_interior
);
}
}
...
...
src/ngraph/op/parameter.cpp
View file @
c386da90
...
@@ -40,10 +40,7 @@ void op::Parameter::validate_and_infer_types()
...
@@ -40,10 +40,7 @@ void op::Parameter::validate_and_infer_types()
shared_ptr
<
Node
>
op
::
Parameter
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Parameter
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
0
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Parameter
>
(
m_element_type
,
m_shape
);
return
make_shared
<
Parameter
>
(
m_element_type
,
m_shape
);
}
}
...
...
src/ngraph/op/power.cpp
View file @
c386da90
...
@@ -30,10 +30,7 @@ op::Power::Power(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
...
@@ -30,10 +30,7 @@ op::Power::Power(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr
<
Node
>
op
::
Power
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Power
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Power
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
return
make_shared
<
Power
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
}
}
...
...
src/ngraph/op/product.cpp
View file @
c386da90
...
@@ -27,9 +27,6 @@ op::Product::Product(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
...
@@ -27,9 +27,6 @@ op::Product::Product(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
shared_ptr
<
Node
>
op
::
Product
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Product
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Product
>
(
new_args
.
at
(
0
),
m_reduction_axes
);
return
make_shared
<
Product
>
(
new_args
.
at
(
0
),
m_reduction_axes
);
}
}
src/ngraph/op/reduce.cpp
View file @
c386da90
...
@@ -98,10 +98,7 @@ op::Reduce::Reduce(const shared_ptr<Node>& arg_reductee,
...
@@ -98,10 +98,7 @@ op::Reduce::Reduce(const shared_ptr<Node>& arg_reductee,
shared_ptr
<
Node
>
op
::
Reduce
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Reduce
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
shared_ptr
<
Reduce
>
fc
=
shared_ptr
<
Reduce
>
fc
=
make_shared
<
Reduce
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
m_reduction_function
,
m_reduction_axes
);
make_shared
<
Reduce
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
m_reduction_function
,
m_reduction_axes
);
fc
->
m_reduction_function
=
clone_function
(
*
m_reduction_function
);
fc
->
m_reduction_function
=
clone_function
(
*
m_reduction_function
);
...
...
src/ngraph/op/reduce_window.cpp
View file @
c386da90
...
@@ -135,10 +135,7 @@ op::ReduceWindow::ReduceWindow(const shared_ptr<Node>& arg_reductee,
...
@@ -135,10 +135,7 @@ op::ReduceWindow::ReduceWindow(const shared_ptr<Node>& arg_reductee,
shared_ptr
<
Node
>
op
::
ReduceWindow
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
ReduceWindow
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
auto
node
=
make_shared
<
ReduceWindow
>
(
new_args
.
at
(
0
),
auto
node
=
make_shared
<
ReduceWindow
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
1
),
m_reduction_function
,
m_reduction_function
,
...
...
src/ngraph/op/relu.cpp
View file @
c386da90
...
@@ -24,38 +24,23 @@ op::Relu::Relu(shared_ptr<Node> arg)
...
@@ -24,38 +24,23 @@ op::Relu::Relu(shared_ptr<Node> arg)
:
UnaryElementwiseArithmetic
(
"Relu"
,
{
arg
})
:
UnaryElementwiseArithmetic
(
"Relu"
,
{
arg
})
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
set_output_type
(
0
,
arg
->
get_element_type
(),
arg
->
get_shape
());
}
}
shared_ptr
<
Node
>
op
::
Relu
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Relu
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Relu
>
(
new_args
.
at
(
0
));
return
make_shared
<
Relu
>
(
new_args
.
at
(
0
));
}
}
op
::
ReluBackprop
::
ReluBackprop
(
shared_ptr
<
Node
>
arg
,
shared_ptr
<
Node
>
delta
)
op
::
ReluBackprop
::
ReluBackprop
(
shared_ptr
<
Node
>
arg
,
shared_ptr
<
Node
>
delta
)
:
Op
(
"ReluBackprop"
,
check_single_output_args
({
arg
,
delta
})
)
:
BinaryElementwiseArithmetic
(
"ReluBackprop"
,
arg
,
delta
)
{
{
if
(
arg
->
get_element_type
()
!=
delta
->
get_element_type
())
constructor_validate_and_infer_types
();
{
throw
ngraph_error
(
"Argument and delta element types for Relu backprop do not match"
);
}
if
(
arg
->
get_shape
()
!=
delta
->
get_shape
())
{
throw
ngraph_error
(
"Argument and delta shape for Relu backprop do not match"
);
}
set_output_type
(
0
,
delta
->
get_element_type
(),
delta
->
get_shape
());
}
}
shared_ptr
<
Node
>
op
::
ReluBackprop
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
ReluBackprop
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
ReluBackprop
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
return
make_shared
<
ReluBackprop
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
}
}
...
...
src/ngraph/op/relu.hpp
View file @
c386da90
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
#include "ngraph/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/
op
.hpp"
#include "ngraph/op/
util/binary_elementwise_arithmetic
.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/util.hpp"
#include "ngraph/util.hpp"
...
@@ -47,7 +47,7 @@ namespace ngraph
...
@@ -47,7 +47,7 @@ namespace ngraph
/// \brief Elementwise ReluBackprop operation.
/// \brief Elementwise ReluBackprop operation.
///
///
class
ReluBackprop
:
public
Op
class
ReluBackprop
:
public
ngraph
::
op
::
util
::
BinaryElementwiseArithmetic
{
{
public
:
public
:
/// \brief Constructs a ReluBackprop operation.
/// \brief Constructs a ReluBackprop operation.
...
...
src/ngraph/op/remainder.cpp
View file @
c386da90
...
@@ -27,9 +27,6 @@ op::Remainder::Remainder(const shared_ptr<Node>& arg0, const shared_ptr<Node>& a
...
@@ -27,9 +27,6 @@ op::Remainder::Remainder(const shared_ptr<Node>& arg0, const shared_ptr<Node>& a
shared_ptr
<
Node
>
op
::
Remainder
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Remainder
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Remainder
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
return
make_shared
<
Remainder
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
}
}
src/ngraph/op/replace_slice.cpp
View file @
c386da90
...
@@ -60,52 +60,43 @@ void op::ReplaceSlice::check_args()
...
@@ -60,52 +60,43 @@ void op::ReplaceSlice::check_args()
auto
&
input_1_shape
=
input_1
.
get_shape
();
auto
&
input_1_shape
=
input_1
.
get_shape
();
auto
&
input_1_element_type
=
input_1
.
get_element_type
();
auto
&
input_1_element_type
=
input_1
.
get_element_type
();
if
(
input_0_shape
.
size
()
!=
input_1_shape
.
size
())
NODE_VALIDATION_ASSERT
(
this
,
input_0_shape
.
size
()
==
input_1_shape
.
size
())
{
<<
"Argument ranks do not match (arg0 shape: "
<<
input_0_shape
throw
ngraph_error
(
"Replace-slice argument ranks do not match"
);
<<
", arg1 shape: "
<<
input_1_shape
<<
")."
;
}
if
(
input_0_element_type
!=
input_1_element_type
)
NODE_VALIDATION_ASSERT
(
this
,
input_0_element_type
==
input_1_element_type
)
{
<<
"Argument element types do not match (arg0 element type: "
<<
input_0_element_type
throw
ngraph_error
(
"Element types for replace-slice arguments do not match"
);
<<
", arg1 element type: "
<<
input_1_element_type
<<
")."
;
}
if
(
m_lower_bounds
.
size
()
!=
input_0_shape
.
size
())
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
.
size
()
==
input_0_shape
.
size
())
{
<<
"Rank of lower bounds ("
<<
m_lower_bounds
.
size
()
<<
") does not match rank "
throw
ngraph_error
(
<<
"of argument ("
<<
input_0_shape
.
size
()
<<
") (lower bounds: "
<<
m_lower_bounds
"Number of lower bounds provided for slice does not match number of input axes"
);
<<
", argument shape: "
<<
input_0_shape
<<
")."
;
}
if
(
m_upper_bounds
.
size
()
!=
input_0_shape
.
size
())
NODE_VALIDATION_ASSERT
(
this
,
m_upper_bounds
.
size
()
==
input_0_shape
.
size
())
{
<<
"Rank of upper bounds ("
<<
m_upper_bounds
.
size
()
<<
") does not match rank "
throw
ngraph_error
(
<<
"of argument ("
<<
input_0_shape
.
size
()
<<
") (upper bounds: "
<<
m_upper_bounds
"Number of upper bounds provided for slice does not match number of input axes"
);
<<
", argument shape: "
<<
input_0_shape
<<
")."
;
}
if
(
m_strides
.
size
()
!=
input_0_shape
.
size
())
NODE_VALIDATION_ASSERT
(
this
,
m_strides
.
size
()
==
input_0_shape
.
size
())
{
<<
"Rank of strides ("
<<
m_strides
.
size
()
<<
") does not match rank "
throw
ngraph_error
(
<<
"of argument ("
<<
input_0_shape
.
size
()
<<
") (strides: "
<<
m_strides
"Number of strides provided for slice does not match number of input axes"
);
<<
", argument shape: "
<<
input_0_shape
<<
")."
;
}
Shape
slice_shape
;
Shape
slice_shape
;
for
(
size_t
i
=
0
;
i
<
input_0_shape
.
size
();
i
++
)
for
(
size_t
i
=
0
;
i
<
input_0_shape
.
size
();
i
++
)
{
{
if
(
m_upper_bounds
[
i
]
>
input_0_shape
[
i
])
NODE_VALIDATION_ASSERT
(
this
,
m_upper_bounds
[
i
]
<=
input_0_shape
[
i
])
{
<<
"Upper bound for slice at axis "
<<
i
<<
" is out of range "
throw
ngraph_error
(
"Upper bound for slice is out of range"
);
<<
"(upper bounds: "
<<
m_upper_bounds
<<
", argument shape: "
<<
input_0_shape
<<
")."
;
}
if
(
m_lower_bounds
[
i
]
>
m_upper_bounds
[
i
])
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
])
{
<<
"Lower bound for slice is greater than upper bound at axis "
<<
i
throw
ngraph_error
(
"Lower bound for slice is greater than upper bound"
);
<<
" (lower bounds: "
<<
m_lower_bounds
<<
", upper bounds: "
<<
m_upper_bounds
<<
")."
;
}
if
(
0
==
m_strides
[
i
])
NODE_VALIDATION_ASSERT
(
this
,
m_strides
[
i
]
!=
0
)
<<
"Stride for slice is zero at axis "
<<
i
{
<<
" (strides: "
<<
m_strides
<<
")."
;
throw
ngraph_error
(
"Stride for slice is zero"
);
}
size_t
slice_axis_size
=
m_upper_bounds
[
i
]
-
m_lower_bounds
[
i
];
size_t
slice_axis_size
=
m_upper_bounds
[
i
]
-
m_lower_bounds
[
i
];
slice_axis_size
=
slice_axis_size
=
...
@@ -113,20 +104,16 @@ void op::ReplaceSlice::check_args()
...
@@ -113,20 +104,16 @@ void op::ReplaceSlice::check_args()
slice_shape
.
push_back
(
slice_axis_size
);
slice_shape
.
push_back
(
slice_axis_size
);
}
}
if
(
input_1_shape
!=
slice_shape
)
NODE_VALIDATION_ASSERT
(
this
,
input_1_shape
==
slice_shape
)
{
<<
"Shape of replacement tensor ("
<<
input_1_shape
<<
") does not match the slice shape "
throw
ngraph_error
(
"Shape of replacement tensor does not match slice shape"
);
<<
"("
<<
slice_shape
<<
")."
;
}
set_output_type
(
0
,
input_0_element_type
,
input_0_shape
);
set_output_type
(
0
,
input_0_element_type
,
input_0_shape
);
}
}
shared_ptr
<
Node
>
op
::
ReplaceSlice
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
ReplaceSlice
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
ReplaceSlice
>
(
return
make_shared
<
ReplaceSlice
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
m_lower_bounds
,
m_upper_bounds
,
m_strides
);
new_args
.
at
(
0
),
new_args
.
at
(
1
),
m_lower_bounds
,
m_upper_bounds
,
m_strides
);
}
}
...
...
src/ngraph/op/reshape.cpp
View file @
c386da90
...
@@ -39,19 +39,16 @@ void op::Reshape::validate_and_infer_types()
...
@@ -39,19 +39,16 @@ void op::Reshape::validate_and_infer_types()
auto
input_shape
=
input
.
get_shape
();
auto
input_shape
=
input
.
get_shape
();
auto
input_rank
=
input_shape
.
size
();
auto
input_rank
=
input_shape
.
size
();
if
(
m_input_order
.
size
()
!=
input_rank
)
NODE_VALIDATION_ASSERT
(
this
,
m_input_order
.
size
()
==
input_rank
)
{
<<
"Input axis order is not a permutation of argument's axis indices (axis order: "
throw
ngraph_error
(
"Input axis order for reshape is not a permutation of argument's axes"
);
<<
m_input_order
<<
", argument shape: "
<<
input_shape
<<
")."
;
}
for
(
size_t
i
=
0
;
i
<
input_rank
;
i
++
)
for
(
size_t
i
=
0
;
i
<
input_rank
;
i
++
)
{
{
auto
it
=
find
(
begin
(
m_input_order
),
end
(
m_input_order
),
i
);
auto
it
=
find
(
begin
(
m_input_order
),
end
(
m_input_order
),
i
);
if
(
end
(
m_input_order
)
==
it
)
NODE_VALIDATION_ASSERT
(
this
,
it
!=
end
(
m_input_order
))
{
<<
"Input axis order is not a permutation of argument's axis indices (axis order: "
throw
ngraph_error
(
<<
m_input_order
<<
", argument shape: "
<<
input_shape
<<
")."
;
"Input axis order for reshape is not a permutation of argument's axes"
);
}
}
}
size_t
input_shape_product
=
1
;
size_t
input_shape_product
=
1
;
...
@@ -66,12 +63,9 @@ void op::Reshape::validate_and_infer_types()
...
@@ -66,12 +63,9 @@ void op::Reshape::validate_and_infer_types()
output_shape_product
*=
i
;
output_shape_product
*=
i
;
}
}
if
(
input_shape_product
!=
output_shape_product
)
NODE_VALIDATION_ASSERT
(
this
,
input_shape_product
==
output_shape_product
)
{
<<
"Product of output shape dimensions does not match product of argument shape dimensions "
throw
ngraph_error
(
<<
"(output shape: "
<<
m_output_shape
<<
", argument shape: "
<<
input_shape
<<
")."
;
"Product of output shape dimensions does not match product of argument shape "
"dimensions for reshape"
);
}
if
(
!
std
::
is_sorted
(
m_input_order
.
begin
(),
m_input_order
.
end
()))
if
(
!
std
::
is_sorted
(
m_input_order
.
begin
(),
m_input_order
.
end
()))
{
{
...
@@ -82,10 +76,7 @@ void op::Reshape::validate_and_infer_types()
...
@@ -82,10 +76,7 @@ void op::Reshape::validate_and_infer_types()
shared_ptr
<
Node
>
op
::
Reshape
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Reshape
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Reshape
>
(
new_args
.
at
(
0
),
m_input_order
,
m_output_shape
);
return
make_shared
<
Reshape
>
(
new_args
.
at
(
0
),
m_input_order
,
m_output_shape
);
}
}
...
...
src/ngraph/op/result.cpp
View file @
c386da90
...
@@ -32,10 +32,8 @@ op::Result::Result(const shared_ptr<Node>& arg)
...
@@ -32,10 +32,8 @@ op::Result::Result(const shared_ptr<Node>& arg)
void
op
::
Result
::
validate_and_infer_types
()
void
op
::
Result
::
validate_and_infer_types
()
{
{
if
(
get_input_size
()
!=
1
)
NODE_VALIDATION_ASSERT
(
this
,
get_input_size
()
==
1
)
<<
"Argument has "
<<
get_input_size
()
{
<<
" outputs (1 expected)."
;
throw
ngraph_error
(
"Result expected a single-output argument"
);
}
// always borrow the placement conf even the default one
// always borrow the placement conf even the default one
set_placement
(
get_argument
(
0
)
->
get_placement
());
set_placement
(
get_argument
(
0
)
->
get_placement
());
...
@@ -44,15 +42,7 @@ void op::Result::validate_and_infer_types()
...
@@ -44,15 +42,7 @@ void op::Result::validate_and_infer_types()
shared_ptr
<
Node
>
op
::
Result
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Result
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
if
(
new_args
.
at
(
0
)
->
get_outputs
().
size
()
!=
1
)
{
throw
ngraph_error
(
"Result::copy_with_new_args expected a single-output argument"
);
}
auto
res
=
make_shared
<
Result
>
(
new_args
.
at
(
0
));
auto
res
=
make_shared
<
Result
>
(
new_args
.
at
(
0
));
if
(
res
)
if
(
res
)
...
...
src/ngraph/op/reverse.cpp
View file @
c386da90
...
@@ -38,13 +38,9 @@ void op::Reverse::validate_and_infer_types()
...
@@ -38,13 +38,9 @@ void op::Reverse::validate_and_infer_types()
// Make sure all reversed axis indices are valid.
// Make sure all reversed axis indices are valid.
for
(
size_t
axis
:
m_reversed_axes
)
for
(
size_t
axis
:
m_reversed_axes
)
{
{
if
(
axis
>=
input_rank
)
NODE_VALIDATION_ASSERT
(
this
,
axis
<
input_rank
)
{
<<
"Reverse axis ("
<<
axis
<<
") is out of bounds (argument shape: "
<<
input_shape
stringstream
ss
;
ss
<<
"Reverse axis "
<<
axis
<<
" is out of bounds (input rank is "
<<
input_rank
<<
")."
;
<<
")."
;
throw
ngraph_error
(
ss
.
str
());
}
}
}
set_output_type
(
0
,
get_input_element_type
(
0
),
input_shape
);
set_output_type
(
0
,
get_input_element_type
(
0
),
input_shape
);
...
@@ -52,10 +48,7 @@ void op::Reverse::validate_and_infer_types()
...
@@ -52,10 +48,7 @@ void op::Reverse::validate_and_infer_types()
shared_ptr
<
Node
>
op
::
Reverse
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Reverse
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Reverse
>
(
new_args
.
at
(
0
),
m_reversed_axes
);
return
make_shared
<
Reverse
>
(
new_args
.
at
(
0
),
m_reversed_axes
);
}
}
...
...
src/ngraph/op/reverse_sequence.cpp
View file @
c386da90
...
@@ -38,36 +38,30 @@ op::ReverseSequence::ReverseSequence(const std::shared_ptr<Node> arg,
...
@@ -38,36 +38,30 @@ op::ReverseSequence::ReverseSequence(const std::shared_ptr<Node> arg,
void
op
::
ReverseSequence
::
validate_and_infer_types
()
void
op
::
ReverseSequence
::
validate_and_infer_types
()
{
{
if
(
get_input_shape
(
1
).
size
()
!=
1
)
NODE_VALIDATION_ASSERT
(
this
,
get_input_shape
(
1
).
size
()
==
1
)
{
<<
"Sequence indices must be a 1-dimensional tensor (sequence indices shape: "
throw
ngraph_error
(
"indices should be a 1-dimensional array"
);
<<
get_input_shape
(
1
)
<<
")."
;
}
if
(
m_batch_axis
>=
get_input_shape
(
0
).
size
())
NODE_VALIDATION_ASSERT
(
this
,
m_batch_axis
<
get_input_shape
(
0
).
size
())
{
<<
"Batch axis index ("
<<
m_batch_axis
throw
ngraph_error
(
"batch axis index is out of bounds"
);
<<
") is out of bounds (argument shape: "
<<
get_input_shape
(
0
)
<<
")."
;
}
if
(
m_seq_axis
>=
get_input_shape
(
0
).
size
())
NODE_VALIDATION_ASSERT
(
this
,
m_seq_axis
<
get_input_shape
(
0
).
size
())
{
<<
"Sequence axis index ("
<<
m_seq_axis
throw
ngraph_error
(
"sequence axis index is out of bounds"
);
<<
") is out of bounds (argument shape: "
<<
get_input_shape
(
0
)
<<
")."
;
}
if
(
get_input_shape
(
0
).
at
(
m_batch_axis
)
!=
get_input_shape
(
1
).
at
(
0
))
NODE_VALIDATION_ASSERT
(
this
,
get_input_shape
(
0
)[
m_batch_axis
]
==
get_input_shape
(
1
)[
0
])
{
<<
"Sequence length ("
<<
get_input_shape
(
1
)[
0
]
<<
") is not equal to batch axis "
throw
ngraph_error
(
"Sequence length size should be equal to batch axis dimension"
);
<<
"dimension ("
<<
get_input_shape
(
0
)[
m_batch_axis
]
}
<<
") (argument shape: "
<<
get_input_shape
(
0
)
<<
", sequence indices shape: "
<<
get_input_shape
(
1
)
<<
")."
;
set_output_type
(
0
,
get_input_element_type
(
0
),
get_input_shape
(
0
));
set_output_type
(
0
,
get_input_element_type
(
0
),
get_input_shape
(
0
));
}
}
shared_ptr
<
Node
>
op
::
ReverseSequence
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
ReverseSequence
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
auto
res
=
auto
res
=
make_shared
<
ReverseSequence
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
m_batch_axis
,
m_seq_axis
);
make_shared
<
ReverseSequence
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
m_batch_axis
,
m_seq_axis
);
return
res
;
return
res
;
...
...
src/ngraph/op/select.cpp
View file @
c386da90
...
@@ -37,28 +37,27 @@ op::Select::Select(const shared_ptr<Node>& arg0,
...
@@ -37,28 +37,27 @@ op::Select::Select(const shared_ptr<Node>& arg0,
auto
&
input_1
=
get_inputs
().
at
(
1
);
auto
&
input_1
=
get_inputs
().
at
(
1
);
auto
&
input_2
=
get_inputs
().
at
(
2
);
auto
&
input_2
=
get_inputs
().
at
(
2
);
if
(
input_0
.
get_element_type
()
!=
element
::
boolean
)
NODE_VALIDATION_ASSERT
(
this
,
input_0
.
get_element_type
()
==
element
::
boolean
)
{
<<
"Argument 0 does not have boolean element type (element type: "
throw
ngraph_error
(
"Argument 0 for arithmetic operators must have boolean element type"
);
<<
input_0
.
get_element_type
()
<<
")."
;
}
if
(
input_0
.
get_shape
()
!=
input_1
.
get_shape
()
||
input_0
.
get_shape
()
!=
input_2
.
get_shape
())
NODE_VALIDATION_ASSERT
(
this
,
{
input_0
.
get_shape
()
==
input_1
.
get_shape
()
&&
throw
ngraph_error
(
"Arguments must have the same shape"
);
input_0
.
get_shape
()
==
input_2
.
get_shape
())
}
<<
"Arguments do not all have the same shape (arg0 shape: "
<<
input_0
.
get_shape
()
if
(
input_1
.
get_element_type
()
!=
input_2
.
get_element_type
())
<<
", arg1 shape: "
<<
input_1
.
get_shape
()
<<
", arg2 shape: "
<<
input_2
.
get_shape
()
{
<<
")."
;
throw
ngraph_error
(
"Arguments 1 and 2 must have the same element type"
);
}
NODE_VALIDATION_ASSERT
(
this
,
input_1
.
get_element_type
()
==
input_2
.
get_element_type
())
<<
"Arguments 1 and 2 do not have the same element type (arg1 type: "
<<
input_1
.
get_element_type
()
<<
", arg2 type: "
<<
input_2
.
get_element_type
()
<<
")."
;
set_output_type
(
0
,
input_1
.
get_element_type
(),
input_1
.
get_shape
());
set_output_type
(
0
,
input_1
.
get_element_type
(),
input_1
.
get_shape
());
}
}
shared_ptr
<
Node
>
op
::
Select
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Select
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
3
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Select
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
2
));
return
make_shared
<
Select
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
2
));
}
}
...
...
src/ngraph/op/select_and_scatter.cpp
View file @
c386da90
...
@@ -222,10 +222,7 @@ op::SelectAndScatter::SelectAndScatter(const shared_ptr<Node>& arg_selectee,
...
@@ -222,10 +222,7 @@ op::SelectAndScatter::SelectAndScatter(const shared_ptr<Node>& arg_selectee,
shared_ptr
<
Node
>
op
::
SelectAndScatter
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
SelectAndScatter
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
3
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
auto
node
=
make_shared
<
SelectAndScatter
>
(
new_args
.
at
(
0
),
auto
node
=
make_shared
<
SelectAndScatter
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
1
),
new_args
.
at
(
2
),
new_args
.
at
(
2
),
...
...
src/ngraph/op/sigmoid.cpp
View file @
c386da90
...
@@ -23,11 +23,7 @@ using namespace ngraph;
...
@@ -23,11 +23,7 @@ using namespace ngraph;
shared_ptr
<
Node
>
op
::
Sigmoid
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Sigmoid
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Sigmoid
>
(
new_args
.
at
(
0
));
return
make_shared
<
Sigmoid
>
(
new_args
.
at
(
0
));
}
}
...
@@ -41,23 +37,20 @@ op::Sigmoid::Sigmoid(shared_ptr<Node> arg)
...
@@ -41,23 +37,20 @@ op::Sigmoid::Sigmoid(shared_ptr<Node> arg)
op
::
SigmoidBackprop
::
SigmoidBackprop
(
shared_ptr
<
Node
>
arg
,
shared_ptr
<
Node
>
delta
)
op
::
SigmoidBackprop
::
SigmoidBackprop
(
shared_ptr
<
Node
>
arg
,
shared_ptr
<
Node
>
delta
)
:
Op
(
"SigmoidBackprop"
,
check_single_output_args
({
arg
,
delta
}))
:
Op
(
"SigmoidBackprop"
,
check_single_output_args
({
arg
,
delta
}))
{
{
if
(
arg
->
get_element_type
()
!
=
delta
->
get_element_type
())
NODE_VALIDATION_ASSERT
(
this
,
arg
->
get_element_type
()
=
=
delta
->
get_element_type
())
{
<<
"Argument and delta element types do not match (argument element type: "
throw
ngraph_error
(
"Argument and delta element types for Sigmoid backprop do not match"
)
;
<<
arg
->
get_element_type
()
<<
", delta element type: "
<<
delta
->
get_element_type
()
<<
")."
;
}
if
(
arg
->
get_shape
()
!
=
delta
->
get_shape
())
NODE_VALIDATION_ASSERT
(
this
,
arg
->
get_shape
()
=
=
delta
->
get_shape
())
{
<<
"Argument and delta shapes do not match (argument shape: "
<<
arg
->
get_shape
()
throw
ngraph_error
(
"Argument and delta shape for Sigmoid backprop do not match"
)
;
<<
", delta shape: "
<<
delta
->
get_shape
()
<<
")."
;
}
set_output_type
(
0
,
delta
->
get_element_type
(),
delta
->
get_shape
());
set_output_type
(
0
,
delta
->
get_element_type
(),
delta
->
get_shape
());
}
}
shared_ptr
<
Node
>
op
::
SigmoidBackprop
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
SigmoidBackprop
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
SigmoidBackprop
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
return
make_shared
<
SigmoidBackprop
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
}
}
...
...
src/ngraph/op/sign.cpp
View file @
c386da90
...
@@ -27,9 +27,6 @@ op::Sign::Sign(const shared_ptr<Node>& arg)
...
@@ -27,9 +27,6 @@ op::Sign::Sign(const shared_ptr<Node>& arg)
shared_ptr
<
Node
>
op
::
Sign
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Sign
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Sign
>
(
new_args
.
at
(
0
));
return
make_shared
<
Sign
>
(
new_args
.
at
(
0
));
}
}
src/ngraph/op/sin.cpp
View file @
c386da90
...
@@ -29,10 +29,7 @@ op::Sin::Sin(const shared_ptr<Node>& arg)
...
@@ -29,10 +29,7 @@ op::Sin::Sin(const shared_ptr<Node>& arg)
shared_ptr
<
Node
>
op
::
Sin
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Sin
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Sin
>
(
new_args
.
at
(
0
));
return
make_shared
<
Sin
>
(
new_args
.
at
(
0
));
}
}
...
...
src/ngraph/op/sinh.cpp
View file @
c386da90
...
@@ -29,10 +29,7 @@ op::Sinh::Sinh(const shared_ptr<Node>& arg)
...
@@ -29,10 +29,7 @@ op::Sinh::Sinh(const shared_ptr<Node>& arg)
shared_ptr
<
Node
>
op
::
Sinh
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Sinh
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Sinh
>
(
new_args
.
at
(
0
));
return
make_shared
<
Sinh
>
(
new_args
.
at
(
0
));
}
}
...
...
src/ngraph/op/slice.cpp
View file @
c386da90
...
@@ -51,42 +51,35 @@ void op::Slice::validate_and_infer_types()
...
@@ -51,42 +51,35 @@ void op::Slice::validate_and_infer_types()
auto
&
input
=
get_inputs
().
at
(
0
);
auto
&
input
=
get_inputs
().
at
(
0
);
auto
&
input_shape
=
input
.
get_shape
();
auto
&
input_shape
=
input
.
get_shape
();
if
(
m_lower_bounds
.
size
()
!=
input_shape
.
size
())
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
.
size
()
==
input_shape
.
size
())
{
<<
"Rank of lower bounds ("
<<
m_lower_bounds
.
size
()
<<
") does not match rank "
throw
ngraph_error
(
<<
"of argument ("
<<
input_shape
.
size
()
<<
") (lower bounds: "
<<
m_lower_bounds
"Number of lower bounds provided for slice does not match number of input axes"
);
<<
", argument shape: "
<<
input_shape
<<
")."
;
}
if
(
m_upper_bounds
.
size
()
!=
input_shape
.
size
())
NODE_VALIDATION_ASSERT
(
this
,
m_upper_bounds
.
size
()
==
input_shape
.
size
())
{
<<
"Rank of upper bounds ("
<<
m_upper_bounds
.
size
()
<<
") does not match rank "
throw
ngraph_error
(
<<
"of argument ("
<<
input_shape
.
size
()
<<
") (upper bounds: "
<<
m_upper_bounds
"Number of upper bounds provided for slice does not match number of input axes"
);
<<
", argument shape: "
<<
input_shape
<<
")."
;
}
if
(
m_strides
.
size
()
!=
input_shape
.
size
())
NODE_VALIDATION_ASSERT
(
this
,
m_strides
.
size
()
==
input_shape
.
size
())
{
<<
"Rank of strides ("
<<
m_strides
.
size
()
<<
") does not match rank "
throw
ngraph_error
(
<<
"of argument ("
<<
input_shape
.
size
()
<<
") (strides: "
<<
m_strides
"Number of strides provided for slice does not match number of input axes"
);
<<
", argument shape: "
<<
input_shape
<<
")."
;
}
Shape
result_shape
;
Shape
result_shape
;
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
i
++
)
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
i
++
)
{
{
if
(
m_upper_bounds
[
i
]
>
input_shape
[
i
])
NODE_VALIDATION_ASSERT
(
this
,
m_upper_bounds
[
i
]
<=
input_shape
[
i
])
{
<<
"Upper bound for slice at axis "
<<
i
<<
" is out of range "
throw
ngraph_error
(
"Upper bound for slice is out of range"
);
<<
"(upper bounds: "
<<
m_upper_bounds
<<
", argument shape: "
<<
input_shape
<<
")."
;
}
if
(
m_lower_bounds
[
i
]
>
m_upper_bounds
[
i
])
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
])
{
<<
"Lower bound for slice is greater than upper bound at axis "
<<
i
throw
ngraph_error
(
"Lower bound for slice is greater than upper bound"
);
<<
" (lower bounds: "
<<
m_lower_bounds
<<
", upper bounds: "
<<
m_upper_bounds
<<
")."
;
}
if
(
0
==
m_strides
[
i
])
NODE_VALIDATION_ASSERT
(
this
,
m_strides
[
i
]
!=
0
)
<<
"Stride for slice is zero at axis "
<<
i
{
<<
" (strides: "
<<
m_strides
<<
")."
;
throw
ngraph_error
(
"Strides distance for slice is zero"
);
}
size_t
result_axis_size
=
m_upper_bounds
[
i
]
-
m_lower_bounds
[
i
];
size_t
result_axis_size
=
m_upper_bounds
[
i
]
-
m_lower_bounds
[
i
];
result_axis_size
=
result_axis_size
=
...
@@ -99,10 +92,7 @@ void op::Slice::validate_and_infer_types()
...
@@ -99,10 +92,7 @@ void op::Slice::validate_and_infer_types()
shared_ptr
<
Node
>
op
::
Slice
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Slice
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Slice
>
(
new_args
.
at
(
0
),
m_lower_bounds
,
m_upper_bounds
,
m_strides
);
return
make_shared
<
Slice
>
(
new_args
.
at
(
0
),
m_lower_bounds
,
m_upper_bounds
,
m_strides
);
}
}
...
...
src/ngraph/op/softmax.cpp
View file @
c386da90
...
@@ -37,10 +37,9 @@ op::Softmax::Softmax(const shared_ptr<Node>& arg, const AxisSet& axes)
...
@@ -37,10 +37,9 @@ op::Softmax::Softmax(const shared_ptr<Node>& arg, const AxisSet& axes)
for
(
auto
axis
:
m_axes
)
for
(
auto
axis
:
m_axes
)
{
{
if
(
axis
>=
get_shape
().
size
())
NODE_VALIDATION_ASSERT
(
this
,
axis
<
get_shape
().
size
())
{
<<
"Reduction axis ("
<<
axis
<<
") is out of bounds (argument shape: "
<<
get_shape
()
throw
ngraph_error
(
"Axis for softmax reduction operator is out of bounds"
);
<<
")."
;
}
}
}
// empty axes == all axes
// empty axes == all axes
...
@@ -55,10 +54,7 @@ op::Softmax::Softmax(const shared_ptr<Node>& arg, const AxisSet& axes)
...
@@ -55,10 +54,7 @@ op::Softmax::Softmax(const shared_ptr<Node>& arg, const AxisSet& axes)
shared_ptr
<
Node
>
op
::
Softmax
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Softmax
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Softmax
>
(
new_args
.
at
(
0
),
m_axes
);
return
make_shared
<
Softmax
>
(
new_args
.
at
(
0
),
m_axes
);
}
}
...
...
src/ngraph/op/sqrt.cpp
View file @
c386da90
...
@@ -29,10 +29,7 @@ op::Sqrt::Sqrt(const shared_ptr<Node>& arg)
...
@@ -29,10 +29,7 @@ op::Sqrt::Sqrt(const shared_ptr<Node>& arg)
shared_ptr
<
Node
>
op
::
Sqrt
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Sqrt
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Sqrt
>
(
new_args
.
at
(
0
));
return
make_shared
<
Sqrt
>
(
new_args
.
at
(
0
));
}
}
...
...
src/ngraph/op/stop_gradient.cpp
View file @
c386da90
...
@@ -29,9 +29,6 @@ op::StopGradient::StopGradient(const shared_ptr<Node>& arg)
...
@@ -29,9 +29,6 @@ op::StopGradient::StopGradient(const shared_ptr<Node>& arg)
shared_ptr
<
Node
>
op
::
StopGradient
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
StopGradient
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
StopGradient
>
(
new_args
.
at
(
0
));
return
make_shared
<
StopGradient
>
(
new_args
.
at
(
0
));
}
}
src/ngraph/op/subtract.cpp
View file @
c386da90
...
@@ -28,10 +28,7 @@ op::Subtract::Subtract(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg
...
@@ -28,10 +28,7 @@ op::Subtract::Subtract(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg
shared_ptr
<
Node
>
op
::
Subtract
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Subtract
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
2
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Subtract
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
return
make_shared
<
Subtract
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
}
}
...
...
src/ngraph/op/sum.cpp
View file @
c386da90
...
@@ -28,10 +28,7 @@ op::Sum::Sum(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
...
@@ -28,10 +28,7 @@ op::Sum::Sum(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
shared_ptr
<
Node
>
op
::
Sum
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Sum
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Sum
>
(
new_args
.
at
(
0
),
m_reduction_axes
);
return
make_shared
<
Sum
>
(
new_args
.
at
(
0
),
m_reduction_axes
);
}
}
...
...
src/ngraph/op/tan.cpp
View file @
c386da90
...
@@ -30,10 +30,7 @@ op::Tan::Tan(const shared_ptr<Node>& arg)
...
@@ -30,10 +30,7 @@ op::Tan::Tan(const shared_ptr<Node>& arg)
shared_ptr
<
Node
>
op
::
Tan
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Tan
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Tan
>
(
new_args
.
at
(
0
));
return
make_shared
<
Tan
>
(
new_args
.
at
(
0
));
}
}
...
...
src/ngraph/op/tanh.cpp
View file @
c386da90
...
@@ -29,10 +29,7 @@ op::Tanh::Tanh(const shared_ptr<Node>& arg)
...
@@ -29,10 +29,7 @@ op::Tanh::Tanh(const shared_ptr<Node>& arg)
shared_ptr
<
Node
>
op
::
Tanh
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Tanh
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
if
(
new_args
.
size
()
!=
1
)
check_new_args_count
(
this
,
new_args
);
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Tanh
>
(
new_args
.
at
(
0
));
return
make_shared
<
Tanh
>
(
new_args
.
at
(
0
));
}
}
...
...
src/ngraph/op/util/arithmetic_reduction.cpp
View file @
c386da90
...
@@ -33,10 +33,10 @@ void op::util::ArithmeticReduction::validate_and_infer_types()
...
@@ -33,10 +33,10 @@ void op::util::ArithmeticReduction::validate_and_infer_types()
for
(
auto
axis
:
m_reduction_axes
)
for
(
auto
axis
:
m_reduction_axes
)
{
{
if
(
axis
>=
input_shape
.
size
())
NODE_VALIDATION_ASSERT
(
this
,
axis
<
input_shape
.
size
())
{
<<
"Reduction axis ("
<<
axis
<<
") is out of bounds "
throw
ngraph_error
(
"Reduction axis for arithmetic reduction operator is out of bounds"
);
<<
"(argument shape: "
<<
input_shape
<<
", reduction axes: "
<<
m_reduction_axes
}
<<
")"
;
}
}
Shape
result_shape
;
Shape
result_shape
;
...
...
src/ngraph/op/util/index_reduction.cpp
View file @
c386da90
...
@@ -31,9 +31,10 @@ op::util::IndexReduction::IndexReduction(const std::string& node_type,
...
@@ -31,9 +31,10 @@ op::util::IndexReduction::IndexReduction(const std::string& node_type,
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
auto
rank
=
arg
->
get_shape
().
size
();
auto
rank
=
arg
->
get_shape
().
size
();
TYPE_CHECK_ASSERT
(
this
,
rank
>=
1
)
<<
"Tensor's rank must be at least 1"
;
NODE_VALIDATION_ASSERT
(
this
,
rank
>=
1
)
<<
"Argument rank must be at least 1"
;
TYPE_CHECK_ASSERT
(
this
,
axis
<
rank
)
<<
"Axis "
<<
axis
<<
" is greater than rank of "
<<
rank
;
NODE_VALIDATION_ASSERT
(
this
,
axis
<
rank
)
<<
"Axis "
<<
axis
<<
" is greater than rank of "
TYPE_CHECK_ASSERT
(
this
,
<<
rank
;
NODE_VALIDATION_ASSERT
(
this
,
index_element_type
==
element
::
i32
||
index_element_type
==
element
::
i64
)
index_element_type
==
element
::
i32
||
index_element_type
==
element
::
i64
)
<<
"Index element type must be i64 or i32"
;
<<
"Index element type must be i64 or i32"
;
...
...
test/type_prop.cpp
View file @
c386da90
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment