Unverified Commit 4c5dbf07 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Cyphers/uop (#3903)

* Address op_tbl issues

* fix

* fix

* fix

* Cleanup

* cleanup

* cleanup

* More fixes

* Revert ser changes

* Compiles

* opset conversion fixed

* Fix opset conversion tests

* Deal with Reciprocal and FloorMod movement

* Cleanup

* Remove duplicate enums

* Experiment

* experiment

* Types

* Reorg around clang 3.9 bug
parent 16e256fa
...@@ -105,13 +105,45 @@ std::shared_ptr<Node> ...@@ -105,13 +105,45 @@ std::shared_ptr<Node>
Node::copy_with_new_inputs(const OutputVector& inputs, Node::copy_with_new_inputs(const OutputVector& inputs,
const std::vector<std::shared_ptr<Node>>& control_dependencies) const const std::vector<std::shared_ptr<Node>>& control_dependencies) const
{ {
bool for_get_output_element = is_type<op::GetOutputElement>(this); shared_ptr<Node> clone;
if (is_type<op::GetOutputElement>(this))
{
auto& value = inputs.at(0);
clone = make_shared<op::GetOutputElement>(value.get_node_shared_ptr(), value.get_index());
}
else
{
NodeVector args; NodeVector args;
for (const Output<Node>& input : inputs) for (const Output<Node>& input : inputs)
{ {
args.push_back(get_output_element(input, for_get_output_element)); args.push_back(get_output_element(input, false));
}
for (int i = 0; i < inputs.size(); ++i)
{
auto in_val = inputs.at(i);
if (is_type<op::GetOutputElement>(in_val.get_node()))
{
in_val = as_type_ptr<op::GetOutputElement>(in_val.get_node_shared_ptr())
->get_as_output();
}
auto in_index = in_val.get_index();
auto arg = args.at(i);
size_t out_index = 0;
if (is_type<op::GetOutputElement>(arg))
{
out_index = as_type_ptr<op::GetOutputElement>(arg)->get_n();
}
if (in_index != out_index)
{
cerr << "Mismatch in: " << in_index << " arg: " << out_index << endl;
cerr << "ARG: " << *arg << endl;
cerr << "IN: " << *inputs.at(i).get_node() << endl;
cerr << "INV: " << *in_val.get_node() << endl;
cerr << "In node " << *this << endl;
}
}
clone = copy_with_new_args(args);
} }
shared_ptr<Node> clone = copy_with_new_args(args);
for (auto& cdep : control_dependencies) for (auto& cdep : control_dependencies)
{ {
clone->add_control_dependency(cdep); clone->add_control_dependency(cdep);
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// This collection contains one entry for each fused op.
//
#ifndef NGRAPH_OP
#warning "NGRAPH_OP not defined"
#define NGRAPH_OP(x, y)
#endif
NGRAPH_OP(BatchMatMulTranspose, ngraph::op)
NGRAPH_OP(Clamp, ngraph::op)
NGRAPH_OP(ConvolutionBias, ngraph::op)
NGRAPH_OP(ConvolutionBiasAdd, ngraph::op)
NGRAPH_OP(ConvolutionBiasBackpropFiltersBias, ngraph::op)
NGRAPH_OP(CrossEntropy, ngraph::op)
NGRAPH_OP(CrossEntropyBackprop, ngraph::op)
NGRAPH_OP(DepthToSpace, ngraph::op)
NGRAPH_OP(Elu, ngraph::op)
NGRAPH_OP(FakeQuantize, ngraph::op)
NGRAPH_OP(Gelu, ngraph::op)
NGRAPH_OP(GeluBackpropFactor, ngraph::op)
NGRAPH_OP(Gemm, ngraph::op)
NGRAPH_OP(GRN, ngraph::op)
NGRAPH_OP(GroupConvolution, ngraph::op)
NGRAPH_OP(GroupConvolutionTranspose, ngraph::op)
NGRAPH_OP(GRUCell, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(LayerNorm, ngraph::op)
NGRAPH_OP(LayerNormBackprop, ngraph::op)
NGRAPH_OP(LogSoftmax, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(LSTMSequence, ngraph::op)
NGRAPH_OP(MatMul, ngraph::op)
NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP(NormalizeL2, ngraph::op)
NGRAPH_OP(PartialSlice, ngraph::op)
NGRAPH_OP(PartialSliceBackprop, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op)
NGRAPH_OP(Reciprocal, ngraph::op)
NGRAPH_OP(RNNCell, ngraph::op)
NGRAPH_OP(ScaleShift, ngraph::op)
NGRAPH_OP(Selu, ngraph::op)
NGRAPH_OP(ShuffleChannels, ngraph::op)
NGRAPH_OP(SpaceToDepth, ngraph::op)
NGRAPH_OP(Split, ngraph::op)
NGRAPH_OP(SquaredDifference, ngraph::op)
NGRAPH_OP(SoftmaxCrossEntropy, ngraph::op)
NGRAPH_OP(SoftmaxCrossEntropyBackprop, ngraph::op)
NGRAPH_OP(Squeeze, ngraph::op)
NGRAPH_OP(TensorIterator, ngraph::op)
NGRAPH_OP(Unsqueeze, ngraph::op)
...@@ -65,10 +65,10 @@ NGRAPH_OP(Atan2, ngraph::op) ...@@ -65,10 +65,10 @@ NGRAPH_OP(Atan2, ngraph::op)
NGRAPH_OP(AvgPool, ngraph::op) NGRAPH_OP(AvgPool, ngraph::op)
NGRAPH_OP(AvgPoolBackprop, ngraph::op) NGRAPH_OP(AvgPoolBackprop, ngraph::op)
NGRAPH_OP(BatchMatMul, ngraph::op) NGRAPH_OP(BatchMatMul, ngraph::op)
NGRAPH_OP(BatchMatMulTranspose, ngraph::op)
NGRAPH_OP(BatchNormInference, ngraph::op) NGRAPH_OP(BatchNormInference, ngraph::op)
NGRAPH_OP(BatchNormTraining, ngraph::op) NGRAPH_OP(BatchNormTraining, ngraph::op)
NGRAPH_OP(BatchNormTrainingBackprop, ngraph::op) NGRAPH_OP(BatchNormTrainingBackprop, ngraph::op)
NGRAPH_OP(BinaryConvolution, ngraph::op)
NGRAPH_OP(Broadcast, ngraph::op) NGRAPH_OP(Broadcast, ngraph::op)
NGRAPH_OP(BroadcastDistributed, ngraph::op) NGRAPH_OP(BroadcastDistributed, ngraph::op)
NGRAPH_OP(BroadcastLike, ngraph::op) NGRAPH_OP(BroadcastLike, ngraph::op)
...@@ -79,8 +79,13 @@ NGRAPH_OP(Convert, ngraph::op) ...@@ -79,8 +79,13 @@ NGRAPH_OP(Convert, ngraph::op)
NGRAPH_OP(Convolution, ngraph::op) NGRAPH_OP(Convolution, ngraph::op)
NGRAPH_OP(ConvolutionBackpropData, ngraph::op) NGRAPH_OP(ConvolutionBackpropData, ngraph::op)
NGRAPH_OP(ConvolutionBackpropFilters, ngraph::op) NGRAPH_OP(ConvolutionBackpropFilters, ngraph::op)
NGRAPH_OP(ConvolutionBias, ngraph::op)
NGRAPH_OP(ConvolutionBiasAdd, ngraph::op)
NGRAPH_OP(ConvolutionBiasBackpropFiltersBias, ngraph::op)
NGRAPH_OP(Cos, ngraph::op) NGRAPH_OP(Cos, ngraph::op)
NGRAPH_OP(Cosh, ngraph::op) NGRAPH_OP(Cosh, ngraph::op)
NGRAPH_OP(CrossEntropy, ngraph::op)
NGRAPH_OP(CrossEntropyBackprop, ngraph::op)
NGRAPH_OP(Dequantize, ngraph::op) NGRAPH_OP(Dequantize, ngraph::op)
NGRAPH_OP(Divide, ngraph::op) NGRAPH_OP(Divide, ngraph::op)
NGRAPH_OP(Dot, ngraph::op) NGRAPH_OP(Dot, ngraph::op)
...@@ -94,22 +99,24 @@ NGRAPH_OP(Equal, ngraph::op) ...@@ -94,22 +99,24 @@ NGRAPH_OP(Equal, ngraph::op)
NGRAPH_OP(Erf, ngraph::op) NGRAPH_OP(Erf, ngraph::op)
NGRAPH_OP(Exp, ngraph::op) NGRAPH_OP(Exp, ngraph::op)
NGRAPH_OP(Floor, ngraph::op) NGRAPH_OP(Floor, ngraph::op)
NGRAPH_OP(FloorMod, ngraph::op) NGRAPH_OP(GRN, ngraph::op)
NGRAPH_OP(GRUCell, ngraph::op)
NGRAPH_OP(Gather, ngraph::op) NGRAPH_OP(Gather, ngraph::op)
NGRAPH_OP(GatherND, ngraph::op) NGRAPH_OP(GatherND, ngraph::op)
NGRAPH_OP(Gelu, ngraph::op)
NGRAPH_OP(GeluBackpropFactor, ngraph::op)
NGRAPH_OP(Gemm, ngraph::op)
NGRAPH_OP(GenerateMask, ngraph::op) NGRAPH_OP(GenerateMask, ngraph::op)
NGRAPH_OP(GetOutputElement, ngraph::op) NGRAPH_OP(GetOutputElement, ngraph::op)
NGRAPH_OP(Greater, ngraph::op) NGRAPH_OP(Greater, ngraph::op)
NGRAPH_OP(GreaterEq, ngraph::op) NGRAPH_OP(GreaterEq, ngraph::op)
NGRAPH_OP(GroupConvolutionTranspose, ngraph::op)
NGRAPH_OP(LayerNorm, ngraph::op)
NGRAPH_OP(LayerNormBackprop, ngraph::op)
NGRAPH_OP(Less, ngraph::op) NGRAPH_OP(Less, ngraph::op)
NGRAPH_OP(LessEq, ngraph::op) NGRAPH_OP(LessEq, ngraph::op)
NGRAPH_OP(LessEqual, ngraph::op)
NGRAPH_OP(Log, ngraph::op) NGRAPH_OP(Log, ngraph::op)
NGRAPH_OP(LogicalAnd, ngraph::op) NGRAPH_OP(LogSoftmax, ngraph::op)
NGRAPH_OP(LogicalNot, ngraph::op)
NGRAPH_OP(LogicalOr, ngraph::op)
NGRAPH_OP(LogicalXor, ngraph::op)
NGRAPH_OP(LRN, ngraph::op)
NGRAPH_OP(Max, ngraph::op) NGRAPH_OP(Max, ngraph::op)
NGRAPH_OP(Maximum, ngraph::op) NGRAPH_OP(Maximum, ngraph::op)
NGRAPH_OP(MaxPool, ngraph::op) NGRAPH_OP(MaxPool, ngraph::op)
...@@ -117,6 +124,7 @@ NGRAPH_OP(MaxPoolBackprop, ngraph::op) ...@@ -117,6 +124,7 @@ NGRAPH_OP(MaxPoolBackprop, ngraph::op)
NGRAPH_OP(Min, ngraph::op) NGRAPH_OP(Min, ngraph::op)
NGRAPH_OP(Minimum, ngraph::op) NGRAPH_OP(Minimum, ngraph::op)
NGRAPH_OP(Multiply, ngraph::op) NGRAPH_OP(Multiply, ngraph::op)
NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP(Negative, ngraph::op) NGRAPH_OP(Negative, ngraph::op)
NGRAPH_OP(Not, ngraph::op) NGRAPH_OP(Not, ngraph::op)
NGRAPH_OP(NotEqual, ngraph::op) NGRAPH_OP(NotEqual, ngraph::op)
...@@ -124,6 +132,8 @@ NGRAPH_OP(OneHot, ngraph::op) ...@@ -124,6 +132,8 @@ NGRAPH_OP(OneHot, ngraph::op)
NGRAPH_OP(Or, ngraph::op) NGRAPH_OP(Or, ngraph::op)
NGRAPH_OP(Pad, ngraph::op) NGRAPH_OP(Pad, ngraph::op)
NGRAPH_OP(Parameter, ngraph::op) NGRAPH_OP(Parameter, ngraph::op)
NGRAPH_OP(PartialSlice, ngraph::op)
NGRAPH_OP(PartialSliceBackprop, ngraph::op)
NGRAPH_OP(Passthrough, ngraph::op) NGRAPH_OP(Passthrough, ngraph::op)
NGRAPH_OP(Power, ngraph::op) NGRAPH_OP(Power, ngraph::op)
NGRAPH_OP(Product, ngraph::op) NGRAPH_OP(Product, ngraph::op)
...@@ -135,9 +145,10 @@ NGRAPH_OP(QuantizedConvolutionBiasSignedAdd, ngraph::op) ...@@ -135,9 +145,10 @@ NGRAPH_OP(QuantizedConvolutionBiasSignedAdd, ngraph::op)
NGRAPH_OP(QuantizedConvolutionRelu, ngraph::op) NGRAPH_OP(QuantizedConvolutionRelu, ngraph::op)
NGRAPH_OP(QuantizedDot, ngraph::op) NGRAPH_OP(QuantizedDot, ngraph::op)
NGRAPH_OP(QuantizedDotBias, ngraph::op) NGRAPH_OP(QuantizedDotBias, ngraph::op)
NGRAPH_OP(Recv, ngraph::op)
NGRAPH_OP(RandomUniform, ngraph::op) NGRAPH_OP(RandomUniform, ngraph::op)
NGRAPH_OP(Recv, ngraph::op)
NGRAPH_OP(Range, ngraph::op) NGRAPH_OP(Range, ngraph::op)
NGRAPH_OP(Reciprocal, ngraph::op)
NGRAPH_OP(Relu, ngraph::op) NGRAPH_OP(Relu, ngraph::op)
NGRAPH_OP(ReluBackprop, ngraph::op) NGRAPH_OP(ReluBackprop, ngraph::op)
NGRAPH_OP(ReplaceSlice, ngraph::op) NGRAPH_OP(ReplaceSlice, ngraph::op)
...@@ -146,9 +157,11 @@ NGRAPH_OP(Result, ngraph::op) ...@@ -146,9 +157,11 @@ NGRAPH_OP(Result, ngraph::op)
NGRAPH_OP(Reverse, ngraph::op) NGRAPH_OP(Reverse, ngraph::op)
NGRAPH_OP(ReverseSequence, ngraph::op) NGRAPH_OP(ReverseSequence, ngraph::op)
NGRAPH_OP(ScalarConstantLike, ngraph::op) NGRAPH_OP(ScalarConstantLike, ngraph::op)
NGRAPH_OP(ScaleShift, ngraph::op)
NGRAPH_OP(ScatterAdd, ngraph::op) NGRAPH_OP(ScatterAdd, ngraph::op)
NGRAPH_OP(ScatterNDAdd, ngraph::op) NGRAPH_OP(ScatterNDAdd, ngraph::op)
NGRAPH_OP(Select, ngraph::op) NGRAPH_OP(Select, ngraph::op)
NGRAPH_OP(Selu, ngraph::op)
NGRAPH_OP(Send, ngraph::op) NGRAPH_OP(Send, ngraph::op)
NGRAPH_OP(ShapeOf, ngraph::op) NGRAPH_OP(ShapeOf, ngraph::op)
NGRAPH_OP(Sigmoid, ngraph::op) NGRAPH_OP(Sigmoid, ngraph::op)
...@@ -158,6 +171,8 @@ NGRAPH_OP(Sin, ngraph::op) ...@@ -158,6 +171,8 @@ NGRAPH_OP(Sin, ngraph::op)
NGRAPH_OP(Sinh, ngraph::op) NGRAPH_OP(Sinh, ngraph::op)
NGRAPH_OP(Slice, ngraph::op) NGRAPH_OP(Slice, ngraph::op)
NGRAPH_OP(Softmax, ngraph::op) NGRAPH_OP(Softmax, ngraph::op)
NGRAPH_OP(SoftmaxCrossEntropy, ngraph::op)
NGRAPH_OP(SoftmaxCrossEntropyBackprop, ngraph::op)
NGRAPH_OP(Sqrt, ngraph::op) NGRAPH_OP(Sqrt, ngraph::op)
NGRAPH_OP(StopGradient, ngraph::op) NGRAPH_OP(StopGradient, ngraph::op)
NGRAPH_OP(Subtract, ngraph::op) NGRAPH_OP(Subtract, ngraph::op)
...@@ -165,7 +180,6 @@ NGRAPH_OP(Sum, ngraph::op) ...@@ -165,7 +180,6 @@ NGRAPH_OP(Sum, ngraph::op)
NGRAPH_OP(Tan, ngraph::op) NGRAPH_OP(Tan, ngraph::op)
NGRAPH_OP(Tanh, ngraph::op) NGRAPH_OP(Tanh, ngraph::op)
NGRAPH_OP(Tile, ngraph::op) NGRAPH_OP(Tile, ngraph::op)
NGRAPH_OP(TopK, ngraph::op) NGRAPH_OP(TopK, ngraph::op::v0)
NGRAPH_OP(Transpose, ngraph::op) NGRAPH_OP(Transpose, ngraph::op)
NGRAPH_OP(VariadicSplit, ngraph::op)
NGRAPH_OP(Xor, ngraph::op) NGRAPH_OP(Xor, ngraph::op)
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// This collection contains one entry for each op. If an op is added it must be
// added to this list.
//
// In order to use this list you want to define a macro named exactly NGRAPH_OP
// When you are done you should undef the macro
// As an example if you wanted to make a list of all op names as strings you could do this:
//
// #define NGRAPH_OP(a,b) #a,
// std::vector<std::string> op_names{
// #include "this include file name"
// };
// #undef NGRAPH_OP
//
// This sample expands to a list like this:
// "Abs",
// "Acos",
// ...
//
// #define NGRAPH_OP(a,b) b::a,
// std::vector<std::string> op_names{
// #include "this include file name"
// };
// #undef NGRAPH_OP
//
// This sample expands to a list like this:
// ngraph::op::Abs,
// ngraph::op::Acos,
// ...
//
// It's that easy. You can use this for fun and profit.
#ifndef NGRAPH_OP
#warning "NGRAPH_OP not defined"
#define NGRAPH_OP(x, y)
#endif
NGRAPH_OP(Abs, ngraph::op)
NGRAPH_OP(Acos, ngraph::op)
// NGRAPH_OP(Acosh, ngraph::op)
NGRAPH_OP(Add, ngraph::op::v1)
NGRAPH_OP(Asin, ngraph::op)
// NGRAPH_OP(Asinh, ngraph::op)
NGRAPH_OP(Atan, ngraph::op)
// NGRAPH_OP(Atanh, ngraph::op)
NGRAPH_OP(AvgPool, ngraph::op::v1)
NGRAPH_OP(BatchNormInference, ngraph::op)
NGRAPH_OP(BinaryConvolution, ngraph::op::v1)
NGRAPH_OP(Broadcast, ngraph::op::v1)
// NGRAPH_OP(CTCGreedyDecoder, ngraph::op)
NGRAPH_OP(Ceiling, ngraph::op)
NGRAPH_OP(Clamp, ngraph::op)
NGRAPH_OP(Concat, ngraph::op)
NGRAPH_OP(Constant, ngraph::op)
NGRAPH_OP(Convert, ngraph::op)
// NGRAPH_OP(ConvertLike, ngraph::op)
NGRAPH_OP(Convolution, ngraph::op::v1)
NGRAPH_OP(ConvolutionBackpropData, ngraph::op::v1)
NGRAPH_OP(Cos, ngraph::op)
NGRAPH_OP(Cosh, ngraph::op)
// NGRAPH_OP(DeformableConvolution, ngraph::op)
// NGRAPH_OP(DeformablePSROIPooling, ngraph::op)
NGRAPH_OP(DepthToSpace, ngraph::op)
// NGRAPH_OP(DetectionOutput, ngraph::op)
NGRAPH_OP(Divide, ngraph::op::v1)
NGRAPH_OP(Elu, ngraph::op)
NGRAPH_OP(Erf, ngraph::op)
NGRAPH_OP(Equal, ngraph::op::v1)
NGRAPH_OP(Exp, ngraph::op)
NGRAPH_OP(FakeQuantize, ngraph::op)
NGRAPH_OP(Floor, ngraph::op)
NGRAPH_OP(FloorMod, ngraph::op::v1)
NGRAPH_OP(Gather, ngraph::op::v1)
NGRAPH_OP(Greater, ngraph::op::v1)
NGRAPH_OP(GreaterEq, ngraph::op::v1)
NGRAPH_OP(GroupConvolution, ngraph::op)
// NGRAPH_OP(GroupConvolutionBackpropData, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(Interpolate, ngraph::op)
// NGRAPH_OP(LeakyRelu, ngraph::op)
NGRAPH_OP(Less, ngraph::op::v1)
NGRAPH_OP(LessEqual, ngraph::op::v1)
NGRAPH_OP(Log, ngraph::op)
NGRAPH_OP(LogicalAnd, ngraph::op::v1)
NGRAPH_OP(LogicalNot, ngraph::op::v1)
NGRAPH_OP(LogicalOr, ngraph::op::v1)
NGRAPH_OP(LogicalXor, ngraph::op::v1)
NGRAPH_OP(LRN, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(LSTMSequence, ngraph::op)
NGRAPH_OP(MatMul, ngraph::op)
NGRAPH_OP(MaxPool, ngraph::op::v1)
NGRAPH_OP(Maximum, ngraph::op::v1)
NGRAPH_OP(Minimum, ngraph::op::v1)
// NGRAPH_OP(Mod, ngraph::op)
NGRAPH_OP(Multiply, ngraph::op::v1)
NGRAPH_OP(Negative, ngraph::op)
// NGRAPH_OP(NonMaxSuppression, ngraph::op)
NGRAPH_OP(NormalizeL2, ngraph::op)
NGRAPH_OP(NotEqual, ngraph::op::v1)
NGRAPH_OP(OneHot, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op)
// NGRAPH_OP(PSROIPooling, ngraph::op)
NGRAPH_OP(Pad, ngraph::op::v1)
NGRAPH_OP(Parameter, ngraph::op)
NGRAPH_OP(Power, ngraph::op::v1)
// NGRAPH_OP(PriorBox, ngraph::op)
// NGRAPH_OP(PriorBoxClustered, ngraph::op)
// NGRAPH_OP(Proposal, ngraph::op)
NGRAPH_OP(Range, ngraph::op)
NGRAPH_OP(Relu, ngraph::op)
// NGRAPH_OP(ReduceLogicalAnd, ngraph::op)
// NGRAPH_OP(ReduceLogicalOr, ngraph::op)
NGRAPH_OP(ReduceMax, ngraph::op::v1)
// NGRAPH_OP(ReduceMean, ngraph::op)
NGRAPH_OP(ReduceMin, ngraph::op::v1)
NGRAPH_OP(ReduceProd, ngraph::op::v1)
NGRAPH_OP(ReduceSum, ngraph::op::v1)
// NGRAPH_OP(RegionYolo, ngraph::op)
NGRAPH_OP(Reshape, ngraph::op::v1)
NGRAPH_OP(Result, ngraph::op)
NGRAPH_OP(Reverse, ngraph::op::v1)
NGRAPH_OP(ReverseSequence, ngraph::op)
NGRAPH_OP(RNNCell, ngraph::op)
// NGRAPH_OP(ROIPooling, ngraph::op)
NGRAPH_OP(ShapeOf, ngraph::op)
NGRAPH_OP(ShuffleChannels, ngraph::op)
NGRAPH_OP(Sign, ngraph::op)
NGRAPH_OP(Sigmoid, ngraph::op)
NGRAPH_OP(Sin, ngraph::op)
NGRAPH_OP(Sinh, ngraph::op)
NGRAPH_OP(Softmax, ngraph::op::v1)
NGRAPH_OP(Sqrt, ngraph::op)
NGRAPH_OP(SpaceToDepth, ngraph::op)
NGRAPH_OP(Split, ngraph::op)
NGRAPH_OP(SquaredDifference, ngraph::op)
NGRAPH_OP(Squeeze, ngraph::op)
NGRAPH_OP(StridedSlice, ngraph::op::v1)
NGRAPH_OP(Subtract, ngraph::op)
NGRAPH_OP(Tan, ngraph::op)
NGRAPH_OP(Tanh, ngraph::op)
NGRAPH_OP(TensorIterator, ngraph::op)
NGRAPH_OP(Tile, ngraph::op)
NGRAPH_OP(TopK, ngraph::op::v1)
NGRAPH_OP(Transpose, ngraph::op)
NGRAPH_OP(Unsqueeze, ngraph::op)
NGRAPH_OP(VariadicSplit, ngraph::op::v1)
// Related to v1
NGRAPH_OP(AvgPoolBackprop, ngraph::op::v1)
NGRAPH_OP(ConvolutionBackpropFilters, ngraph::op::v1)
NGRAPH_OP(MaxPoolBackprop, ngraph::op::v1)
// Other
NGRAPH_OP(GenerateMask, ngraph::op::v1)
This diff is collapsed.
...@@ -15,40 +15,147 @@ ...@@ -15,40 +15,147 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/pass/opset1_upgrade.hpp" #include "ngraph/pass/opset1_upgrade.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp" #include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/atan2.hpp"
#include "ngraph/op/avg_pool.hpp" #include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/binary_convolution.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast_distributed.hpp"
#include "ngraph/op/ceiling.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/convolution.hpp" #include "ngraph/op/convolution.hpp"
#include "ngraph/op/cos.hpp"
#include "ngraph/op/cosh.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/equal.hpp" #include "ngraph/op/equal.hpp"
#include "ngraph/op/erf.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/dyn_replace_slice.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp" #include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/layers/interpolate.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/random_uniform.hpp"
#include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/tile.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/floor_mod.hpp"
#include "ngraph/op/fused/batch_mat_mul_transpose.hpp"
#include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/crossentropy.hpp"
#include "ngraph/op/fused/depth_to_space.hpp"
#include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/fused/fake_quantize.hpp"
#include "ngraph/op/fused/gelu.hpp"
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/layer_norm.hpp"
#include "ngraph/op/fused/log_softmax.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/fused/matmul.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/fused/partial_slice.hpp"
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/reciprocal.hpp"
#include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/selu.hpp"
#include "ngraph/op/fused/shuffle_channels.hpp"
#include "ngraph/op/fused/softmax_crossentropy.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/fused/split.hpp"
#include "ngraph/op/fused/squared_difference.hpp"
#include "ngraph/op/fused/squeeze.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp" #include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp" #include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp" #include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp" #include "ngraph/op/less_eq.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp" #include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp" #include "ngraph/op/maximum.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/minimum.hpp" #include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/not.hpp" #include "ngraph/op/not.hpp"
#include "ngraph/op/not_equal.hpp" #include "ngraph/op/not_equal.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/or.hpp" #include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp" #include "ngraph/op/pad.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/passthrough.hpp"
#include "ngraph/op/power.hpp" #include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/quantized_convolution.hpp"
#include "ngraph/op/quantized_dot.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/reduce_prod.hpp" #include "ngraph/op/reduce_prod.hpp"
#include "ngraph/op/reduce_sum.hpp" #include "ngraph/op/reduce_sum.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp" #include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/scatter_add.hpp"
#include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp"
#include "ngraph/op/sinh.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp" #include "ngraph/op/softmax.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/stop_gradient.hpp"
#include "ngraph/op/strided_slice.hpp" #include "ngraph/op/strided_slice.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/tensor_iterator.hpp"
#include "ngraph/op/topk.hpp" #include "ngraph/op/topk.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/xor.hpp" #include "ngraph/op/xor.hpp"
#include <limits> #include <limits>
...@@ -57,33 +164,30 @@ ...@@ -57,33 +164,30 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
#define NGRAPH_OP(a, b) a, namespace
enum class OP_TYPEID
{ {
#include "ngraph/op/fused_op_tbl.hpp" enum class OP_TYPEID
#include "ngraph/op/op_tbl.hpp" {
}; #define NGRAPH_OP(a, b) a,
#undef NGRAPH_OP #include "ngraph/op/op_v0_tbl.hpp"
#define NGRAPH_OP(a, b) {#a, OP_TYPEID::a},
static unordered_map<string, OP_TYPEID> typeid_map{
#include "ngraph/op/fused_op_tbl.hpp"
#include "ngraph/op/op_tbl.hpp"
};
#undef NGRAPH_OP #undef NGRAPH_OP
OTHER
};
}
static OP_TYPEID get_typeid(shared_ptr<Node> node) static OP_TYPEID get_typeid(shared_ptr<Node> node)
{ {
OP_TYPEID type_id; static map<NodeTypeInfo, OP_TYPEID> typeid_map{
auto it = typeid_map.find(node->description()); #define NGRAPH_OP(a, b) {b::a::type_info, OP_TYPEID::a},
#include "ngraph/op/op_v0_tbl.hpp"
#undef NGRAPH_OP
};
OP_TYPEID type_id = OP_TYPEID::OTHER;
auto it = typeid_map.find(node->get_type_info());
if (it != typeid_map.end()) if (it != typeid_map.end())
{ {
type_id = it->second; type_id = it->second;
} }
else
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
return type_id; return type_id;
} }
// END mapping to OP_TYPEID // END mapping to OP_TYPEID
...@@ -102,20 +206,6 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node) ...@@ -102,20 +206,6 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
{ {
bool modified = false; bool modified = false;
size_t op_version = node->get_version();
if (op_version == 1)
{
return modified;
}
NGRAPH_CHECK(op_version == 0,
"Op version 1 transformation pass failed for ",
*node,
", only op version 0 operations expected. Op version ",
op_version,
" found.");
// Not all enumeration values explicitly handled in switch // Not all enumeration values explicitly handled in switch
#if defined(__clang__) #if defined(__clang__)
#pragma clang diagnostic push #pragma clang diagnostic push
......
...@@ -21,7 +21,7 @@ else() ...@@ -21,7 +21,7 @@ else()
endif() endif()
if (NGRAPH_INTERPRETER_ENABLE) if (NGRAPH_INTERPRETER_ENABLE)
add_library(interpreter_backend ${LIBRARY_TYPE} int_backend.cpp node_wrapper.cpp int_executable.cpp) add_library(interpreter_backend ${LIBRARY_TYPE} int_backend.cpp int_executable.cpp)
target_compile_definitions(interpreter_backend PRIVATE INTERPRETER_BACKEND_EXPORTS) target_compile_definitions(interpreter_backend PRIVATE INTERPRETER_BACKEND_EXPORTS)
if(NGRAPH_LIB_VERSIONING_ENABLE) if(NGRAPH_LIB_VERSIONING_ENABLE)
set_target_properties(interpreter_backend PROPERTIES set_target_properties(interpreter_backend PROPERTIES
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/interpreter/node_wrapper.hpp"
using namespace ngraph;
using namespace std;
runtime::interpreter::NodeWrapper::NodeWrapper(const shared_ptr<const Node>& node)
: m_node{node}
{
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// {"Abs", runtime::interpreter::OP_TYPEID::Abs},
// {"Acos", runtime::interpreter::OP_TYPEID::Acos},
// ...
#define NGRAPH_OP(a, b) {#a, runtime::interpreter::OP_TYPEID::a},
static unordered_map<string, runtime::interpreter::OP_TYPEID> typeid_map{
#include "ngraph/op/op_tbl.hpp"
#ifdef INTERPRETER_USE_HYBRID
#include "ngraph/runtime/hybrid/op/op_tbl.hpp"
#endif
};
#undef NGRAPH_OP
auto it = typeid_map.find(m_node->description());
if (it != typeid_map.end())
{
m_typeid = it->second;
}
else
{
throw unsupported_op("Unsupported op '" + m_node->description() + "'");
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node.hpp"
namespace ngraph
{
namespace runtime
{
namespace interpreter
{
enum class OP_TYPEID;
class NodeWrapper;
}
}
}
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// Abs,
// Acos,
// ...
#define NGRAPH_OP(a, b) a,
enum class ngraph::runtime::interpreter::OP_TYPEID
{
#include "ngraph/op/op_tbl.hpp"
#ifdef INTERPRETER_USE_HYBRID
#include "ngraph/runtime/hybrid/op/op_tbl.hpp"
#endif
};
#undef NGRAPH_OP
/// \brief This class allows adding an enum typeid to each Node. This makes dealing with
/// collections of Nodes a little easier and faster as we can use switch() instead of
/// if/else statements
class ngraph::runtime::interpreter::NodeWrapper
{
public:
NodeWrapper(const std::shared_ptr<const ngraph::Node>& node);
std::shared_ptr<const Node> get_node() const { return m_node; }
ngraph::runtime::interpreter::OP_TYPEID get_typeid() const { return m_typeid; }
private:
std::shared_ptr<const ngraph::Node> m_node;
OP_TYPEID m_typeid;
};
This diff is collapsed.
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <cstring>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -38,12 +39,31 @@ namespace ngraph ...@@ -38,12 +39,31 @@ namespace ngraph
const char* name; const char* name;
uint64_t version; uint64_t version;
bool is_castable(const DiscreteTypeInfo& target_type) const { return this == &target_type; } bool is_castable(const DiscreteTypeInfo& target_type) const { return *this == target_type; }
// For use as a key // For use as a key
bool operator<(const DiscreteTypeInfo& b) const bool operator<(const DiscreteTypeInfo& b) const
{ {
return version < b.version || return version < b.version || (version == b.version && strcmp(name, b.name) < 0);
(version == b.version && std::string(name) < std::string(b.name)); }
bool operator<=(const DiscreteTypeInfo& b) const
{
return version < b.version || (version == b.version && strcmp(name, b.name) <= 0);
}
bool operator>(const DiscreteTypeInfo& b) const
{
return version < b.version || (version == b.version && strcmp(name, b.name) > 0);
}
bool operator>=(const DiscreteTypeInfo& b) const
{
return version < b.version || (version == b.version && strcmp(name, b.name) >= 0);
}
bool operator==(const DiscreteTypeInfo& b) const
{
return version == b.version && strcmp(name, b.name) == 0;
}
bool operator!=(const DiscreteTypeInfo& b) const
{
return version != b.version || strcmp(name, b.name) != 0;
} }
}; };
......
...@@ -35,11 +35,10 @@ void test_type_prop_opset0_downgrade_pass(const element::Type& output_type, ...@@ -35,11 +35,10 @@ void test_type_prop_opset0_downgrade_pass(const element::Type& output_type,
pass_manager.run_passes(f); pass_manager.run_passes(f);
auto v0_result = f->get_results().at(0); auto v0_result = f->get_results().at(0);
auto node = v0_result->input(0).get_source_output().get_node_shared_ptr(); auto node = v0_result->input_value(0).get_node_shared_ptr();
auto v0_node = static_pointer_cast<OpV0>(node); auto v0_node = as_type_ptr<OpV0>(node);
EXPECT_EQ(v0_node->description(), (node_name.empty() ? v1_node->description() : node_name)); EXPECT_TRUE(v0_node);
EXPECT_EQ(v0_node->get_version(), 0);
EXPECT_EQ(v0_node->get_autob(), np_auto_b); EXPECT_EQ(v0_node->get_autob(), np_auto_b);
EXPECT_EQ(v0_node->output(0).get_element_type(), output_type); EXPECT_EQ(v0_node->output(0).get_element_type(), output_type);
EXPECT_EQ(v0_node->output(0).get_shape(), (Shape{1, 3, 2})); EXPECT_EQ(v0_node->output(0).get_shape(), (Shape{1, 3, 2}));
......
...@@ -22,20 +22,18 @@ TEST(opset_transform, opset1_broadcast_upgrade_pass) ...@@ -22,20 +22,18 @@ TEST(opset_transform, opset1_broadcast_upgrade_pass)
pass_manager.register_pass<pass::Opset1Upgrade>(); pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f); pass_manager.run_passes(f);
auto bcast_v1 = static_pointer_cast<op::v1::Broadcast>( auto bcast_v1 = as_type_ptr<op::v1::Broadcast>(
f->get_results().at(0)->input_value(0).get_node_shared_ptr()); f->get_results().at(0)->input_value(0).get_node_shared_ptr());
EXPECT_EQ(bcast_v1->description(), "Broadcast"); EXPECT_TRUE(bcast_v1);
EXPECT_EQ(bcast_v1->get_version(), 1);
EXPECT_EQ(bcast_v1->get_broadcast_spec(), op::AutoBroadcastSpec()); EXPECT_EQ(bcast_v1->get_broadcast_spec(), op::AutoBroadcastSpec());
EXPECT_EQ(bcast_v1->get_broadcast_axes(), (std::make_pair<bool, AxisSet>(true, AxisSet{0, 2}))); EXPECT_EQ(bcast_v1->get_broadcast_axes(), (std::make_pair<bool, AxisSet>(true, AxisSet{0, 2})));
EXPECT_TRUE(bcast_v1->input_value(1).get_node()->is_constant());
EXPECT_EQ(bcast_v1->input_value(1).get_node()->description(), "Constant"); EXPECT_TRUE(bcast_v1->input_value(2).get_node()->is_constant());
EXPECT_EQ(bcast_v1->input_value(2).get_node()->description(), "Constant");
EXPECT_EQ(static_pointer_cast<op::Constant>(bcast_v1->input_value(1).get_node_shared_ptr()) EXPECT_EQ(static_pointer_cast<op::Constant>(bcast_v1->input_value(1).get_node_shared_ptr())
->get_shape_val(), ->get_shape_val(),
(Shape{3, 5, 4, 6})); (Shape{3, 5, 4, 6}));
EXPECT_EQ(static_pointer_cast<op::Constant>(bcast_v1->input_value(2).get_node_shared_ptr()) EXPECT_EQ(as_type_ptr<op::Constant>(bcast_v1->input_value(2).get_node_shared_ptr())
->get_axis_set_val(), ->get_axis_set_val(),
(AxisSet{1, 3})); (AxisSet{1, 3}));
} }
...@@ -53,11 +51,10 @@ TEST(opset_transform, opset1_broadcast_downgrade_pass) ...@@ -53,11 +51,10 @@ TEST(opset_transform, opset1_broadcast_downgrade_pass)
pass_manager.register_pass<pass::Opset0Downgrade>(); pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f); pass_manager.run_passes(f);
auto bcast_v0 = static_pointer_cast<op::v0::Broadcast>( auto bcast_v0 = as_type_ptr<op::v0::Broadcast>(
f->get_results().at(0)->input_value(0).get_node_shared_ptr()); f->get_results().at(0)->input_value(0).get_node_shared_ptr());
EXPECT_EQ(bcast_v0->description(), "Broadcast"); EXPECT_TRUE(bcast_v0);
EXPECT_EQ(bcast_v0->get_version(), 0);
EXPECT_EQ(bcast_v0->get_broadcast_shape(), (Shape{3, 1, 4, 2, 3})); EXPECT_EQ(bcast_v0->get_broadcast_shape(), (Shape{3, 1, 4, 2, 3}));
EXPECT_EQ(bcast_v0->get_broadcast_axes(), (AxisSet{0, 2})); EXPECT_EQ(bcast_v0->get_broadcast_axes(), (AxisSet{0, 2}));
} }
...@@ -33,10 +33,9 @@ TEST(opset_transform, opset1_convolution_upgrade_pass) ...@@ -33,10 +33,9 @@ TEST(opset_transform, opset1_convolution_upgrade_pass)
auto convolution_s1_result = f->get_results().at(0); auto convolution_s1_result = f->get_results().at(0);
auto node = convolution_s1_result->input(0).get_source_output().get_node_shared_ptr(); auto node = convolution_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto convolution_v1_node = static_pointer_cast<op::v1::Convolution>(node); auto convolution_v1_node = as_type_ptr<op::v1::Convolution>(node);
EXPECT_EQ(convolution_v1_node->description(), "Convolution"); EXPECT_TRUE(convolution_v1_node);
EXPECT_EQ(convolution_v1_node->get_version(), 1);
EXPECT_EQ(convolution_v1_node->get_pads_begin(), pads_begin); EXPECT_EQ(convolution_v1_node->get_pads_begin(), pads_begin);
EXPECT_EQ(convolution_v1_node->get_pads_end(), pads_end); EXPECT_EQ(convolution_v1_node->get_pads_end(), pads_end);
...@@ -66,10 +65,9 @@ TEST(opset_transform, opset1_convolution_downgrade_pass) ...@@ -66,10 +65,9 @@ TEST(opset_transform, opset1_convolution_downgrade_pass)
auto conv_s0_result = f->get_results().at(0); auto conv_s0_result = f->get_results().at(0);
auto node = conv_s0_result->input(0).get_source_output().get_node_shared_ptr(); auto node = conv_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto conv_v0_node = static_pointer_cast<op::v0::Convolution>(node); auto conv_v0_node = as_type_ptr<op::v0::Convolution>(node);
EXPECT_EQ(conv_v0_node->description(), "Convolution"); EXPECT_TRUE(conv_v0_node);
EXPECT_EQ(conv_v0_node->get_version(), 0);
EXPECT_EQ(conv_v0_node->get_window_movement_strides(), strides); EXPECT_EQ(conv_v0_node->get_window_movement_strides(), strides);
EXPECT_EQ(conv_v0_node->get_window_dilation_strides(), dilations); EXPECT_EQ(conv_v0_node->get_window_dilation_strides(), dilations);
EXPECT_EQ(conv_v0_node->get_padding_below(), pads_begin); EXPECT_EQ(conv_v0_node->get_padding_below(), pads_begin);
...@@ -99,10 +97,9 @@ TEST(opset_transform, opset1_convolution_backprop_data_downgrade_pass) ...@@ -99,10 +97,9 @@ TEST(opset_transform, opset1_convolution_backprop_data_downgrade_pass)
auto conv_s0_result = f->get_results().at(0); auto conv_s0_result = f->get_results().at(0);
auto node = conv_s0_result->input(0).get_source_output().get_node_shared_ptr(); auto node = conv_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto conv_v0_node = static_pointer_cast<op::v0::ConvolutionBackpropData>(node); auto conv_v0_node = as_type_ptr<op::v0::ConvolutionBackpropData>(node);
EXPECT_EQ(conv_v0_node->description(), "ConvolutionBackpropData"); EXPECT_TRUE(conv_v0_node);
EXPECT_EQ(conv_v0_node->get_version(), 0);
EXPECT_EQ(conv_v0_node->get_data_batch_shape(), (Shape{64, 3, 100})); EXPECT_EQ(conv_v0_node->get_data_batch_shape(), (Shape{64, 3, 100}));
EXPECT_EQ(conv_v0_node->get_window_movement_strides_forward(), strides); EXPECT_EQ(conv_v0_node->get_window_movement_strides_forward(), strides);
EXPECT_EQ(conv_v0_node->get_window_dilation_strides_forward(), dilations); EXPECT_EQ(conv_v0_node->get_window_dilation_strides_forward(), dilations);
...@@ -131,10 +128,9 @@ TEST(opset_transform, opset1_convolution_backprop_filters_downgrade_pass) ...@@ -131,10 +128,9 @@ TEST(opset_transform, opset1_convolution_backprop_filters_downgrade_pass)
auto conv_s0_result = f->get_results().at(0); auto conv_s0_result = f->get_results().at(0);
auto node = conv_s0_result->input(0).get_source_output().get_node_shared_ptr(); auto node = conv_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto conv_v0_node = static_pointer_cast<op::v0::ConvolutionBackpropFilters>(node); auto conv_v0_node = as_type_ptr<op::v0::ConvolutionBackpropFilters>(node);
EXPECT_EQ(conv_v0_node->description(), "ConvolutionBackpropFilters"); EXPECT_TRUE(conv_v0_node);
EXPECT_EQ(conv_v0_node->get_version(), 0);
EXPECT_EQ(conv_v0_node->get_filters_shape(), (Shape{128, 3, 10})); EXPECT_EQ(conv_v0_node->get_filters_shape(), (Shape{128, 3, 10}));
EXPECT_EQ(conv_v0_node->get_window_movement_strides_forward(), strides); EXPECT_EQ(conv_v0_node->get_window_movement_strides_forward(), strides);
EXPECT_EQ(conv_v0_node->get_window_dilation_strides_forward(), dilations); EXPECT_EQ(conv_v0_node->get_window_dilation_strides_forward(), dilations);
......
...@@ -39,12 +39,8 @@ TEST(opset_transform, opset1_dyn_reshape_upgrade_pass) ...@@ -39,12 +39,8 @@ TEST(opset_transform, opset1_dyn_reshape_upgrade_pass)
pass_manager.register_pass<pass::Opset1Upgrade>(); pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f); pass_manager.run_passes(f);
const auto pass_replacement_node = const auto pass_replacement_node = f->get_result()->input_value(0).get_node_shared_ptr();
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); EXPECT_TRUE(is_type<op::v1::Reshape>(pass_replacement_node));
const auto reshape_v1 = as_type_ptr<op::v1::Reshape>(pass_replacement_node);
EXPECT_EQ(reshape_v1->description(), "DynReshape");
EXPECT_EQ(reshape_v1->get_version(), 1);
} }
TEST(opset_transform, opset1_reshape_downgrade_pass) TEST(opset_transform, opset1_reshape_downgrade_pass)
...@@ -60,11 +56,8 @@ TEST(opset_transform, opset1_reshape_downgrade_pass) ...@@ -60,11 +56,8 @@ TEST(opset_transform, opset1_reshape_downgrade_pass)
pass_manager.register_pass<pass::Opset0Downgrade>(); pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f); pass_manager.run_passes(f);
const auto pass_replacement_node = const auto pass_replacement_node = f->get_result()->input_value(0).get_node_shared_ptr();
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reshape_v1 = as_type_ptr<op::v0::DynReshape>(pass_replacement_node); const auto reshape_v1 = as_type_ptr<op::v0::DynReshape>(pass_replacement_node);
EXPECT_TRUE(reshape_v1);
EXPECT_EQ(reshape_v1->description(), "DynReshape");
EXPECT_EQ(reshape_v1->get_version(), 0);
EXPECT_EQ(reshape_v1->get_zero_flag(), true); EXPECT_EQ(reshape_v1->get_zero_flag(), true);
} }
...@@ -40,10 +40,8 @@ TEST(opset_transform, opset1_gather_upgrade_pass) ...@@ -40,10 +40,8 @@ TEST(opset_transform, opset1_gather_upgrade_pass)
pass_manager.run_passes(f); pass_manager.run_passes(f);
auto gather_s1_result = f->get_results().at(0); auto gather_s1_result = f->get_results().at(0);
auto node = gather_s1_result->input(0).get_source_output().get_node_shared_ptr(); auto gather_v1_node = as_type_ptr<op::v1::Gather>(
auto gather_v1_node = static_pointer_cast<op::v1::Gather>(node); gather_s1_result->input(0).get_source_output().get_node_shared_ptr());
EXPECT_TRUE(gather_v1_node);
EXPECT_EQ(gather_v1_node->description(), "Gather");
EXPECT_EQ(gather_v1_node->get_version(), 1);
EXPECT_EQ(gather_v1_node->get_axis(), axis); EXPECT_EQ(gather_v1_node->get_axis(), axis);
} }
...@@ -26,10 +26,8 @@ TEST(opset_transform, opset1_generate_mask_downgrade_pass) ...@@ -26,10 +26,8 @@ TEST(opset_transform, opset1_generate_mask_downgrade_pass)
pass_manager.register_pass<pass::Opset0Downgrade>(); pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f); pass_manager.run_passes(f);
auto generate_mask_v0 = static_pointer_cast<op::v0::GenerateMask>( auto generate_mask_v0 = as_type_ptr<op::v0::GenerateMask>(
f->get_results().at(0)->input_value(0).get_node_shared_ptr()); f->get_results().at(0)->input_value(0).get_node_shared_ptr());
EXPECT_TRUE(generate_mask_v0);
EXPECT_EQ(generate_mask_v0->description(), "GenerateMask");
EXPECT_EQ(generate_mask_v0->get_version(), 0);
EXPECT_EQ(generate_mask_v0->get_mask_shape(), (Shape{1, 128})); EXPECT_EQ(generate_mask_v0->get_mask_shape(), (Shape{1, 128}));
} }
...@@ -40,10 +40,8 @@ TEST(opset_transform, opset1_logical_and_upgrade_pass) ...@@ -40,10 +40,8 @@ TEST(opset_transform, opset1_logical_and_upgrade_pass)
const auto pass_replacement_node = const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto and_v1 = static_pointer_cast<op::v1::LogicalAnd>(pass_replacement_node); const auto and_v1 = as_type_ptr<op::v1::LogicalAnd>(pass_replacement_node);
EXPECT_TRUE(and_v1);
EXPECT_EQ(and_v1->description(), "LogicalAnd");
EXPECT_EQ(and_v1->get_version(), 1);
const auto values_out_element_type = and_v1->output(0).get_element_type(); const auto values_out_element_type = and_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean); EXPECT_EQ(values_out_element_type, element::boolean);
...@@ -63,10 +61,8 @@ TEST(opset_transform, opset1_logical_and_downgrade_pass) ...@@ -63,10 +61,8 @@ TEST(opset_transform, opset1_logical_and_downgrade_pass)
const auto pass_replacement_node = const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto and_v0 = static_pointer_cast<op::v0::And>(pass_replacement_node); const auto and_v0 = as_type_ptr<op::v0::And>(pass_replacement_node);
EXPECT_TRUE(and_v0);
EXPECT_EQ(and_v0->description(), "And");
EXPECT_EQ(and_v0->get_version(), 0);
const auto values_out_element_type = and_v0->output(0).get_element_type(); const auto values_out_element_type = and_v0->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean); EXPECT_EQ(values_out_element_type, element::boolean);
......
...@@ -39,10 +39,8 @@ TEST(opset_transform, opset1_logical_not_upgrade_pass) ...@@ -39,10 +39,8 @@ TEST(opset_transform, opset1_logical_not_upgrade_pass)
const auto pass_replacement_node = const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto not_v1 = static_pointer_cast<op::v1::LogicalNot>(pass_replacement_node); const auto not_v1 = as_type_ptr<op::v1::LogicalNot>(pass_replacement_node);
EXPECT_TRUE(not_v1);
EXPECT_EQ(not_v1->description(), "LogicalNot");
EXPECT_EQ(not_v1->get_version(), 1);
const auto values_out_element_type = not_v1->output(0).get_element_type(); const auto values_out_element_type = not_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean); EXPECT_EQ(values_out_element_type, element::boolean);
...@@ -61,10 +59,8 @@ TEST(opset_transform, opset1_logical_not_downgrade_pass) ...@@ -61,10 +59,8 @@ TEST(opset_transform, opset1_logical_not_downgrade_pass)
const auto pass_replacement_node = const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto not_v0 = static_pointer_cast<op::v0::Not>(pass_replacement_node); const auto not_v0 = as_type_ptr<op::v0::Not>(pass_replacement_node);
EXPECT_TRUE(not_v0);
EXPECT_EQ(not_v0->description(), "Not");
EXPECT_EQ(not_v0->get_version(), 0);
const auto values_out_element_type = not_v0->output(0).get_element_type(); const auto values_out_element_type = not_v0->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean); EXPECT_EQ(values_out_element_type, element::boolean);
......
...@@ -40,10 +40,8 @@ TEST(opset_transform, opset1_logical_or_upgrade_pass) ...@@ -40,10 +40,8 @@ TEST(opset_transform, opset1_logical_or_upgrade_pass)
const auto pass_replacement_node = const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto or_v1 = static_pointer_cast<op::v1::LogicalOr>(pass_replacement_node); const auto or_v1 = as_type_ptr<op::v1::LogicalOr>(pass_replacement_node);
EXPECT_TRUE(or_v1);
EXPECT_EQ(or_v1->description(), "LogicalOr");
EXPECT_EQ(or_v1->get_version(), 1);
const auto values_out_element_type = or_v1->output(0).get_element_type(); const auto values_out_element_type = or_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean); EXPECT_EQ(values_out_element_type, element::boolean);
...@@ -63,10 +61,8 @@ TEST(opset_transform, opset1_logical_or_downgrade_pass) ...@@ -63,10 +61,8 @@ TEST(opset_transform, opset1_logical_or_downgrade_pass)
const auto pass_replacement_node = const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto or_v0 = static_pointer_cast<op::v0::Or>(pass_replacement_node); const auto or_v0 = as_type_ptr<op::v0::Or>(pass_replacement_node);
EXPECT_TRUE(or_v0);
EXPECT_EQ(or_v0->description(), "Or");
EXPECT_EQ(or_v0->get_version(), 0);
const auto values_out_element_type = or_v0->output(0).get_element_type(); const auto values_out_element_type = or_v0->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean); EXPECT_EQ(values_out_element_type, element::boolean);
......
...@@ -40,10 +40,8 @@ TEST(opset_transform, opset1_logical_xor_upgrade_pass) ...@@ -40,10 +40,8 @@ TEST(opset_transform, opset1_logical_xor_upgrade_pass)
const auto pass_replacement_node = const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto xor_v1 = static_pointer_cast<op::v1::LogicalXor>(pass_replacement_node); const auto xor_v1 = as_type_ptr<op::v1::LogicalXor>(pass_replacement_node);
EXPECT_TRUE(xor_v1);
EXPECT_EQ(xor_v1->description(), "LogicalXor");
EXPECT_EQ(xor_v1->get_version(), 1);
const auto values_out_element_type = xor_v1->output(0).get_element_type(); const auto values_out_element_type = xor_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean); EXPECT_EQ(values_out_element_type, element::boolean);
...@@ -63,10 +61,8 @@ TEST(opset_transform, opset1_logical_xor_downgrade_pass) ...@@ -63,10 +61,8 @@ TEST(opset_transform, opset1_logical_xor_downgrade_pass)
const auto pass_replacement_node = const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto xor_v0 = static_pointer_cast<op::v0::Xor>(pass_replacement_node); const auto xor_v0 = as_type_ptr<op::v0::Xor>(pass_replacement_node);
EXPECT_TRUE(xor_v0);
EXPECT_EQ(xor_v0->description(), "Xor");
EXPECT_EQ(xor_v0->get_version(), 0);
const auto values_out_element_type = xor_v0->output(0).get_element_type(); const auto values_out_element_type = xor_v0->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean); EXPECT_EQ(values_out_element_type, element::boolean);
......
...@@ -29,10 +29,8 @@ TEST(opset_transform, opset1_pad_upgrade_pass) ...@@ -29,10 +29,8 @@ TEST(opset_transform, opset1_pad_upgrade_pass)
auto pad_s1_result = f->get_results().at(0); auto pad_s1_result = f->get_results().at(0);
auto node = pad_s1_result->input(0).get_source_output().get_node_shared_ptr(); auto node = pad_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto pad_v1_node = static_pointer_cast<op::v1::Pad>(node); auto pad_v1_node = as_type_ptr<op::v1::Pad>(node);
EXPECT_TRUE(pad_v1_node);
EXPECT_EQ(pad_v1_node->description(), "Pad");
EXPECT_EQ(pad_v1_node->get_version(), 1);
EXPECT_EQ(pad_v1_node->get_pad_mode(), pad_mode); EXPECT_EQ(pad_v1_node->get_pad_mode(), pad_mode);
EXPECT_EQ(pad_v1_node->get_pads_begin(), padding_below); EXPECT_EQ(pad_v1_node->get_pads_begin(), padding_below);
...@@ -58,10 +56,8 @@ TEST(opset_transform, opset1_pad_downgrade_pass) ...@@ -58,10 +56,8 @@ TEST(opset_transform, opset1_pad_downgrade_pass)
auto pad_s0_result = f->get_results().at(0); auto pad_s0_result = f->get_results().at(0);
auto node = pad_s0_result->input(0).get_source_output().get_node_shared_ptr(); auto node = pad_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto pad_v0_node = static_pointer_cast<op::v0::Pad>(node); auto pad_v0_node = as_type_ptr<op::v0::Pad>(node);
EXPECT_TRUE(pad_v0_node);
EXPECT_EQ(pad_v0_node->description(), "Pad");
EXPECT_EQ(pad_v0_node->get_version(), 0);
EXPECT_EQ(pad_v0_node->get_pad_mode(), pad_mode); EXPECT_EQ(pad_v0_node->get_pad_mode(), pad_mode);
EXPECT_EQ(pad_v0_node->get_padding_below(), CoordinateDiff({1, 2})); EXPECT_EQ(pad_v0_node->get_padding_below(), CoordinateDiff({1, 2}));
......
...@@ -33,10 +33,8 @@ TEST(opset_transform, opset1_avgpool_upgrade_pass) ...@@ -33,10 +33,8 @@ TEST(opset_transform, opset1_avgpool_upgrade_pass)
auto avgpool_s1_result = f->get_results().at(0); auto avgpool_s1_result = f->get_results().at(0);
auto node = avgpool_s1_result->input(0).get_source_output().get_node_shared_ptr(); auto node = avgpool_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto avg_pool_v1_node = static_pointer_cast<op::v1::AvgPool>(node); auto avg_pool_v1_node = as_type_ptr<op::v1::AvgPool>(node);
EXPECT_TRUE(avg_pool_v1_node);
EXPECT_EQ(avg_pool_v1_node->description(), "AvgPool");
EXPECT_EQ(avg_pool_v1_node->get_version(), 1);
EXPECT_EQ(avg_pool_v1_node->get_pads_begin(), pads_begin); EXPECT_EQ(avg_pool_v1_node->get_pads_begin(), pads_begin);
EXPECT_EQ(avg_pool_v1_node->get_pads_end(), pads_end); EXPECT_EQ(avg_pool_v1_node->get_pads_end(), pads_end);
...@@ -68,10 +66,8 @@ TEST(opset_transform, opset1_maxpool_upgrade_pass) ...@@ -68,10 +66,8 @@ TEST(opset_transform, opset1_maxpool_upgrade_pass)
auto maxpool_s1_result = f->get_results().at(0); auto maxpool_s1_result = f->get_results().at(0);
auto node = maxpool_s1_result->input(0).get_source_output().get_node_shared_ptr(); auto node = maxpool_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto max_pool_v1_node = static_pointer_cast<op::v1::MaxPool>(node); auto max_pool_v1_node = as_type_ptr<op::v1::MaxPool>(node);
EXPECT_TRUE(max_pool_v1_node);
EXPECT_EQ(max_pool_v1_node->description(), "MaxPool");
EXPECT_EQ(max_pool_v1_node->get_version(), 1);
EXPECT_EQ(max_pool_v1_node->get_pads_begin(), pads_begin); EXPECT_EQ(max_pool_v1_node->get_pads_begin(), pads_begin);
EXPECT_EQ(max_pool_v1_node->get_pads_end(), pads_end); EXPECT_EQ(max_pool_v1_node->get_pads_end(), pads_end);
...@@ -109,10 +105,8 @@ TEST(opset_transform, opset1_avgpool_downgrade_pass) ...@@ -109,10 +105,8 @@ TEST(opset_transform, opset1_avgpool_downgrade_pass)
auto avgpool_s0_result = f->get_results().at(0); auto avgpool_s0_result = f->get_results().at(0);
auto node = avgpool_s0_result->input(0).get_source_output().get_node_shared_ptr(); auto node = avgpool_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto avg_pool_v0_node = static_pointer_cast<op::v0::AvgPool>(node); auto avg_pool_v0_node = as_type_ptr<op::v0::AvgPool>(node);
EXPECT_TRUE(avg_pool_v0_node);
EXPECT_EQ(avg_pool_v0_node->description(), "AvgPool");
EXPECT_EQ(avg_pool_v0_node->get_version(), 0);
EXPECT_EQ(avg_pool_v0_node->get_padding_below(), padding_below); EXPECT_EQ(avg_pool_v0_node->get_padding_below(), padding_below);
EXPECT_EQ(avg_pool_v0_node->get_padding_above(), padding_above); EXPECT_EQ(avg_pool_v0_node->get_padding_above(), padding_above);
...@@ -149,10 +143,8 @@ TEST(opset_transform, opset1_maxpool_downgrade_pass) ...@@ -149,10 +143,8 @@ TEST(opset_transform, opset1_maxpool_downgrade_pass)
auto maxpool_s0_result = f->get_results().at(0); auto maxpool_s0_result = f->get_results().at(0);
auto node = maxpool_s0_result->input(0).get_source_output().get_node_shared_ptr(); auto node = maxpool_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto max_pool_v0_node = static_pointer_cast<op::v0::MaxPool>(node); auto max_pool_v0_node = as_type_ptr<op::v0::MaxPool>(node);
EXPECT_TRUE(max_pool_v0_node);
EXPECT_EQ(max_pool_v0_node->description(), "MaxPool");
EXPECT_EQ(max_pool_v0_node->get_version(), 0);
EXPECT_EQ(max_pool_v0_node->get_padding_below(), padding_below); EXPECT_EQ(max_pool_v0_node->get_padding_below(), padding_below);
EXPECT_EQ(max_pool_v0_node->get_padding_above(), padding_above); EXPECT_EQ(max_pool_v0_node->get_padding_above(), padding_above);
...@@ -189,10 +181,8 @@ TEST(opset_transform, opset1_avgpool_backprop_downgrade_pass) ...@@ -189,10 +181,8 @@ TEST(opset_transform, opset1_avgpool_backprop_downgrade_pass)
auto avgpool_backprop_s0_result = f->get_results().at(0); auto avgpool_backprop_s0_result = f->get_results().at(0);
auto node = avgpool_backprop_s0_result->input(0).get_source_output().get_node_shared_ptr(); auto node = avgpool_backprop_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto avg_pool_backprop_v0_node = static_pointer_cast<op::v0::AvgPoolBackprop>(node); auto avg_pool_backprop_v0_node = as_type_ptr<op::v0::AvgPoolBackprop>(node);
EXPECT_TRUE(avg_pool_backprop_v0_node);
EXPECT_EQ(avg_pool_backprop_v0_node->description(), "AvgPoolBackprop");
EXPECT_EQ(avg_pool_backprop_v0_node->get_version(), 0);
EXPECT_EQ(avg_pool_backprop_v0_node->get_padding_below(), padding_below); EXPECT_EQ(avg_pool_backprop_v0_node->get_padding_below(), padding_below);
EXPECT_EQ(avg_pool_backprop_v0_node->get_padding_above(), padding_above); EXPECT_EQ(avg_pool_backprop_v0_node->get_padding_above(), padding_above);
...@@ -229,10 +219,8 @@ TEST(opset_transform, opset1_maxpool_backprop_downgrade_pass) ...@@ -229,10 +219,8 @@ TEST(opset_transform, opset1_maxpool_backprop_downgrade_pass)
auto max_pool_backprop_s0_result = f->get_results().at(0); auto max_pool_backprop_s0_result = f->get_results().at(0);
auto node = max_pool_backprop_s0_result->input(0).get_source_output().get_node_shared_ptr(); auto node = max_pool_backprop_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto max_pool_backprop_v0_node = static_pointer_cast<op::v0::MaxPoolBackprop>(node); auto max_pool_backprop_v0_node = as_type_ptr<op::v0::MaxPoolBackprop>(node);
EXPECT_TRUE(max_pool_backprop_v0_node);
EXPECT_EQ(max_pool_backprop_v0_node->description(), "MaxPoolBackprop");
EXPECT_EQ(max_pool_backprop_v0_node->get_version(), 0);
EXPECT_EQ(max_pool_backprop_v0_node->get_padding_below(), padding_below); EXPECT_EQ(max_pool_backprop_v0_node->get_padding_below(), padding_below);
EXPECT_EQ(max_pool_backprop_v0_node->get_padding_above(), padding_above); EXPECT_EQ(max_pool_backprop_v0_node->get_padding_above(), padding_above);
EXPECT_EQ(max_pool_backprop_v0_node->get_window_movement_strides(), window_movement_strides); EXPECT_EQ(max_pool_backprop_v0_node->get_window_movement_strides(), window_movement_strides);
......
...@@ -41,10 +41,8 @@ TEST(opset_transform, opset1_product_upgrade_pass) ...@@ -41,10 +41,8 @@ TEST(opset_transform, opset1_product_upgrade_pass)
const auto pass_replacement_node = const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reduce_prod_v1 = static_pointer_cast<op::v1::ReduceProd>(pass_replacement_node); const auto reduce_prod_v1 = as_type_ptr<op::v1::ReduceProd>(pass_replacement_node);
EXPECT_TRUE(reduce_prod_v1);
EXPECT_EQ(reduce_prod_v1->description(), "Product");
EXPECT_EQ(reduce_prod_v1->get_version(), 1);
EXPECT_EQ(reduce_prod_v1->get_keep_dims(), false); EXPECT_EQ(reduce_prod_v1->get_keep_dims(), false);
} }
...@@ -63,15 +61,12 @@ TEST(opset_transform, opset0_reduce_prod_downgrade_pass) ...@@ -63,15 +61,12 @@ TEST(opset_transform, opset0_reduce_prod_downgrade_pass)
const auto reshape_replacement_node = const auto reshape_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reshape = static_pointer_cast<op::Reshape>(reshape_replacement_node); const auto reshape = as_type_ptr<op::Reshape>(reshape_replacement_node);
EXPECT_TRUE(reshape);
const auto product_replace_node = const auto product_replace_node =
reshape_replacement_node->input(0).get_source_output().get_node_shared_ptr(); reshape_replacement_node->input(0).get_source_output().get_node_shared_ptr();
const auto product_v0 = static_pointer_cast<op::v0::Product>(product_replace_node); const auto product_v0 = as_type_ptr<op::v0::Product>(product_replace_node);
EXPECT_TRUE(product_v0);
EXPECT_EQ(reshape->description(), "Reshape");
EXPECT_EQ(reshape->get_version(), 0);
EXPECT_EQ(product_v0->description(), "Product");
EXPECT_EQ(product_v0->get_version(), 0);
} }
TEST(opset_transform, opset0_reduce_prod_downgrade_pass_axes_not_constant) TEST(opset_transform, opset0_reduce_prod_downgrade_pass_axes_not_constant)
......
...@@ -41,11 +41,9 @@ TEST(opset_transform, opset1_reverse_upgrade_pass) ...@@ -41,11 +41,9 @@ TEST(opset_transform, opset1_reverse_upgrade_pass)
const auto pass_replacement_node = const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reverse_v1 = static_pointer_cast<op::v1::Reverse>(pass_replacement_node); const auto reverse_v1 = as_type_ptr<op::v1::Reverse>(pass_replacement_node);
EXPECT_TRUE(reverse_v1);
EXPECT_EQ(reverse_v1->get_mode(), op::v1::Reverse::Mode::INDEX); EXPECT_EQ(reverse_v1->get_mode(), op::v1::Reverse::Mode::INDEX);
EXPECT_EQ(reverse_v1->description(), "Reverse");
EXPECT_EQ(reverse_v1->get_version(), 1);
const auto& rev_axes_input_shape = reverse_v1->get_input_shape(1); const auto& rev_axes_input_shape = reverse_v1->get_input_shape(1);
// should match the number of elements of v0::Reverse reverse_axes attribute // should match the number of elements of v0::Reverse reverse_axes attribute
...@@ -69,10 +67,8 @@ TEST(opset_transform, opset0_reverse_downgrade_pass_index_mode) ...@@ -69,10 +67,8 @@ TEST(opset_transform, opset0_reverse_downgrade_pass_index_mode)
const auto pass_replacement_node = const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reverse_v0 = static_pointer_cast<op::v0::Reverse>(pass_replacement_node); const auto reverse_v0 = as_type_ptr<op::v0::Reverse>(pass_replacement_node);
EXPECT_TRUE(reverse_v0);
EXPECT_EQ(reverse_v0->description(), "Reverse");
EXPECT_EQ(reverse_v0->get_version(), 0);
EXPECT_EQ(reverse_v0->get_reversed_axes(), AxisSet({1, 2})); EXPECT_EQ(reverse_v0->get_reversed_axes(), AxisSet({1, 2}));
} }
...@@ -93,10 +89,8 @@ TEST(opset_transform, opset0_reverse_downgrade_pass_mask_mode) ...@@ -93,10 +89,8 @@ TEST(opset_transform, opset0_reverse_downgrade_pass_mask_mode)
const auto pass_replacement_node = const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reverse_v0 = static_pointer_cast<op::v0::Reverse>(pass_replacement_node); const auto reverse_v0 = as_type_ptr<op::v0::Reverse>(pass_replacement_node);
EXPECT_TRUE(reverse_v0);
EXPECT_EQ(reverse_v0->description(), "Reverse");
EXPECT_EQ(reverse_v0->get_version(), 0);
EXPECT_EQ(reverse_v0->get_reversed_axes(), AxisSet({0, 2})); EXPECT_EQ(reverse_v0->get_reversed_axes(), AxisSet({0, 2}));
} }
......
...@@ -45,16 +45,17 @@ TEST(opset_transform, opset1_dyn_slice_upgrade_pass) ...@@ -45,16 +45,17 @@ TEST(opset_transform, opset1_dyn_slice_upgrade_pass)
const auto pass_replacement_node = const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto strided_slice_v1 = as_type_ptr<op::v1::StridedSlice>(pass_replacement_node); const auto strided_slice_v1 = as_type_ptr<op::v1::StridedSlice>(pass_replacement_node);
EXPECT_TRUE(strided_slice_v1);
auto begin_const = auto begin_const =
as_type_ptr<op::Constant>(strided_slice_v1->input_value(1).get_node_shared_ptr()); as_type_ptr<op::Constant>(strided_slice_v1->input_value(1).get_node_shared_ptr());
EXPECT_TRUE(begin_const);
auto end_const = auto end_const =
as_type_ptr<op::Constant>(strided_slice_v1->input_value(2).get_node_shared_ptr()); as_type_ptr<op::Constant>(strided_slice_v1->input_value(2).get_node_shared_ptr());
EXPECT_TRUE(end_const);
auto strides_const = auto strides_const =
as_type_ptr<op::Constant>(strided_slice_v1->input_value(3).get_node_shared_ptr()); as_type_ptr<op::Constant>(strided_slice_v1->input_value(3).get_node_shared_ptr());
EXPECT_TRUE(strides_const);
EXPECT_EQ(strided_slice_v1->description(), "Slice");
EXPECT_EQ(strided_slice_v1->get_version(), 1);
EXPECT_EQ(strided_slice_v1->get_begin_mask(), vector<int64_t>(4, 0)); EXPECT_EQ(strided_slice_v1->get_begin_mask(), vector<int64_t>(4, 0));
EXPECT_EQ(strided_slice_v1->get_end_mask(), vector<int64_t>(4, 0)); EXPECT_EQ(strided_slice_v1->get_end_mask(), vector<int64_t>(4, 0));
EXPECT_EQ(begin_const->get_vector<int64_t>(), EXPECT_EQ(begin_const->get_vector<int64_t>(),
...@@ -84,9 +85,7 @@ TEST(opset_transform, opset1_strided_slice_downgrade_pass) ...@@ -84,9 +85,7 @@ TEST(opset_transform, opset1_strided_slice_downgrade_pass)
const auto pass_replacement_node = const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto slice_v0 = as_type_ptr<op::v0::Slice>(pass_replacement_node); const auto slice_v0 = as_type_ptr<op::v0::Slice>(pass_replacement_node);
EXPECT_TRUE(slice_v0);
EXPECT_EQ(slice_v0->description(), "Slice");
EXPECT_EQ(slice_v0->get_version(), 0);
EXPECT_EQ(slice_v0->get_lower_bounds(), Coordinate({1, 2, 0, 2})); EXPECT_EQ(slice_v0->get_lower_bounds(), Coordinate({1, 2, 0, 2}));
EXPECT_EQ(slice_v0->get_upper_bounds(), Coordinate({5, 4, 5, 6})); EXPECT_EQ(slice_v0->get_upper_bounds(), Coordinate({5, 4, 5, 6}));
EXPECT_EQ(slice_v0->get_strides(), Strides({1, 1, 1, 1})); EXPECT_EQ(slice_v0->get_strides(), Strides({1, 1, 1, 1}));
......
...@@ -40,11 +40,9 @@ TEST(opset_transform, opset1_softmax_upgrade_pass_axis) ...@@ -40,11 +40,9 @@ TEST(opset_transform, opset1_softmax_upgrade_pass_axis)
auto softmax_s1_result = f->get_results().at(0); auto softmax_s1_result = f->get_results().at(0);
auto node = softmax_s1_result->input(0).get_source_output().get_node_shared_ptr(); auto node = softmax_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto softmax_s1_node = static_pointer_cast<op::v1::Softmax>(node); auto softmax_s1_node = as_type_ptr<op::v1::Softmax>(node);
EXPECT_TRUE(softmax_s1_node);
EXPECT_EQ(softmax_s1_node->get_axis(), axis); EXPECT_EQ(softmax_s1_node->get_axis(), axis);
EXPECT_EQ(softmax_s1_node->description(), "Softmax");
EXPECT_EQ(softmax_s1_node->get_version(), 1);
} }
TEST(opset_transform, opset1_softmax_upgrade_pass_axis_exception) TEST(opset_transform, opset1_softmax_upgrade_pass_axis_exception)
...@@ -75,43 +73,3 @@ TEST(opset_transform, opset1_softmax_upgrade_pass_axis_exception) ...@@ -75,43 +73,3 @@ TEST(opset_transform, opset1_softmax_upgrade_pass_axis_exception)
FAIL() << "Softmax pass failed for unexpected reason"; FAIL() << "Softmax pass failed for unexpected reason";
} }
} }
namespace fake_v2
{
class FakeSoftmax : public op::v0::Softmax
{
public:
FakeSoftmax(const Output<Node>& arg, const AxisSet& axes)
: Softmax{arg, axes}
{
}
size_t get_version() const override { return 2; }
};
}
TEST(opset_transform, opset1_softmax_upgrade_pass_incorrect_op_version)
{
const AxisSet axes{2};
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4});
auto softmax_s2 = make_shared<fake_v2::FakeSoftmax>(arg, axes);
auto result = make_shared<op::Result>(softmax_s2);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
try
{
pass_manager.run_passes(f);
FAIL() << "Opset 1 transformation pass failed for";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Op version 1 transformation pass failed for"));
}
catch (...)
{
FAIL() << "Softmax pass failed for unexpected reason";
}
}
...@@ -41,10 +41,8 @@ TEST(opset_transform, opset1_reduce_sum_upgrade_pass) ...@@ -41,10 +41,8 @@ TEST(opset_transform, opset1_reduce_sum_upgrade_pass)
const auto pass_replacement_node = const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reduce_sum_v1 = static_pointer_cast<op::v1::ReduceSum>(pass_replacement_node); const auto reduce_sum_v1 = as_type_ptr<op::v1::ReduceSum>(pass_replacement_node);
EXPECT_TRUE(reduce_sum_v1);
EXPECT_EQ(reduce_sum_v1->description(), "Sum");
EXPECT_EQ(reduce_sum_v1->get_version(), 1);
EXPECT_EQ(reduce_sum_v1->get_keep_dims(), false); EXPECT_EQ(reduce_sum_v1->get_keep_dims(), false);
} }
...@@ -63,15 +61,12 @@ TEST(opset_transform, opset0_reduce_sum_downgrade_pass) ...@@ -63,15 +61,12 @@ TEST(opset_transform, opset0_reduce_sum_downgrade_pass)
const auto reshape_replacement_node = const auto reshape_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reshape = static_pointer_cast<op::Reshape>(reshape_replacement_node); const auto reshape = as_type_ptr<op::Reshape>(reshape_replacement_node);
EXPECT_TRUE(reshape);
const auto sum_replace_node = const auto sum_replace_node =
reshape_replacement_node->input(0).get_source_output().get_node_shared_ptr(); reshape_replacement_node->input(0).get_source_output().get_node_shared_ptr();
const auto sum_v0 = static_pointer_cast<op::v0::Sum>(sum_replace_node); const auto sum_v0 = as_type_ptr<op::v0::Sum>(sum_replace_node);
EXPECT_TRUE(sum_v0);
EXPECT_EQ(reshape->description(), "Reshape");
EXPECT_EQ(reshape->get_version(), 0);
EXPECT_EQ(sum_v0->description(), "Sum");
EXPECT_EQ(sum_v0->get_version(), 0);
} }
TEST(opset_transform, opset0_reduce_sum_downgrade_pass_not_constant_axes) TEST(opset_transform, opset0_reduce_sum_downgrade_pass_not_constant_axes)
......
...@@ -41,11 +41,9 @@ TEST(opset_transform, opset1_topk_upgrade_pass) ...@@ -41,11 +41,9 @@ TEST(opset_transform, opset1_topk_upgrade_pass)
const auto pass_replacement_node = const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto topk_v1 = static_pointer_cast<op::v1::TopK>(pass_replacement_node); const auto topk_v1 = as_type_ptr<op::v1::TopK>(pass_replacement_node);
EXPECT_TRUE(topk_v1);
EXPECT_EQ(topk_v1->get_axis(), axis); EXPECT_EQ(topk_v1->get_axis(), axis);
EXPECT_EQ(topk_v1->description(), "TopK");
EXPECT_EQ(topk_v1->get_version(), 1);
EXPECT_EQ(topk_v1->get_mode(), op::v1::TopK::Mode::MAX); EXPECT_EQ(topk_v1->get_mode(), op::v1::TopK::Mode::MAX);
EXPECT_EQ(topk_v1->get_sort_type(), op::v1::TopK::SortType::SORT_VALUES); EXPECT_EQ(topk_v1->get_sort_type(), op::v1::TopK::SortType::SORT_VALUES);
...@@ -74,9 +72,7 @@ TEST(opset_transform, opset1_topk_downgrade_pass) ...@@ -74,9 +72,7 @@ TEST(opset_transform, opset1_topk_downgrade_pass)
const auto pass_replacement_node = const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto topk_v0 = as_type_ptr<op::v0::TopK>(pass_replacement_node); const auto topk_v0 = as_type_ptr<op::v0::TopK>(pass_replacement_node);
EXPECT_TRUE(topk_v0);
EXPECT_EQ(topk_v0->description(), "TopK");
EXPECT_EQ(topk_v0->get_version(), 0);
EXPECT_EQ(topk_v0->get_k(), k); EXPECT_EQ(topk_v0->get_k(), k);
EXPECT_EQ(topk_v0->get_top_k_axis(), axis); EXPECT_EQ(topk_v0->get_top_k_axis(), axis);
EXPECT_EQ(topk_v0->get_compute_max(), true); EXPECT_EQ(topk_v0->get_compute_max(), true);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment