Unverified Commit 4c5dbf07 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Cyphers/uop (#3903)

* Address op_tbl issues

* fix

* fix

* fix

* Cleanup

* cleanup

* cleanup

* More fixes

* Revert ser changes

* Compiles

* opset conversion fixed

* Fix opset conversion tests

* Deal with Reciprocal and FloorMod movement

* Cleanup

* Remove duplicate enums

* Experiment

* experiment

* Types

* Reorg around clang 3.9 bug
parent 16e256fa
......@@ -105,13 +105,45 @@ std::shared_ptr<Node>
Node::copy_with_new_inputs(const OutputVector& inputs,
const std::vector<std::shared_ptr<Node>>& control_dependencies) const
{
bool for_get_output_element = is_type<op::GetOutputElement>(this);
NodeVector args;
for (const Output<Node>& input : inputs)
shared_ptr<Node> clone;
if (is_type<op::GetOutputElement>(this))
{
args.push_back(get_output_element(input, for_get_output_element));
auto& value = inputs.at(0);
clone = make_shared<op::GetOutputElement>(value.get_node_shared_ptr(), value.get_index());
}
else
{
NodeVector args;
for (const Output<Node>& input : inputs)
{
args.push_back(get_output_element(input, false));
}
for (int i = 0; i < inputs.size(); ++i)
{
auto in_val = inputs.at(i);
if (is_type<op::GetOutputElement>(in_val.get_node()))
{
in_val = as_type_ptr<op::GetOutputElement>(in_val.get_node_shared_ptr())
->get_as_output();
}
auto in_index = in_val.get_index();
auto arg = args.at(i);
size_t out_index = 0;
if (is_type<op::GetOutputElement>(arg))
{
out_index = as_type_ptr<op::GetOutputElement>(arg)->get_n();
}
if (in_index != out_index)
{
cerr << "Mismatch in: " << in_index << " arg: " << out_index << endl;
cerr << "ARG: " << *arg << endl;
cerr << "IN: " << *inputs.at(i).get_node() << endl;
cerr << "INV: " << *in_val.get_node() << endl;
cerr << "In node " << *this << endl;
}
}
clone = copy_with_new_args(args);
}
shared_ptr<Node> clone = copy_with_new_args(args);
for (auto& cdep : control_dependencies)
{
clone->add_control_dependency(cdep);
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// This collection contains one entry for each fused op.
//
#ifndef NGRAPH_OP
#warning "NGRAPH_OP not defined"
#define NGRAPH_OP(x, y)
#endif
NGRAPH_OP(BatchMatMulTranspose, ngraph::op)
NGRAPH_OP(Clamp, ngraph::op)
NGRAPH_OP(ConvolutionBias, ngraph::op)
NGRAPH_OP(ConvolutionBiasAdd, ngraph::op)
NGRAPH_OP(ConvolutionBiasBackpropFiltersBias, ngraph::op)
NGRAPH_OP(CrossEntropy, ngraph::op)
NGRAPH_OP(CrossEntropyBackprop, ngraph::op)
NGRAPH_OP(DepthToSpace, ngraph::op)
NGRAPH_OP(Elu, ngraph::op)
NGRAPH_OP(FakeQuantize, ngraph::op)
NGRAPH_OP(Gelu, ngraph::op)
NGRAPH_OP(GeluBackpropFactor, ngraph::op)
NGRAPH_OP(Gemm, ngraph::op)
NGRAPH_OP(GRN, ngraph::op)
NGRAPH_OP(GroupConvolution, ngraph::op)
NGRAPH_OP(GroupConvolutionTranspose, ngraph::op)
NGRAPH_OP(GRUCell, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(LayerNorm, ngraph::op)
NGRAPH_OP(LayerNormBackprop, ngraph::op)
NGRAPH_OP(LogSoftmax, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(LSTMSequence, ngraph::op)
NGRAPH_OP(MatMul, ngraph::op)
NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP(NormalizeL2, ngraph::op)
NGRAPH_OP(PartialSlice, ngraph::op)
NGRAPH_OP(PartialSliceBackprop, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op)
NGRAPH_OP(Reciprocal, ngraph::op)
NGRAPH_OP(RNNCell, ngraph::op)
NGRAPH_OP(ScaleShift, ngraph::op)
NGRAPH_OP(Selu, ngraph::op)
NGRAPH_OP(ShuffleChannels, ngraph::op)
NGRAPH_OP(SpaceToDepth, ngraph::op)
NGRAPH_OP(Split, ngraph::op)
NGRAPH_OP(SquaredDifference, ngraph::op)
NGRAPH_OP(SoftmaxCrossEntropy, ngraph::op)
NGRAPH_OP(SoftmaxCrossEntropyBackprop, ngraph::op)
NGRAPH_OP(Squeeze, ngraph::op)
NGRAPH_OP(TensorIterator, ngraph::op)
NGRAPH_OP(Unsqueeze, ngraph::op)
......@@ -65,10 +65,10 @@ NGRAPH_OP(Atan2, ngraph::op)
NGRAPH_OP(AvgPool, ngraph::op)
NGRAPH_OP(AvgPoolBackprop, ngraph::op)
NGRAPH_OP(BatchMatMul, ngraph::op)
NGRAPH_OP(BatchMatMulTranspose, ngraph::op)
NGRAPH_OP(BatchNormInference, ngraph::op)
NGRAPH_OP(BatchNormTraining, ngraph::op)
NGRAPH_OP(BatchNormTrainingBackprop, ngraph::op)
NGRAPH_OP(BinaryConvolution, ngraph::op)
NGRAPH_OP(Broadcast, ngraph::op)
NGRAPH_OP(BroadcastDistributed, ngraph::op)
NGRAPH_OP(BroadcastLike, ngraph::op)
......@@ -79,8 +79,13 @@ NGRAPH_OP(Convert, ngraph::op)
NGRAPH_OP(Convolution, ngraph::op)
NGRAPH_OP(ConvolutionBackpropData, ngraph::op)
NGRAPH_OP(ConvolutionBackpropFilters, ngraph::op)
NGRAPH_OP(ConvolutionBias, ngraph::op)
NGRAPH_OP(ConvolutionBiasAdd, ngraph::op)
NGRAPH_OP(ConvolutionBiasBackpropFiltersBias, ngraph::op)
NGRAPH_OP(Cos, ngraph::op)
NGRAPH_OP(Cosh, ngraph::op)
NGRAPH_OP(CrossEntropy, ngraph::op)
NGRAPH_OP(CrossEntropyBackprop, ngraph::op)
NGRAPH_OP(Dequantize, ngraph::op)
NGRAPH_OP(Divide, ngraph::op)
NGRAPH_OP(Dot, ngraph::op)
......@@ -94,22 +99,24 @@ NGRAPH_OP(Equal, ngraph::op)
NGRAPH_OP(Erf, ngraph::op)
NGRAPH_OP(Exp, ngraph::op)
NGRAPH_OP(Floor, ngraph::op)
NGRAPH_OP(FloorMod, ngraph::op)
NGRAPH_OP(GRN, ngraph::op)
NGRAPH_OP(GRUCell, ngraph::op)
NGRAPH_OP(Gather, ngraph::op)
NGRAPH_OP(GatherND, ngraph::op)
NGRAPH_OP(Gelu, ngraph::op)
NGRAPH_OP(GeluBackpropFactor, ngraph::op)
NGRAPH_OP(Gemm, ngraph::op)
NGRAPH_OP(GenerateMask, ngraph::op)
NGRAPH_OP(GetOutputElement, ngraph::op)
NGRAPH_OP(Greater, ngraph::op)
NGRAPH_OP(GreaterEq, ngraph::op)
NGRAPH_OP(GroupConvolutionTranspose, ngraph::op)
NGRAPH_OP(LayerNorm, ngraph::op)
NGRAPH_OP(LayerNormBackprop, ngraph::op)
NGRAPH_OP(Less, ngraph::op)
NGRAPH_OP(LessEq, ngraph::op)
NGRAPH_OP(LessEqual, ngraph::op)
NGRAPH_OP(Log, ngraph::op)
NGRAPH_OP(LogicalAnd, ngraph::op)
NGRAPH_OP(LogicalNot, ngraph::op)
NGRAPH_OP(LogicalOr, ngraph::op)
NGRAPH_OP(LogicalXor, ngraph::op)
NGRAPH_OP(LRN, ngraph::op)
NGRAPH_OP(LogSoftmax, ngraph::op)
NGRAPH_OP(Max, ngraph::op)
NGRAPH_OP(Maximum, ngraph::op)
NGRAPH_OP(MaxPool, ngraph::op)
......@@ -117,6 +124,7 @@ NGRAPH_OP(MaxPoolBackprop, ngraph::op)
NGRAPH_OP(Min, ngraph::op)
NGRAPH_OP(Minimum, ngraph::op)
NGRAPH_OP(Multiply, ngraph::op)
NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP(Negative, ngraph::op)
NGRAPH_OP(Not, ngraph::op)
NGRAPH_OP(NotEqual, ngraph::op)
......@@ -124,6 +132,8 @@ NGRAPH_OP(OneHot, ngraph::op)
NGRAPH_OP(Or, ngraph::op)
NGRAPH_OP(Pad, ngraph::op)
NGRAPH_OP(Parameter, ngraph::op)
NGRAPH_OP(PartialSlice, ngraph::op)
NGRAPH_OP(PartialSliceBackprop, ngraph::op)
NGRAPH_OP(Passthrough, ngraph::op)
NGRAPH_OP(Power, ngraph::op)
NGRAPH_OP(Product, ngraph::op)
......@@ -135,9 +145,10 @@ NGRAPH_OP(QuantizedConvolutionBiasSignedAdd, ngraph::op)
NGRAPH_OP(QuantizedConvolutionRelu, ngraph::op)
NGRAPH_OP(QuantizedDot, ngraph::op)
NGRAPH_OP(QuantizedDotBias, ngraph::op)
NGRAPH_OP(Recv, ngraph::op)
NGRAPH_OP(RandomUniform, ngraph::op)
NGRAPH_OP(Recv, ngraph::op)
NGRAPH_OP(Range, ngraph::op)
NGRAPH_OP(Reciprocal, ngraph::op)
NGRAPH_OP(Relu, ngraph::op)
NGRAPH_OP(ReluBackprop, ngraph::op)
NGRAPH_OP(ReplaceSlice, ngraph::op)
......@@ -146,9 +157,11 @@ NGRAPH_OP(Result, ngraph::op)
NGRAPH_OP(Reverse, ngraph::op)
NGRAPH_OP(ReverseSequence, ngraph::op)
NGRAPH_OP(ScalarConstantLike, ngraph::op)
NGRAPH_OP(ScaleShift, ngraph::op)
NGRAPH_OP(ScatterAdd, ngraph::op)
NGRAPH_OP(ScatterNDAdd, ngraph::op)
NGRAPH_OP(Select, ngraph::op)
NGRAPH_OP(Selu, ngraph::op)
NGRAPH_OP(Send, ngraph::op)
NGRAPH_OP(ShapeOf, ngraph::op)
NGRAPH_OP(Sigmoid, ngraph::op)
......@@ -158,6 +171,8 @@ NGRAPH_OP(Sin, ngraph::op)
NGRAPH_OP(Sinh, ngraph::op)
NGRAPH_OP(Slice, ngraph::op)
NGRAPH_OP(Softmax, ngraph::op)
NGRAPH_OP(SoftmaxCrossEntropy, ngraph::op)
NGRAPH_OP(SoftmaxCrossEntropyBackprop, ngraph::op)
NGRAPH_OP(Sqrt, ngraph::op)
NGRAPH_OP(StopGradient, ngraph::op)
NGRAPH_OP(Subtract, ngraph::op)
......@@ -165,7 +180,6 @@ NGRAPH_OP(Sum, ngraph::op)
NGRAPH_OP(Tan, ngraph::op)
NGRAPH_OP(Tanh, ngraph::op)
NGRAPH_OP(Tile, ngraph::op)
NGRAPH_OP(TopK, ngraph::op)
NGRAPH_OP(TopK, ngraph::op::v0)
NGRAPH_OP(Transpose, ngraph::op)
NGRAPH_OP(VariadicSplit, ngraph::op)
NGRAPH_OP(Xor, ngraph::op)
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// This collection contains one entry for each op. If an op is added it must be
// added to this list.
//
// In order to use this list you want to define a macro named exactly NGRAPH_OP
// When you are done you should undef the macro
// As an example if you wanted to make a list of all op names as strings you could do this:
//
// #define NGRAPH_OP(a,b) #a,
// std::vector<std::string> op_names{
// #include "this include file name"
// };
// #undef NGRAPH_OP
//
// This sample expands to a list like this:
// "Abs",
// "Acos",
// ...
//
// #define NGRAPH_OP(a,b) b::a,
// std::vector<std::string> op_names{
// #include "this include file name"
// };
// #undef NGRAPH_OP
//
// This sample expands to a list like this:
// ngraph::op::Abs,
// ngraph::op::Acos,
// ...
//
// It's that easy. You can use this for fun and profit.
#ifndef NGRAPH_OP
#warning "NGRAPH_OP not defined"
#define NGRAPH_OP(x, y)
#endif
NGRAPH_OP(Abs, ngraph::op)
NGRAPH_OP(Acos, ngraph::op)
// NGRAPH_OP(Acosh, ngraph::op)
NGRAPH_OP(Add, ngraph::op::v1)
NGRAPH_OP(Asin, ngraph::op)
// NGRAPH_OP(Asinh, ngraph::op)
NGRAPH_OP(Atan, ngraph::op)
// NGRAPH_OP(Atanh, ngraph::op)
NGRAPH_OP(AvgPool, ngraph::op::v1)
NGRAPH_OP(BatchNormInference, ngraph::op)
NGRAPH_OP(BinaryConvolution, ngraph::op::v1)
NGRAPH_OP(Broadcast, ngraph::op::v1)
// NGRAPH_OP(CTCGreedyDecoder, ngraph::op)
NGRAPH_OP(Ceiling, ngraph::op)
NGRAPH_OP(Clamp, ngraph::op)
NGRAPH_OP(Concat, ngraph::op)
NGRAPH_OP(Constant, ngraph::op)
NGRAPH_OP(Convert, ngraph::op)
// NGRAPH_OP(ConvertLike, ngraph::op)
NGRAPH_OP(Convolution, ngraph::op::v1)
NGRAPH_OP(ConvolutionBackpropData, ngraph::op::v1)
NGRAPH_OP(Cos, ngraph::op)
NGRAPH_OP(Cosh, ngraph::op)
// NGRAPH_OP(DeformableConvolution, ngraph::op)
// NGRAPH_OP(DeformablePSROIPooling, ngraph::op)
NGRAPH_OP(DepthToSpace, ngraph::op)
// NGRAPH_OP(DetectionOutput, ngraph::op)
NGRAPH_OP(Divide, ngraph::op::v1)
NGRAPH_OP(Elu, ngraph::op)
NGRAPH_OP(Erf, ngraph::op)
NGRAPH_OP(Equal, ngraph::op::v1)
NGRAPH_OP(Exp, ngraph::op)
NGRAPH_OP(FakeQuantize, ngraph::op)
NGRAPH_OP(Floor, ngraph::op)
NGRAPH_OP(FloorMod, ngraph::op::v1)
NGRAPH_OP(Gather, ngraph::op::v1)
NGRAPH_OP(Greater, ngraph::op::v1)
NGRAPH_OP(GreaterEq, ngraph::op::v1)
NGRAPH_OP(GroupConvolution, ngraph::op)
// NGRAPH_OP(GroupConvolutionBackpropData, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(Interpolate, ngraph::op)
// NGRAPH_OP(LeakyRelu, ngraph::op)
NGRAPH_OP(Less, ngraph::op::v1)
NGRAPH_OP(LessEqual, ngraph::op::v1)
NGRAPH_OP(Log, ngraph::op)
NGRAPH_OP(LogicalAnd, ngraph::op::v1)
NGRAPH_OP(LogicalNot, ngraph::op::v1)
NGRAPH_OP(LogicalOr, ngraph::op::v1)
NGRAPH_OP(LogicalXor, ngraph::op::v1)
NGRAPH_OP(LRN, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(LSTMSequence, ngraph::op)
NGRAPH_OP(MatMul, ngraph::op)
NGRAPH_OP(MaxPool, ngraph::op::v1)
NGRAPH_OP(Maximum, ngraph::op::v1)
NGRAPH_OP(Minimum, ngraph::op::v1)
// NGRAPH_OP(Mod, ngraph::op)
NGRAPH_OP(Multiply, ngraph::op::v1)
NGRAPH_OP(Negative, ngraph::op)
// NGRAPH_OP(NonMaxSuppression, ngraph::op)
NGRAPH_OP(NormalizeL2, ngraph::op)
NGRAPH_OP(NotEqual, ngraph::op::v1)
NGRAPH_OP(OneHot, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op)
// NGRAPH_OP(PSROIPooling, ngraph::op)
NGRAPH_OP(Pad, ngraph::op::v1)
NGRAPH_OP(Parameter, ngraph::op)
NGRAPH_OP(Power, ngraph::op::v1)
// NGRAPH_OP(PriorBox, ngraph::op)
// NGRAPH_OP(PriorBoxClustered, ngraph::op)
// NGRAPH_OP(Proposal, ngraph::op)
NGRAPH_OP(Range, ngraph::op)
NGRAPH_OP(Relu, ngraph::op)
// NGRAPH_OP(ReduceLogicalAnd, ngraph::op)
// NGRAPH_OP(ReduceLogicalOr, ngraph::op)
NGRAPH_OP(ReduceMax, ngraph::op::v1)
// NGRAPH_OP(ReduceMean, ngraph::op)
NGRAPH_OP(ReduceMin, ngraph::op::v1)
NGRAPH_OP(ReduceProd, ngraph::op::v1)
NGRAPH_OP(ReduceSum, ngraph::op::v1)
// NGRAPH_OP(RegionYolo, ngraph::op)
NGRAPH_OP(Reshape, ngraph::op::v1)
NGRAPH_OP(Result, ngraph::op)
NGRAPH_OP(Reverse, ngraph::op::v1)
NGRAPH_OP(ReverseSequence, ngraph::op)
NGRAPH_OP(RNNCell, ngraph::op)
// NGRAPH_OP(ROIPooling, ngraph::op)
NGRAPH_OP(ShapeOf, ngraph::op)
NGRAPH_OP(ShuffleChannels, ngraph::op)
NGRAPH_OP(Sign, ngraph::op)
NGRAPH_OP(Sigmoid, ngraph::op)
NGRAPH_OP(Sin, ngraph::op)
NGRAPH_OP(Sinh, ngraph::op)
NGRAPH_OP(Softmax, ngraph::op::v1)
NGRAPH_OP(Sqrt, ngraph::op)
NGRAPH_OP(SpaceToDepth, ngraph::op)
NGRAPH_OP(Split, ngraph::op)
NGRAPH_OP(SquaredDifference, ngraph::op)
NGRAPH_OP(Squeeze, ngraph::op)
NGRAPH_OP(StridedSlice, ngraph::op::v1)
NGRAPH_OP(Subtract, ngraph::op)
NGRAPH_OP(Tan, ngraph::op)
NGRAPH_OP(Tanh, ngraph::op)
NGRAPH_OP(TensorIterator, ngraph::op)
NGRAPH_OP(Tile, ngraph::op)
NGRAPH_OP(TopK, ngraph::op::v1)
NGRAPH_OP(Transpose, ngraph::op)
NGRAPH_OP(Unsqueeze, ngraph::op)
NGRAPH_OP(VariadicSplit, ngraph::op::v1)
// Related to v1
NGRAPH_OP(AvgPoolBackprop, ngraph::op::v1)
NGRAPH_OP(ConvolutionBackpropFilters, ngraph::op::v1)
NGRAPH_OP(MaxPoolBackprop, ngraph::op::v1)
// Other
NGRAPH_OP(GenerateMask, ngraph::op::v1)
This diff is collapsed.
......@@ -15,40 +15,147 @@
//*****************************************************************************
#include "ngraph/pass/opset1_upgrade.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/atan2.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/binary_convolution.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast_distributed.hpp"
#include "ngraph/op/ceiling.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/cos.hpp"
#include "ngraph/op/cosh.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/erf.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/dyn_replace_slice.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/layers/interpolate.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/random_uniform.hpp"
#include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/tile.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/floor_mod.hpp"
#include "ngraph/op/fused/batch_mat_mul_transpose.hpp"
#include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/crossentropy.hpp"
#include "ngraph/op/fused/depth_to_space.hpp"
#include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/fused/fake_quantize.hpp"
#include "ngraph/op/fused/gelu.hpp"
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/layer_norm.hpp"
#include "ngraph/op/fused/log_softmax.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/fused/matmul.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/fused/partial_slice.hpp"
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/reciprocal.hpp"
#include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/selu.hpp"
#include "ngraph/op/fused/shuffle_channels.hpp"
#include "ngraph/op/fused/softmax_crossentropy.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/fused/split.hpp"
#include "ngraph/op/fused/squared_difference.hpp"
#include "ngraph/op/fused/squeeze.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/passthrough.hpp"
#include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/quantized_convolution.hpp"
#include "ngraph/op/quantized_dot.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/reduce_prod.hpp"
#include "ngraph/op/reduce_sum.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/scatter_add.hpp"
#include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp"
#include "ngraph/op/sinh.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/stop_gradient.hpp"
#include "ngraph/op/strided_slice.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/tensor_iterator.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/xor.hpp"
#include <limits>
......@@ -57,33 +164,30 @@
using namespace std;
using namespace ngraph;
#define NGRAPH_OP(a, b) a,
enum class OP_TYPEID
namespace
{
#include "ngraph/op/fused_op_tbl.hpp"
#include "ngraph/op/op_tbl.hpp"
};
#undef NGRAPH_OP
#define NGRAPH_OP(a, b) {#a, OP_TYPEID::a},
static unordered_map<string, OP_TYPEID> typeid_map{
#include "ngraph/op/fused_op_tbl.hpp"
#include "ngraph/op/op_tbl.hpp"
};
enum class OP_TYPEID
{
#define NGRAPH_OP(a, b) a,
#include "ngraph/op/op_v0_tbl.hpp"
#undef NGRAPH_OP
OTHER
};
}
static OP_TYPEID get_typeid(shared_ptr<Node> node)
{
OP_TYPEID type_id;
auto it = typeid_map.find(node->description());
static map<NodeTypeInfo, OP_TYPEID> typeid_map{
#define NGRAPH_OP(a, b) {b::a::type_info, OP_TYPEID::a},
#include "ngraph/op/op_v0_tbl.hpp"
#undef NGRAPH_OP
};
OP_TYPEID type_id = OP_TYPEID::OTHER;
auto it = typeid_map.find(node->get_type_info());
if (it != typeid_map.end())
{
type_id = it->second;
}
else
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
return type_id;
}
// END mapping to OP_TYPEID
......@@ -102,20 +206,6 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
{
bool modified = false;
size_t op_version = node->get_version();
if (op_version == 1)
{
return modified;
}
NGRAPH_CHECK(op_version == 0,
"Op version 1 transformation pass failed for ",
*node,
", only op version 0 operations expected. Op version ",
op_version,
" found.");
// Not all enumeration values explicitly handled in switch
#if defined(__clang__)
#pragma clang diagnostic push
......
......@@ -21,7 +21,7 @@ else()
endif()
if (NGRAPH_INTERPRETER_ENABLE)
add_library(interpreter_backend ${LIBRARY_TYPE} int_backend.cpp node_wrapper.cpp int_executable.cpp)
add_library(interpreter_backend ${LIBRARY_TYPE} int_backend.cpp int_executable.cpp)
target_compile_definitions(interpreter_backend PRIVATE INTERPRETER_BACKEND_EXPORTS)
if(NGRAPH_LIB_VERSIONING_ENABLE)
set_target_properties(interpreter_backend PROPERTIES
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/interpreter/node_wrapper.hpp"
using namespace ngraph;
using namespace std;
runtime::interpreter::NodeWrapper::NodeWrapper(const shared_ptr<const Node>& node)
: m_node{node}
{
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// {"Abs", runtime::interpreter::OP_TYPEID::Abs},
// {"Acos", runtime::interpreter::OP_TYPEID::Acos},
// ...
#define NGRAPH_OP(a, b) {#a, runtime::interpreter::OP_TYPEID::a},
static unordered_map<string, runtime::interpreter::OP_TYPEID> typeid_map{
#include "ngraph/op/op_tbl.hpp"
#ifdef INTERPRETER_USE_HYBRID
#include "ngraph/runtime/hybrid/op/op_tbl.hpp"
#endif
};
#undef NGRAPH_OP
auto it = typeid_map.find(m_node->description());
if (it != typeid_map.end())
{
m_typeid = it->second;
}
else
{
throw unsupported_op("Unsupported op '" + m_node->description() + "'");
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node.hpp"
namespace ngraph
{
namespace runtime
{
namespace interpreter
{
enum class OP_TYPEID;
class NodeWrapper;
}
}
}
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// Abs,
// Acos,
// ...
#define NGRAPH_OP(a, b) a,
enum class ngraph::runtime::interpreter::OP_TYPEID
{
#include "ngraph/op/op_tbl.hpp"
#ifdef INTERPRETER_USE_HYBRID
#include "ngraph/runtime/hybrid/op/op_tbl.hpp"
#endif
};
#undef NGRAPH_OP
/// \brief This class allows adding an enum typeid to each Node. This makes dealing with
/// collections of Nodes a little easier and faster as we can use switch() instead of
/// if/else statements
class ngraph::runtime::interpreter::NodeWrapper
{
public:
NodeWrapper(const std::shared_ptr<const ngraph::Node>& node);
std::shared_ptr<const Node> get_node() const { return m_node; }
ngraph::runtime::interpreter::OP_TYPEID get_typeid() const { return m_typeid; }
private:
std::shared_ptr<const ngraph::Node> m_node;
OP_TYPEID m_typeid;
};
This diff is collapsed.
......@@ -17,6 +17,7 @@
#pragma once
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <utility>
......@@ -38,12 +39,31 @@ namespace ngraph
const char* name;
uint64_t version;
bool is_castable(const DiscreteTypeInfo& target_type) const { return this == &target_type; }
bool is_castable(const DiscreteTypeInfo& target_type) const { return *this == target_type; }
// For use as a key
bool operator<(const DiscreteTypeInfo& b) const
{
return version < b.version ||
(version == b.version && std::string(name) < std::string(b.name));
return version < b.version || (version == b.version && strcmp(name, b.name) < 0);
}
bool operator<=(const DiscreteTypeInfo& b) const
{
return version < b.version || (version == b.version && strcmp(name, b.name) <= 0);
}
bool operator>(const DiscreteTypeInfo& b) const
{
return version < b.version || (version == b.version && strcmp(name, b.name) > 0);
}
bool operator>=(const DiscreteTypeInfo& b) const
{
return version < b.version || (version == b.version && strcmp(name, b.name) >= 0);
}
bool operator==(const DiscreteTypeInfo& b) const
{
return version == b.version && strcmp(name, b.name) == 0;
}
bool operator!=(const DiscreteTypeInfo& b) const
{
return version != b.version || strcmp(name, b.name) != 0;
}
};
......
......@@ -35,11 +35,10 @@ void test_type_prop_opset0_downgrade_pass(const element::Type& output_type,
pass_manager.run_passes(f);
auto v0_result = f->get_results().at(0);
auto node = v0_result->input(0).get_source_output().get_node_shared_ptr();
auto v0_node = static_pointer_cast<OpV0>(node);
auto node = v0_result->input_value(0).get_node_shared_ptr();
auto v0_node = as_type_ptr<OpV0>(node);
EXPECT_EQ(v0_node->description(), (node_name.empty() ? v1_node->description() : node_name));
EXPECT_EQ(v0_node->get_version(), 0);
EXPECT_TRUE(v0_node);
EXPECT_EQ(v0_node->get_autob(), np_auto_b);
EXPECT_EQ(v0_node->output(0).get_element_type(), output_type);
EXPECT_EQ(v0_node->output(0).get_shape(), (Shape{1, 3, 2}));
......
......@@ -22,20 +22,18 @@ TEST(opset_transform, opset1_broadcast_upgrade_pass)
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
auto bcast_v1 = static_pointer_cast<op::v1::Broadcast>(
auto bcast_v1 = as_type_ptr<op::v1::Broadcast>(
f->get_results().at(0)->input_value(0).get_node_shared_ptr());
EXPECT_EQ(bcast_v1->description(), "Broadcast");
EXPECT_EQ(bcast_v1->get_version(), 1);
EXPECT_TRUE(bcast_v1);
EXPECT_EQ(bcast_v1->get_broadcast_spec(), op::AutoBroadcastSpec());
EXPECT_EQ(bcast_v1->get_broadcast_axes(), (std::make_pair<bool, AxisSet>(true, AxisSet{0, 2})));
EXPECT_EQ(bcast_v1->input_value(1).get_node()->description(), "Constant");
EXPECT_EQ(bcast_v1->input_value(2).get_node()->description(), "Constant");
EXPECT_TRUE(bcast_v1->input_value(1).get_node()->is_constant());
EXPECT_TRUE(bcast_v1->input_value(2).get_node()->is_constant());
EXPECT_EQ(static_pointer_cast<op::Constant>(bcast_v1->input_value(1).get_node_shared_ptr())
->get_shape_val(),
(Shape{3, 5, 4, 6}));
EXPECT_EQ(static_pointer_cast<op::Constant>(bcast_v1->input_value(2).get_node_shared_ptr())
EXPECT_EQ(as_type_ptr<op::Constant>(bcast_v1->input_value(2).get_node_shared_ptr())
->get_axis_set_val(),
(AxisSet{1, 3}));
}
......@@ -53,11 +51,10 @@ TEST(opset_transform, opset1_broadcast_downgrade_pass)
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
auto bcast_v0 = static_pointer_cast<op::v0::Broadcast>(
auto bcast_v0 = as_type_ptr<op::v0::Broadcast>(
f->get_results().at(0)->input_value(0).get_node_shared_ptr());
EXPECT_EQ(bcast_v0->description(), "Broadcast");
EXPECT_EQ(bcast_v0->get_version(), 0);
EXPECT_TRUE(bcast_v0);
EXPECT_EQ(bcast_v0->get_broadcast_shape(), (Shape{3, 1, 4, 2, 3}));
EXPECT_EQ(bcast_v0->get_broadcast_axes(), (AxisSet{0, 2}));
}
......@@ -33,10 +33,9 @@ TEST(opset_transform, opset1_convolution_upgrade_pass)
auto convolution_s1_result = f->get_results().at(0);
auto node = convolution_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto convolution_v1_node = static_pointer_cast<op::v1::Convolution>(node);
auto convolution_v1_node = as_type_ptr<op::v1::Convolution>(node);
EXPECT_EQ(convolution_v1_node->description(), "Convolution");
EXPECT_EQ(convolution_v1_node->get_version(), 1);
EXPECT_TRUE(convolution_v1_node);
EXPECT_EQ(convolution_v1_node->get_pads_begin(), pads_begin);
EXPECT_EQ(convolution_v1_node->get_pads_end(), pads_end);
......@@ -66,10 +65,9 @@ TEST(opset_transform, opset1_convolution_downgrade_pass)
auto conv_s0_result = f->get_results().at(0);
auto node = conv_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto conv_v0_node = static_pointer_cast<op::v0::Convolution>(node);
auto conv_v0_node = as_type_ptr<op::v0::Convolution>(node);
EXPECT_EQ(conv_v0_node->description(), "Convolution");
EXPECT_EQ(conv_v0_node->get_version(), 0);
EXPECT_TRUE(conv_v0_node);
EXPECT_EQ(conv_v0_node->get_window_movement_strides(), strides);
EXPECT_EQ(conv_v0_node->get_window_dilation_strides(), dilations);
EXPECT_EQ(conv_v0_node->get_padding_below(), pads_begin);
......@@ -99,10 +97,9 @@ TEST(opset_transform, opset1_convolution_backprop_data_downgrade_pass)
auto conv_s0_result = f->get_results().at(0);
auto node = conv_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto conv_v0_node = static_pointer_cast<op::v0::ConvolutionBackpropData>(node);
auto conv_v0_node = as_type_ptr<op::v0::ConvolutionBackpropData>(node);
EXPECT_EQ(conv_v0_node->description(), "ConvolutionBackpropData");
EXPECT_EQ(conv_v0_node->get_version(), 0);
EXPECT_TRUE(conv_v0_node);
EXPECT_EQ(conv_v0_node->get_data_batch_shape(), (Shape{64, 3, 100}));
EXPECT_EQ(conv_v0_node->get_window_movement_strides_forward(), strides);
EXPECT_EQ(conv_v0_node->get_window_dilation_strides_forward(), dilations);
......@@ -131,10 +128,9 @@ TEST(opset_transform, opset1_convolution_backprop_filters_downgrade_pass)
auto conv_s0_result = f->get_results().at(0);
auto node = conv_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto conv_v0_node = static_pointer_cast<op::v0::ConvolutionBackpropFilters>(node);
auto conv_v0_node = as_type_ptr<op::v0::ConvolutionBackpropFilters>(node);
EXPECT_EQ(conv_v0_node->description(), "ConvolutionBackpropFilters");
EXPECT_EQ(conv_v0_node->get_version(), 0);
EXPECT_TRUE(conv_v0_node);
EXPECT_EQ(conv_v0_node->get_filters_shape(), (Shape{128, 3, 10}));
EXPECT_EQ(conv_v0_node->get_window_movement_strides_forward(), strides);
EXPECT_EQ(conv_v0_node->get_window_dilation_strides_forward(), dilations);
......
......@@ -39,12 +39,8 @@ TEST(opset_transform, opset1_dyn_reshape_upgrade_pass)
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reshape_v1 = as_type_ptr<op::v1::Reshape>(pass_replacement_node);
EXPECT_EQ(reshape_v1->description(), "DynReshape");
EXPECT_EQ(reshape_v1->get_version(), 1);
const auto pass_replacement_node = f->get_result()->input_value(0).get_node_shared_ptr();
EXPECT_TRUE(is_type<op::v1::Reshape>(pass_replacement_node));
}
TEST(opset_transform, opset1_reshape_downgrade_pass)
......@@ -60,11 +56,8 @@ TEST(opset_transform, opset1_reshape_downgrade_pass)
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto pass_replacement_node = f->get_result()->input_value(0).get_node_shared_ptr();
const auto reshape_v1 = as_type_ptr<op::v0::DynReshape>(pass_replacement_node);
EXPECT_EQ(reshape_v1->description(), "DynReshape");
EXPECT_EQ(reshape_v1->get_version(), 0);
EXPECT_TRUE(reshape_v1);
EXPECT_EQ(reshape_v1->get_zero_flag(), true);
}
......@@ -40,10 +40,8 @@ TEST(opset_transform, opset1_gather_upgrade_pass)
pass_manager.run_passes(f);
auto gather_s1_result = f->get_results().at(0);
auto node = gather_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto gather_v1_node = static_pointer_cast<op::v1::Gather>(node);
EXPECT_EQ(gather_v1_node->description(), "Gather");
EXPECT_EQ(gather_v1_node->get_version(), 1);
auto gather_v1_node = as_type_ptr<op::v1::Gather>(
gather_s1_result->input(0).get_source_output().get_node_shared_ptr());
EXPECT_TRUE(gather_v1_node);
EXPECT_EQ(gather_v1_node->get_axis(), axis);
}
......@@ -26,10 +26,8 @@ TEST(opset_transform, opset1_generate_mask_downgrade_pass)
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
auto generate_mask_v0 = static_pointer_cast<op::v0::GenerateMask>(
auto generate_mask_v0 = as_type_ptr<op::v0::GenerateMask>(
f->get_results().at(0)->input_value(0).get_node_shared_ptr());
EXPECT_EQ(generate_mask_v0->description(), "GenerateMask");
EXPECT_EQ(generate_mask_v0->get_version(), 0);
EXPECT_TRUE(generate_mask_v0);
EXPECT_EQ(generate_mask_v0->get_mask_shape(), (Shape{1, 128}));
}
......@@ -40,10 +40,8 @@ TEST(opset_transform, opset1_logical_and_upgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto and_v1 = static_pointer_cast<op::v1::LogicalAnd>(pass_replacement_node);
EXPECT_EQ(and_v1->description(), "LogicalAnd");
EXPECT_EQ(and_v1->get_version(), 1);
const auto and_v1 = as_type_ptr<op::v1::LogicalAnd>(pass_replacement_node);
EXPECT_TRUE(and_v1);
const auto values_out_element_type = and_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
......@@ -63,10 +61,8 @@ TEST(opset_transform, opset1_logical_and_downgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto and_v0 = static_pointer_cast<op::v0::And>(pass_replacement_node);
EXPECT_EQ(and_v0->description(), "And");
EXPECT_EQ(and_v0->get_version(), 0);
const auto and_v0 = as_type_ptr<op::v0::And>(pass_replacement_node);
EXPECT_TRUE(and_v0);
const auto values_out_element_type = and_v0->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
......
......@@ -39,10 +39,8 @@ TEST(opset_transform, opset1_logical_not_upgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto not_v1 = static_pointer_cast<op::v1::LogicalNot>(pass_replacement_node);
EXPECT_EQ(not_v1->description(), "LogicalNot");
EXPECT_EQ(not_v1->get_version(), 1);
const auto not_v1 = as_type_ptr<op::v1::LogicalNot>(pass_replacement_node);
EXPECT_TRUE(not_v1);
const auto values_out_element_type = not_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
......@@ -61,10 +59,8 @@ TEST(opset_transform, opset1_logical_not_downgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto not_v0 = static_pointer_cast<op::v0::Not>(pass_replacement_node);
EXPECT_EQ(not_v0->description(), "Not");
EXPECT_EQ(not_v0->get_version(), 0);
const auto not_v0 = as_type_ptr<op::v0::Not>(pass_replacement_node);
EXPECT_TRUE(not_v0);
const auto values_out_element_type = not_v0->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
......
......@@ -40,10 +40,8 @@ TEST(opset_transform, opset1_logical_or_upgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto or_v1 = static_pointer_cast<op::v1::LogicalOr>(pass_replacement_node);
EXPECT_EQ(or_v1->description(), "LogicalOr");
EXPECT_EQ(or_v1->get_version(), 1);
const auto or_v1 = as_type_ptr<op::v1::LogicalOr>(pass_replacement_node);
EXPECT_TRUE(or_v1);
const auto values_out_element_type = or_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
......@@ -63,10 +61,8 @@ TEST(opset_transform, opset1_logical_or_downgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto or_v0 = static_pointer_cast<op::v0::Or>(pass_replacement_node);
EXPECT_EQ(or_v0->description(), "Or");
EXPECT_EQ(or_v0->get_version(), 0);
const auto or_v0 = as_type_ptr<op::v0::Or>(pass_replacement_node);
EXPECT_TRUE(or_v0);
const auto values_out_element_type = or_v0->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
......
......@@ -40,10 +40,8 @@ TEST(opset_transform, opset1_logical_xor_upgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto xor_v1 = static_pointer_cast<op::v1::LogicalXor>(pass_replacement_node);
EXPECT_EQ(xor_v1->description(), "LogicalXor");
EXPECT_EQ(xor_v1->get_version(), 1);
const auto xor_v1 = as_type_ptr<op::v1::LogicalXor>(pass_replacement_node);
EXPECT_TRUE(xor_v1);
const auto values_out_element_type = xor_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
......@@ -63,10 +61,8 @@ TEST(opset_transform, opset1_logical_xor_downgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto xor_v0 = static_pointer_cast<op::v0::Xor>(pass_replacement_node);
EXPECT_EQ(xor_v0->description(), "Xor");
EXPECT_EQ(xor_v0->get_version(), 0);
const auto xor_v0 = as_type_ptr<op::v0::Xor>(pass_replacement_node);
EXPECT_TRUE(xor_v0);
const auto values_out_element_type = xor_v0->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
......
......@@ -29,10 +29,8 @@ TEST(opset_transform, opset1_pad_upgrade_pass)
auto pad_s1_result = f->get_results().at(0);
auto node = pad_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto pad_v1_node = static_pointer_cast<op::v1::Pad>(node);
EXPECT_EQ(pad_v1_node->description(), "Pad");
EXPECT_EQ(pad_v1_node->get_version(), 1);
auto pad_v1_node = as_type_ptr<op::v1::Pad>(node);
EXPECT_TRUE(pad_v1_node);
EXPECT_EQ(pad_v1_node->get_pad_mode(), pad_mode);
EXPECT_EQ(pad_v1_node->get_pads_begin(), padding_below);
......@@ -58,10 +56,8 @@ TEST(opset_transform, opset1_pad_downgrade_pass)
auto pad_s0_result = f->get_results().at(0);
auto node = pad_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto pad_v0_node = static_pointer_cast<op::v0::Pad>(node);
EXPECT_EQ(pad_v0_node->description(), "Pad");
EXPECT_EQ(pad_v0_node->get_version(), 0);
auto pad_v0_node = as_type_ptr<op::v0::Pad>(node);
EXPECT_TRUE(pad_v0_node);
EXPECT_EQ(pad_v0_node->get_pad_mode(), pad_mode);
EXPECT_EQ(pad_v0_node->get_padding_below(), CoordinateDiff({1, 2}));
......
......@@ -33,10 +33,8 @@ TEST(opset_transform, opset1_avgpool_upgrade_pass)
auto avgpool_s1_result = f->get_results().at(0);
auto node = avgpool_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto avg_pool_v1_node = static_pointer_cast<op::v1::AvgPool>(node);
EXPECT_EQ(avg_pool_v1_node->description(), "AvgPool");
EXPECT_EQ(avg_pool_v1_node->get_version(), 1);
auto avg_pool_v1_node = as_type_ptr<op::v1::AvgPool>(node);
EXPECT_TRUE(avg_pool_v1_node);
EXPECT_EQ(avg_pool_v1_node->get_pads_begin(), pads_begin);
EXPECT_EQ(avg_pool_v1_node->get_pads_end(), pads_end);
......@@ -68,10 +66,8 @@ TEST(opset_transform, opset1_maxpool_upgrade_pass)
auto maxpool_s1_result = f->get_results().at(0);
auto node = maxpool_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto max_pool_v1_node = static_pointer_cast<op::v1::MaxPool>(node);
EXPECT_EQ(max_pool_v1_node->description(), "MaxPool");
EXPECT_EQ(max_pool_v1_node->get_version(), 1);
auto max_pool_v1_node = as_type_ptr<op::v1::MaxPool>(node);
EXPECT_TRUE(max_pool_v1_node);
EXPECT_EQ(max_pool_v1_node->get_pads_begin(), pads_begin);
EXPECT_EQ(max_pool_v1_node->get_pads_end(), pads_end);
......@@ -109,10 +105,8 @@ TEST(opset_transform, opset1_avgpool_downgrade_pass)
auto avgpool_s0_result = f->get_results().at(0);
auto node = avgpool_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto avg_pool_v0_node = static_pointer_cast<op::v0::AvgPool>(node);
EXPECT_EQ(avg_pool_v0_node->description(), "AvgPool");
EXPECT_EQ(avg_pool_v0_node->get_version(), 0);
auto avg_pool_v0_node = as_type_ptr<op::v0::AvgPool>(node);
EXPECT_TRUE(avg_pool_v0_node);
EXPECT_EQ(avg_pool_v0_node->get_padding_below(), padding_below);
EXPECT_EQ(avg_pool_v0_node->get_padding_above(), padding_above);
......@@ -149,10 +143,8 @@ TEST(opset_transform, opset1_maxpool_downgrade_pass)
auto maxpool_s0_result = f->get_results().at(0);
auto node = maxpool_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto max_pool_v0_node = static_pointer_cast<op::v0::MaxPool>(node);
EXPECT_EQ(max_pool_v0_node->description(), "MaxPool");
EXPECT_EQ(max_pool_v0_node->get_version(), 0);
auto max_pool_v0_node = as_type_ptr<op::v0::MaxPool>(node);
EXPECT_TRUE(max_pool_v0_node);
EXPECT_EQ(max_pool_v0_node->get_padding_below(), padding_below);
EXPECT_EQ(max_pool_v0_node->get_padding_above(), padding_above);
......@@ -189,10 +181,8 @@ TEST(opset_transform, opset1_avgpool_backprop_downgrade_pass)
auto avgpool_backprop_s0_result = f->get_results().at(0);
auto node = avgpool_backprop_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto avg_pool_backprop_v0_node = static_pointer_cast<op::v0::AvgPoolBackprop>(node);
EXPECT_EQ(avg_pool_backprop_v0_node->description(), "AvgPoolBackprop");
EXPECT_EQ(avg_pool_backprop_v0_node->get_version(), 0);
auto avg_pool_backprop_v0_node = as_type_ptr<op::v0::AvgPoolBackprop>(node);
EXPECT_TRUE(avg_pool_backprop_v0_node);
EXPECT_EQ(avg_pool_backprop_v0_node->get_padding_below(), padding_below);
EXPECT_EQ(avg_pool_backprop_v0_node->get_padding_above(), padding_above);
......@@ -229,10 +219,8 @@ TEST(opset_transform, opset1_maxpool_backprop_downgrade_pass)
auto max_pool_backprop_s0_result = f->get_results().at(0);
auto node = max_pool_backprop_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto max_pool_backprop_v0_node = static_pointer_cast<op::v0::MaxPoolBackprop>(node);
EXPECT_EQ(max_pool_backprop_v0_node->description(), "MaxPoolBackprop");
EXPECT_EQ(max_pool_backprop_v0_node->get_version(), 0);
auto max_pool_backprop_v0_node = as_type_ptr<op::v0::MaxPoolBackprop>(node);
EXPECT_TRUE(max_pool_backprop_v0_node);
EXPECT_EQ(max_pool_backprop_v0_node->get_padding_below(), padding_below);
EXPECT_EQ(max_pool_backprop_v0_node->get_padding_above(), padding_above);
EXPECT_EQ(max_pool_backprop_v0_node->get_window_movement_strides(), window_movement_strides);
......
......@@ -41,10 +41,8 @@ TEST(opset_transform, opset1_product_upgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reduce_prod_v1 = static_pointer_cast<op::v1::ReduceProd>(pass_replacement_node);
EXPECT_EQ(reduce_prod_v1->description(), "Product");
EXPECT_EQ(reduce_prod_v1->get_version(), 1);
const auto reduce_prod_v1 = as_type_ptr<op::v1::ReduceProd>(pass_replacement_node);
EXPECT_TRUE(reduce_prod_v1);
EXPECT_EQ(reduce_prod_v1->get_keep_dims(), false);
}
......@@ -63,15 +61,12 @@ TEST(opset_transform, opset0_reduce_prod_downgrade_pass)
const auto reshape_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reshape = static_pointer_cast<op::Reshape>(reshape_replacement_node);
const auto reshape = as_type_ptr<op::Reshape>(reshape_replacement_node);
EXPECT_TRUE(reshape);
const auto product_replace_node =
reshape_replacement_node->input(0).get_source_output().get_node_shared_ptr();
const auto product_v0 = static_pointer_cast<op::v0::Product>(product_replace_node);
EXPECT_EQ(reshape->description(), "Reshape");
EXPECT_EQ(reshape->get_version(), 0);
EXPECT_EQ(product_v0->description(), "Product");
EXPECT_EQ(product_v0->get_version(), 0);
const auto product_v0 = as_type_ptr<op::v0::Product>(product_replace_node);
EXPECT_TRUE(product_v0);
}
TEST(opset_transform, opset0_reduce_prod_downgrade_pass_axes_not_constant)
......
......@@ -41,11 +41,9 @@ TEST(opset_transform, opset1_reverse_upgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reverse_v1 = static_pointer_cast<op::v1::Reverse>(pass_replacement_node);
const auto reverse_v1 = as_type_ptr<op::v1::Reverse>(pass_replacement_node);
EXPECT_TRUE(reverse_v1);
EXPECT_EQ(reverse_v1->get_mode(), op::v1::Reverse::Mode::INDEX);
EXPECT_EQ(reverse_v1->description(), "Reverse");
EXPECT_EQ(reverse_v1->get_version(), 1);
const auto& rev_axes_input_shape = reverse_v1->get_input_shape(1);
// should match the number of elements of v0::Reverse reverse_axes attribute
......@@ -69,10 +67,8 @@ TEST(opset_transform, opset0_reverse_downgrade_pass_index_mode)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reverse_v0 = static_pointer_cast<op::v0::Reverse>(pass_replacement_node);
EXPECT_EQ(reverse_v0->description(), "Reverse");
EXPECT_EQ(reverse_v0->get_version(), 0);
const auto reverse_v0 = as_type_ptr<op::v0::Reverse>(pass_replacement_node);
EXPECT_TRUE(reverse_v0);
EXPECT_EQ(reverse_v0->get_reversed_axes(), AxisSet({1, 2}));
}
......@@ -93,10 +89,8 @@ TEST(opset_transform, opset0_reverse_downgrade_pass_mask_mode)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reverse_v0 = static_pointer_cast<op::v0::Reverse>(pass_replacement_node);
EXPECT_EQ(reverse_v0->description(), "Reverse");
EXPECT_EQ(reverse_v0->get_version(), 0);
const auto reverse_v0 = as_type_ptr<op::v0::Reverse>(pass_replacement_node);
EXPECT_TRUE(reverse_v0);
EXPECT_EQ(reverse_v0->get_reversed_axes(), AxisSet({0, 2}));
}
......
......@@ -45,16 +45,17 @@ TEST(opset_transform, opset1_dyn_slice_upgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto strided_slice_v1 = as_type_ptr<op::v1::StridedSlice>(pass_replacement_node);
EXPECT_TRUE(strided_slice_v1);
auto begin_const =
as_type_ptr<op::Constant>(strided_slice_v1->input_value(1).get_node_shared_ptr());
EXPECT_TRUE(begin_const);
auto end_const =
as_type_ptr<op::Constant>(strided_slice_v1->input_value(2).get_node_shared_ptr());
EXPECT_TRUE(end_const);
auto strides_const =
as_type_ptr<op::Constant>(strided_slice_v1->input_value(3).get_node_shared_ptr());
EXPECT_TRUE(strides_const);
EXPECT_EQ(strided_slice_v1->description(), "Slice");
EXPECT_EQ(strided_slice_v1->get_version(), 1);
EXPECT_EQ(strided_slice_v1->get_begin_mask(), vector<int64_t>(4, 0));
EXPECT_EQ(strided_slice_v1->get_end_mask(), vector<int64_t>(4, 0));
EXPECT_EQ(begin_const->get_vector<int64_t>(),
......@@ -84,9 +85,7 @@ TEST(opset_transform, opset1_strided_slice_downgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto slice_v0 = as_type_ptr<op::v0::Slice>(pass_replacement_node);
EXPECT_EQ(slice_v0->description(), "Slice");
EXPECT_EQ(slice_v0->get_version(), 0);
EXPECT_TRUE(slice_v0);
EXPECT_EQ(slice_v0->get_lower_bounds(), Coordinate({1, 2, 0, 2}));
EXPECT_EQ(slice_v0->get_upper_bounds(), Coordinate({5, 4, 5, 6}));
EXPECT_EQ(slice_v0->get_strides(), Strides({1, 1, 1, 1}));
......
......@@ -40,11 +40,9 @@ TEST(opset_transform, opset1_softmax_upgrade_pass_axis)
auto softmax_s1_result = f->get_results().at(0);
auto node = softmax_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto softmax_s1_node = static_pointer_cast<op::v1::Softmax>(node);
auto softmax_s1_node = as_type_ptr<op::v1::Softmax>(node);
EXPECT_TRUE(softmax_s1_node);
EXPECT_EQ(softmax_s1_node->get_axis(), axis);
EXPECT_EQ(softmax_s1_node->description(), "Softmax");
EXPECT_EQ(softmax_s1_node->get_version(), 1);
}
TEST(opset_transform, opset1_softmax_upgrade_pass_axis_exception)
......@@ -75,43 +73,3 @@ TEST(opset_transform, opset1_softmax_upgrade_pass_axis_exception)
FAIL() << "Softmax pass failed for unexpected reason";
}
}
namespace fake_v2
{
class FakeSoftmax : public op::v0::Softmax
{
public:
FakeSoftmax(const Output<Node>& arg, const AxisSet& axes)
: Softmax{arg, axes}
{
}
size_t get_version() const override { return 2; }
};
}
TEST(opset_transform, opset1_softmax_upgrade_pass_incorrect_op_version)
{
const AxisSet axes{2};
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4});
auto softmax_s2 = make_shared<fake_v2::FakeSoftmax>(arg, axes);
auto result = make_shared<op::Result>(softmax_s2);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
try
{
pass_manager.run_passes(f);
FAIL() << "Opset 1 transformation pass failed for";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Op version 1 transformation pass failed for"));
}
catch (...)
{
FAIL() << "Softmax pass failed for unexpected reason";
}
}
......@@ -41,10 +41,8 @@ TEST(opset_transform, opset1_reduce_sum_upgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reduce_sum_v1 = static_pointer_cast<op::v1::ReduceSum>(pass_replacement_node);
EXPECT_EQ(reduce_sum_v1->description(), "Sum");
EXPECT_EQ(reduce_sum_v1->get_version(), 1);
const auto reduce_sum_v1 = as_type_ptr<op::v1::ReduceSum>(pass_replacement_node);
EXPECT_TRUE(reduce_sum_v1);
EXPECT_EQ(reduce_sum_v1->get_keep_dims(), false);
}
......@@ -63,15 +61,12 @@ TEST(opset_transform, opset0_reduce_sum_downgrade_pass)
const auto reshape_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reshape = static_pointer_cast<op::Reshape>(reshape_replacement_node);
const auto reshape = as_type_ptr<op::Reshape>(reshape_replacement_node);
EXPECT_TRUE(reshape);
const auto sum_replace_node =
reshape_replacement_node->input(0).get_source_output().get_node_shared_ptr();
const auto sum_v0 = static_pointer_cast<op::v0::Sum>(sum_replace_node);
EXPECT_EQ(reshape->description(), "Reshape");
EXPECT_EQ(reshape->get_version(), 0);
EXPECT_EQ(sum_v0->description(), "Sum");
EXPECT_EQ(sum_v0->get_version(), 0);
const auto sum_v0 = as_type_ptr<op::v0::Sum>(sum_replace_node);
EXPECT_TRUE(sum_v0);
}
TEST(opset_transform, opset0_reduce_sum_downgrade_pass_not_constant_axes)
......
......@@ -41,11 +41,9 @@ TEST(opset_transform, opset1_topk_upgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto topk_v1 = static_pointer_cast<op::v1::TopK>(pass_replacement_node);
const auto topk_v1 = as_type_ptr<op::v1::TopK>(pass_replacement_node);
EXPECT_TRUE(topk_v1);
EXPECT_EQ(topk_v1->get_axis(), axis);
EXPECT_EQ(topk_v1->description(), "TopK");
EXPECT_EQ(topk_v1->get_version(), 1);
EXPECT_EQ(topk_v1->get_mode(), op::v1::TopK::Mode::MAX);
EXPECT_EQ(topk_v1->get_sort_type(), op::v1::TopK::SortType::SORT_VALUES);
......@@ -74,9 +72,7 @@ TEST(opset_transform, opset1_topk_downgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto topk_v0 = as_type_ptr<op::v0::TopK>(pass_replacement_node);
EXPECT_EQ(topk_v0->description(), "TopK");
EXPECT_EQ(topk_v0->get_version(), 0);
EXPECT_TRUE(topk_v0);
EXPECT_EQ(topk_v0->get_k(), k);
EXPECT_EQ(topk_v0->get_top_k_axis(), axis);
EXPECT_EQ(topk_v0->get_compute_max(), true);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment