Unverified Commit 4c5dbf07 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Cyphers/uop (#3903)

* Address op_tbl issues

* fix

* fix

* fix

* Cleanup

* cleanup

* cleanup

* More fixes

* Revert ser changes

* Compiles

* opset conversion fixed

* Fix opset conversion tests

* Deal with Reciprocal and FloorMod movement

* Cleanup

* Remove duplicate enums

* Experiment

* experiment

* Types

* Reorg around clang 3.9 bug
parent 16e256fa
......@@ -105,13 +105,45 @@ std::shared_ptr<Node>
Node::copy_with_new_inputs(const OutputVector& inputs,
const std::vector<std::shared_ptr<Node>>& control_dependencies) const
{
bool for_get_output_element = is_type<op::GetOutputElement>(this);
NodeVector args;
for (const Output<Node>& input : inputs)
shared_ptr<Node> clone;
if (is_type<op::GetOutputElement>(this))
{
args.push_back(get_output_element(input, for_get_output_element));
auto& value = inputs.at(0);
clone = make_shared<op::GetOutputElement>(value.get_node_shared_ptr(), value.get_index());
}
else
{
NodeVector args;
for (const Output<Node>& input : inputs)
{
args.push_back(get_output_element(input, false));
}
for (int i = 0; i < inputs.size(); ++i)
{
auto in_val = inputs.at(i);
if (is_type<op::GetOutputElement>(in_val.get_node()))
{
in_val = as_type_ptr<op::GetOutputElement>(in_val.get_node_shared_ptr())
->get_as_output();
}
auto in_index = in_val.get_index();
auto arg = args.at(i);
size_t out_index = 0;
if (is_type<op::GetOutputElement>(arg))
{
out_index = as_type_ptr<op::GetOutputElement>(arg)->get_n();
}
if (in_index != out_index)
{
cerr << "Mismatch in: " << in_index << " arg: " << out_index << endl;
cerr << "ARG: " << *arg << endl;
cerr << "IN: " << *inputs.at(i).get_node() << endl;
cerr << "INV: " << *in_val.get_node() << endl;
cerr << "In node " << *this << endl;
}
}
clone = copy_with_new_args(args);
}
shared_ptr<Node> clone = copy_with_new_args(args);
for (auto& cdep : control_dependencies)
{
clone->add_control_dependency(cdep);
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// This collection contains one entry for each fused op.
//
#ifndef NGRAPH_OP
#warning "NGRAPH_OP not defined"
#define NGRAPH_OP(x, y)
#endif
NGRAPH_OP(BatchMatMulTranspose, ngraph::op)
NGRAPH_OP(Clamp, ngraph::op)
NGRAPH_OP(ConvolutionBias, ngraph::op)
NGRAPH_OP(ConvolutionBiasAdd, ngraph::op)
NGRAPH_OP(ConvolutionBiasBackpropFiltersBias, ngraph::op)
NGRAPH_OP(CrossEntropy, ngraph::op)
NGRAPH_OP(CrossEntropyBackprop, ngraph::op)
NGRAPH_OP(DepthToSpace, ngraph::op)
NGRAPH_OP(Elu, ngraph::op)
NGRAPH_OP(FakeQuantize, ngraph::op)
NGRAPH_OP(Gelu, ngraph::op)
NGRAPH_OP(GeluBackpropFactor, ngraph::op)
NGRAPH_OP(Gemm, ngraph::op)
NGRAPH_OP(GRN, ngraph::op)
NGRAPH_OP(GroupConvolution, ngraph::op)
NGRAPH_OP(GroupConvolutionTranspose, ngraph::op)
NGRAPH_OP(GRUCell, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(LayerNorm, ngraph::op)
NGRAPH_OP(LayerNormBackprop, ngraph::op)
NGRAPH_OP(LogSoftmax, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(LSTMSequence, ngraph::op)
NGRAPH_OP(MatMul, ngraph::op)
NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP(NormalizeL2, ngraph::op)
NGRAPH_OP(PartialSlice, ngraph::op)
NGRAPH_OP(PartialSliceBackprop, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op)
NGRAPH_OP(Reciprocal, ngraph::op)
NGRAPH_OP(RNNCell, ngraph::op)
NGRAPH_OP(ScaleShift, ngraph::op)
NGRAPH_OP(Selu, ngraph::op)
NGRAPH_OP(ShuffleChannels, ngraph::op)
NGRAPH_OP(SpaceToDepth, ngraph::op)
NGRAPH_OP(Split, ngraph::op)
NGRAPH_OP(SquaredDifference, ngraph::op)
NGRAPH_OP(SoftmaxCrossEntropy, ngraph::op)
NGRAPH_OP(SoftmaxCrossEntropyBackprop, ngraph::op)
NGRAPH_OP(Squeeze, ngraph::op)
NGRAPH_OP(TensorIterator, ngraph::op)
NGRAPH_OP(Unsqueeze, ngraph::op)
......@@ -65,10 +65,10 @@ NGRAPH_OP(Atan2, ngraph::op)
NGRAPH_OP(AvgPool, ngraph::op)
NGRAPH_OP(AvgPoolBackprop, ngraph::op)
NGRAPH_OP(BatchMatMul, ngraph::op)
NGRAPH_OP(BatchMatMulTranspose, ngraph::op)
NGRAPH_OP(BatchNormInference, ngraph::op)
NGRAPH_OP(BatchNormTraining, ngraph::op)
NGRAPH_OP(BatchNormTrainingBackprop, ngraph::op)
NGRAPH_OP(BinaryConvolution, ngraph::op)
NGRAPH_OP(Broadcast, ngraph::op)
NGRAPH_OP(BroadcastDistributed, ngraph::op)
NGRAPH_OP(BroadcastLike, ngraph::op)
......@@ -79,8 +79,13 @@ NGRAPH_OP(Convert, ngraph::op)
NGRAPH_OP(Convolution, ngraph::op)
NGRAPH_OP(ConvolutionBackpropData, ngraph::op)
NGRAPH_OP(ConvolutionBackpropFilters, ngraph::op)
NGRAPH_OP(ConvolutionBias, ngraph::op)
NGRAPH_OP(ConvolutionBiasAdd, ngraph::op)
NGRAPH_OP(ConvolutionBiasBackpropFiltersBias, ngraph::op)
NGRAPH_OP(Cos, ngraph::op)
NGRAPH_OP(Cosh, ngraph::op)
NGRAPH_OP(CrossEntropy, ngraph::op)
NGRAPH_OP(CrossEntropyBackprop, ngraph::op)
NGRAPH_OP(Dequantize, ngraph::op)
NGRAPH_OP(Divide, ngraph::op)
NGRAPH_OP(Dot, ngraph::op)
......@@ -94,22 +99,24 @@ NGRAPH_OP(Equal, ngraph::op)
NGRAPH_OP(Erf, ngraph::op)
NGRAPH_OP(Exp, ngraph::op)
NGRAPH_OP(Floor, ngraph::op)
NGRAPH_OP(FloorMod, ngraph::op)
NGRAPH_OP(GRN, ngraph::op)
NGRAPH_OP(GRUCell, ngraph::op)
NGRAPH_OP(Gather, ngraph::op)
NGRAPH_OP(GatherND, ngraph::op)
NGRAPH_OP(Gelu, ngraph::op)
NGRAPH_OP(GeluBackpropFactor, ngraph::op)
NGRAPH_OP(Gemm, ngraph::op)
NGRAPH_OP(GenerateMask, ngraph::op)
NGRAPH_OP(GetOutputElement, ngraph::op)
NGRAPH_OP(Greater, ngraph::op)
NGRAPH_OP(GreaterEq, ngraph::op)
NGRAPH_OP(GroupConvolutionTranspose, ngraph::op)
NGRAPH_OP(LayerNorm, ngraph::op)
NGRAPH_OP(LayerNormBackprop, ngraph::op)
NGRAPH_OP(Less, ngraph::op)
NGRAPH_OP(LessEq, ngraph::op)
NGRAPH_OP(LessEqual, ngraph::op)
NGRAPH_OP(Log, ngraph::op)
NGRAPH_OP(LogicalAnd, ngraph::op)
NGRAPH_OP(LogicalNot, ngraph::op)
NGRAPH_OP(LogicalOr, ngraph::op)
NGRAPH_OP(LogicalXor, ngraph::op)
NGRAPH_OP(LRN, ngraph::op)
NGRAPH_OP(LogSoftmax, ngraph::op)
NGRAPH_OP(Max, ngraph::op)
NGRAPH_OP(Maximum, ngraph::op)
NGRAPH_OP(MaxPool, ngraph::op)
......@@ -117,6 +124,7 @@ NGRAPH_OP(MaxPoolBackprop, ngraph::op)
NGRAPH_OP(Min, ngraph::op)
NGRAPH_OP(Minimum, ngraph::op)
NGRAPH_OP(Multiply, ngraph::op)
NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP(Negative, ngraph::op)
NGRAPH_OP(Not, ngraph::op)
NGRAPH_OP(NotEqual, ngraph::op)
......@@ -124,6 +132,8 @@ NGRAPH_OP(OneHot, ngraph::op)
NGRAPH_OP(Or, ngraph::op)
NGRAPH_OP(Pad, ngraph::op)
NGRAPH_OP(Parameter, ngraph::op)
NGRAPH_OP(PartialSlice, ngraph::op)
NGRAPH_OP(PartialSliceBackprop, ngraph::op)
NGRAPH_OP(Passthrough, ngraph::op)
NGRAPH_OP(Power, ngraph::op)
NGRAPH_OP(Product, ngraph::op)
......@@ -135,9 +145,10 @@ NGRAPH_OP(QuantizedConvolutionBiasSignedAdd, ngraph::op)
NGRAPH_OP(QuantizedConvolutionRelu, ngraph::op)
NGRAPH_OP(QuantizedDot, ngraph::op)
NGRAPH_OP(QuantizedDotBias, ngraph::op)
NGRAPH_OP(Recv, ngraph::op)
NGRAPH_OP(RandomUniform, ngraph::op)
NGRAPH_OP(Recv, ngraph::op)
NGRAPH_OP(Range, ngraph::op)
NGRAPH_OP(Reciprocal, ngraph::op)
NGRAPH_OP(Relu, ngraph::op)
NGRAPH_OP(ReluBackprop, ngraph::op)
NGRAPH_OP(ReplaceSlice, ngraph::op)
......@@ -146,9 +157,11 @@ NGRAPH_OP(Result, ngraph::op)
NGRAPH_OP(Reverse, ngraph::op)
NGRAPH_OP(ReverseSequence, ngraph::op)
NGRAPH_OP(ScalarConstantLike, ngraph::op)
NGRAPH_OP(ScaleShift, ngraph::op)
NGRAPH_OP(ScatterAdd, ngraph::op)
NGRAPH_OP(ScatterNDAdd, ngraph::op)
NGRAPH_OP(Select, ngraph::op)
NGRAPH_OP(Selu, ngraph::op)
NGRAPH_OP(Send, ngraph::op)
NGRAPH_OP(ShapeOf, ngraph::op)
NGRAPH_OP(Sigmoid, ngraph::op)
......@@ -158,6 +171,8 @@ NGRAPH_OP(Sin, ngraph::op)
NGRAPH_OP(Sinh, ngraph::op)
NGRAPH_OP(Slice, ngraph::op)
NGRAPH_OP(Softmax, ngraph::op)
NGRAPH_OP(SoftmaxCrossEntropy, ngraph::op)
NGRAPH_OP(SoftmaxCrossEntropyBackprop, ngraph::op)
NGRAPH_OP(Sqrt, ngraph::op)
NGRAPH_OP(StopGradient, ngraph::op)
NGRAPH_OP(Subtract, ngraph::op)
......@@ -165,7 +180,6 @@ NGRAPH_OP(Sum, ngraph::op)
NGRAPH_OP(Tan, ngraph::op)
NGRAPH_OP(Tanh, ngraph::op)
NGRAPH_OP(Tile, ngraph::op)
NGRAPH_OP(TopK, ngraph::op)
NGRAPH_OP(TopK, ngraph::op::v0)
NGRAPH_OP(Transpose, ngraph::op)
NGRAPH_OP(VariadicSplit, ngraph::op)
NGRAPH_OP(Xor, ngraph::op)
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// This collection contains one entry for each op. If an op is added it must be
// added to this list.
//
// In order to use this list you want to define a macro named exactly NGRAPH_OP
// When you are done you should undef the macro
// As an example if you wanted to make a list of all op names as strings you could do this:
//
// #define NGRAPH_OP(a,b) #a,
// std::vector<std::string> op_names{
// #include "this include file name"
// };
// #undef NGRAPH_OP
//
// This sample expands to a list like this:
// "Abs",
// "Acos",
// ...
//
// #define NGRAPH_OP(a,b) b::a,
// std::vector<std::string> op_names{
// #include "this include file name"
// };
// #undef NGRAPH_OP
//
// This sample expands to a list like this:
// ngraph::op::Abs,
// ngraph::op::Acos,
// ...
//
// It's that easy. You can use this for fun and profit.
#ifndef NGRAPH_OP
#warning "NGRAPH_OP not defined"
#define NGRAPH_OP(x, y)
#endif
NGRAPH_OP(Abs, ngraph::op)
NGRAPH_OP(Acos, ngraph::op)
// NGRAPH_OP(Acosh, ngraph::op)
NGRAPH_OP(Add, ngraph::op::v1)
NGRAPH_OP(Asin, ngraph::op)
// NGRAPH_OP(Asinh, ngraph::op)
NGRAPH_OP(Atan, ngraph::op)
// NGRAPH_OP(Atanh, ngraph::op)
NGRAPH_OP(AvgPool, ngraph::op::v1)
NGRAPH_OP(BatchNormInference, ngraph::op)
NGRAPH_OP(BinaryConvolution, ngraph::op::v1)
NGRAPH_OP(Broadcast, ngraph::op::v1)
// NGRAPH_OP(CTCGreedyDecoder, ngraph::op)
NGRAPH_OP(Ceiling, ngraph::op)
NGRAPH_OP(Clamp, ngraph::op)
NGRAPH_OP(Concat, ngraph::op)
NGRAPH_OP(Constant, ngraph::op)
NGRAPH_OP(Convert, ngraph::op)
// NGRAPH_OP(ConvertLike, ngraph::op)
NGRAPH_OP(Convolution, ngraph::op::v1)
NGRAPH_OP(ConvolutionBackpropData, ngraph::op::v1)
NGRAPH_OP(Cos, ngraph::op)
NGRAPH_OP(Cosh, ngraph::op)
// NGRAPH_OP(DeformableConvolution, ngraph::op)
// NGRAPH_OP(DeformablePSROIPooling, ngraph::op)
NGRAPH_OP(DepthToSpace, ngraph::op)
// NGRAPH_OP(DetectionOutput, ngraph::op)
NGRAPH_OP(Divide, ngraph::op::v1)
NGRAPH_OP(Elu, ngraph::op)
NGRAPH_OP(Erf, ngraph::op)
NGRAPH_OP(Equal, ngraph::op::v1)
NGRAPH_OP(Exp, ngraph::op)
NGRAPH_OP(FakeQuantize, ngraph::op)
NGRAPH_OP(Floor, ngraph::op)
NGRAPH_OP(FloorMod, ngraph::op::v1)
NGRAPH_OP(Gather, ngraph::op::v1)
NGRAPH_OP(Greater, ngraph::op::v1)
NGRAPH_OP(GreaterEq, ngraph::op::v1)
NGRAPH_OP(GroupConvolution, ngraph::op)
// NGRAPH_OP(GroupConvolutionBackpropData, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(Interpolate, ngraph::op)
// NGRAPH_OP(LeakyRelu, ngraph::op)
NGRAPH_OP(Less, ngraph::op::v1)
NGRAPH_OP(LessEqual, ngraph::op::v1)
NGRAPH_OP(Log, ngraph::op)
NGRAPH_OP(LogicalAnd, ngraph::op::v1)
NGRAPH_OP(LogicalNot, ngraph::op::v1)
NGRAPH_OP(LogicalOr, ngraph::op::v1)
NGRAPH_OP(LogicalXor, ngraph::op::v1)
NGRAPH_OP(LRN, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(LSTMSequence, ngraph::op)
NGRAPH_OP(MatMul, ngraph::op)
NGRAPH_OP(MaxPool, ngraph::op::v1)
NGRAPH_OP(Maximum, ngraph::op::v1)
NGRAPH_OP(Minimum, ngraph::op::v1)
// NGRAPH_OP(Mod, ngraph::op)
NGRAPH_OP(Multiply, ngraph::op::v1)
NGRAPH_OP(Negative, ngraph::op)
// NGRAPH_OP(NonMaxSuppression, ngraph::op)
NGRAPH_OP(NormalizeL2, ngraph::op)
NGRAPH_OP(NotEqual, ngraph::op::v1)
NGRAPH_OP(OneHot, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op)
// NGRAPH_OP(PSROIPooling, ngraph::op)
NGRAPH_OP(Pad, ngraph::op::v1)
NGRAPH_OP(Parameter, ngraph::op)
NGRAPH_OP(Power, ngraph::op::v1)
// NGRAPH_OP(PriorBox, ngraph::op)
// NGRAPH_OP(PriorBoxClustered, ngraph::op)
// NGRAPH_OP(Proposal, ngraph::op)
NGRAPH_OP(Range, ngraph::op)
NGRAPH_OP(Relu, ngraph::op)
// NGRAPH_OP(ReduceLogicalAnd, ngraph::op)
// NGRAPH_OP(ReduceLogicalOr, ngraph::op)
NGRAPH_OP(ReduceMax, ngraph::op::v1)
// NGRAPH_OP(ReduceMean, ngraph::op)
NGRAPH_OP(ReduceMin, ngraph::op::v1)
NGRAPH_OP(ReduceProd, ngraph::op::v1)
NGRAPH_OP(ReduceSum, ngraph::op::v1)
// NGRAPH_OP(RegionYolo, ngraph::op)
NGRAPH_OP(Reshape, ngraph::op::v1)
NGRAPH_OP(Result, ngraph::op)
NGRAPH_OP(Reverse, ngraph::op::v1)
NGRAPH_OP(ReverseSequence, ngraph::op)
NGRAPH_OP(RNNCell, ngraph::op)
// NGRAPH_OP(ROIPooling, ngraph::op)
NGRAPH_OP(ShapeOf, ngraph::op)
NGRAPH_OP(ShuffleChannels, ngraph::op)
NGRAPH_OP(Sign, ngraph::op)
NGRAPH_OP(Sigmoid, ngraph::op)
NGRAPH_OP(Sin, ngraph::op)
NGRAPH_OP(Sinh, ngraph::op)
NGRAPH_OP(Softmax, ngraph::op::v1)
NGRAPH_OP(Sqrt, ngraph::op)
NGRAPH_OP(SpaceToDepth, ngraph::op)
NGRAPH_OP(Split, ngraph::op)
NGRAPH_OP(SquaredDifference, ngraph::op)
NGRAPH_OP(Squeeze, ngraph::op)
NGRAPH_OP(StridedSlice, ngraph::op::v1)
NGRAPH_OP(Subtract, ngraph::op)
NGRAPH_OP(Tan, ngraph::op)
NGRAPH_OP(Tanh, ngraph::op)
NGRAPH_OP(TensorIterator, ngraph::op)
NGRAPH_OP(Tile, ngraph::op)
NGRAPH_OP(TopK, ngraph::op::v1)
NGRAPH_OP(Transpose, ngraph::op)
NGRAPH_OP(Unsqueeze, ngraph::op)
NGRAPH_OP(VariadicSplit, ngraph::op::v1)
// Related to v1
NGRAPH_OP(AvgPoolBackprop, ngraph::op::v1)
NGRAPH_OP(ConvolutionBackpropFilters, ngraph::op::v1)
NGRAPH_OP(MaxPoolBackprop, ngraph::op::v1)
// Other
NGRAPH_OP(GenerateMask, ngraph::op::v1)
......@@ -14,81 +14,190 @@
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cstdint>
#include <numeric>
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/atan2.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/binary_convolution.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast_distributed.hpp"
#include "ngraph/op/ceiling.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/cos.hpp"
#include "ngraph/op/cosh.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/erf.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/dyn_replace_slice.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/layers/interpolate.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/random_uniform.hpp"
#include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/tile.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/floor_mod.hpp"
#include "ngraph/op/fused/batch_mat_mul_transpose.hpp"
#include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/crossentropy.hpp"
#include "ngraph/op/fused/depth_to_space.hpp"
#include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/fused/fake_quantize.hpp"
#include "ngraph/op/fused/gelu.hpp"
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/layer_norm.hpp"
#include "ngraph/op/fused/log_softmax.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/fused/matmul.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/fused/partial_slice.hpp"
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/reciprocal.hpp"
#include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/selu.hpp"
#include "ngraph/op/fused/shuffle_channels.hpp"
#include "ngraph/op/fused/softmax_crossentropy.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/fused/split.hpp"
#include "ngraph/op/fused/squared_difference.hpp"
#include "ngraph/op/fused/squeeze.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/passthrough.hpp"
#include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/quantized_convolution.hpp"
#include "ngraph/op/quantized_dot.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/reduce_prod.hpp"
#include "ngraph/op/reduce_sum.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/scatter_add.hpp"
#include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp"
#include "ngraph/op/sinh.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/stop_gradient.hpp"
#include "ngraph/op/strided_slice.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/tensor_iterator.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/variadic_split.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/slice_plan.hpp"
#include <algorithm>
#include "ngraph/type.hpp"
using namespace std;
using namespace ngraph;
#define NGRAPH_OP(a, b) a,
enum class OP_TYPEID
namespace
{
#include "ngraph/op/fused_op_tbl.hpp"
#include "ngraph/op/op_tbl.hpp"
};
enum class OP_TYPEID
{
#define NGRAPH_OP(a, b) a,
//#include "ngraph/op/fused_op_tbl.hpp"
//#include "ngraph/op/op_v0_tbl.hpp"
#undef NGRAPH_OP
#define NGRAPH_OP(a, b) {#a, OP_TYPEID::a},
static unordered_map<string, OP_TYPEID> typeid_map{
#include "ngraph/op/fused_op_tbl.hpp"
#include "ngraph/op/op_tbl.hpp"
};
#define NGRAPH_OP(a, b) a##_v1,
#include "ngraph/op/op_v1_tbl.hpp"
OTHER
};
#undef NGRAPH_OP
}
static OP_TYPEID get_typeid(shared_ptr<Node> node)
{
OP_TYPEID type_id;
auto it = typeid_map.find(node->description());
static map<NodeTypeInfo, OP_TYPEID> typeid_map{
#define NGRAPH_OP(a, b) {b::a::type_info, OP_TYPEID::a##_v1},
#include "ngraph/op/op_v1_tbl.hpp"
#undef NGRAPH_OP
};
OP_TYPEID type_id = OP_TYPEID::OTHER;
auto it = typeid_map.find(node->get_type_info());
if (it != typeid_map.end())
{
type_id = it->second;
}
else
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
return type_id;
}
// END mapping to OP_TYPEID
......@@ -108,20 +217,6 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
{
bool modified = false;
size_t op_version = node->get_version();
if (op_version == 0)
{
return modified;
}
NGRAPH_CHECK(op_version == 1,
"Op version 1 transformation pass failed for ",
*node,
", only op version 1 operations expected. Op version ",
op_version,
" found.");
// Not all enumeration values explicitly handled in switch
#if defined(__clang__)
#pragma clang diagnostic push
......@@ -129,13 +224,13 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
#endif
switch (get_typeid(node))
{
case OP_TYPEID::Add:
case OP_TYPEID::Add_v1:
{
downgrade_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
modified = true;
break;
}
case OP_TYPEID::AvgPool:
case OP_TYPEID::AvgPool_v1:
{
const auto tmp = as_type_ptr<op::v1::AvgPool>(node);
......@@ -160,7 +255,7 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::AvgPoolBackprop:
case OP_TYPEID::AvgPoolBackprop_v1:
{
const auto tmp = as_type_ptr<op::v1::AvgPoolBackprop>(node);
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant());
......@@ -186,7 +281,7 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Broadcast:
case OP_TYPEID::Broadcast_v1:
{
auto tmp = dynamic_cast<const op::v1::Broadcast*>(node.get());
const auto arg = node->input(0).get_source_output();
......@@ -202,7 +297,7 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Convolution:
case OP_TYPEID::Convolution_v1:
{
auto tmp = as_type_ptr<op::v1::Convolution>(node);
const auto data_arg = node->input(0).get_source_output();
......@@ -225,7 +320,7 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::ConvolutionBackpropData:
case OP_TYPEID::ConvolutionBackpropData_v1:
{
auto tmp = as_type_ptr<op::v1::ConvolutionBackpropData>(node);
NGRAPH_CHECK(node->input_value(2).get_node_shared_ptr()->is_constant());
......@@ -253,7 +348,7 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::ConvolutionBackpropFilters:
case OP_TYPEID::ConvolutionBackpropFilters_v1:
{
auto tmp = as_type_ptr<op::v1::ConvolutionBackpropFilters>(node);
NGRAPH_CHECK(node->input_value(2).get_node_shared_ptr()->is_constant());
......@@ -281,7 +376,7 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Divide:
case OP_TYPEID::Divide_v1:
{
const auto tmp = as_type_ptr<op::v1::Divide>(node);
const auto input_arg0 = node->input(0).get_source_output();
......@@ -293,7 +388,7 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::DynReshape:
case OP_TYPEID::Reshape_v1:
{
auto tmp = as_type_ptr<op::v1::Reshape>(node);
auto replacement_node = make_shared<op::v0::DynReshape>(node->input(0).get_source_output(),
......@@ -303,13 +398,13 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Equal:
case OP_TYPEID::Equal_v1:
{
downgrade_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
modified = true;
break;
}
case OP_TYPEID::GenerateMask:
case OP_TYPEID::GenerateMask_v1:
{
auto tmp = dynamic_cast<const op::v1::GenerateMask*>(node.get());
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant());
......@@ -328,61 +423,61 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Greater:
case OP_TYPEID::Greater_v1:
{
downgrade_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
modified = true;
break;
}
case OP_TYPEID::GreaterEq:
case OP_TYPEID::GreaterEq_v1:
{
downgrade_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEq>(node);
modified = true;
break;
}
case OP_TYPEID::Less:
case OP_TYPEID::Less_v1:
{
downgrade_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
modified = true;
break;
}
case OP_TYPEID::LessEqual:
case OP_TYPEID::LessEqual_v1:
{
downgrade_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
modified = true;
break;
}
case OP_TYPEID::LogicalAnd:
case OP_TYPEID::LogicalAnd_v1:
{
downgrade_binary_elementwise_node<op::v0::And, op::v1::LogicalAnd>(node);
modified = true;
break;
}
case OP_TYPEID::LogicalNot:
case OP_TYPEID::LogicalNot_v1:
{
replace_node(node, make_shared<op::v0::Not>(node->input(0).get_source_output()));
modified = true;
break;
}
case OP_TYPEID::LogicalOr:
case OP_TYPEID::LogicalOr_v1:
{
downgrade_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
modified = true;
break;
}
case OP_TYPEID::LogicalXor:
case OP_TYPEID::LogicalXor_v1:
{
downgrade_binary_elementwise_node<op::v0::Xor, op::v1::LogicalXor>(node);
modified = true;
break;
}
case OP_TYPEID::Maximum:
case OP_TYPEID::Maximum_v1:
{
downgrade_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
modified = true;
break;
}
case OP_TYPEID::MaxPool:
case OP_TYPEID::MaxPool_v1:
{
auto tmp = as_type_ptr<op::v1::MaxPool>(node);
......@@ -405,7 +500,7 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::MaxPoolBackprop:
case OP_TYPEID::MaxPoolBackprop_v1:
{
const auto tmp = as_type_ptr<op::v1::MaxPoolBackprop>(node);
......@@ -442,25 +537,25 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Minimum:
case OP_TYPEID::Minimum_v1:
{
downgrade_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
modified = true;
break;
}
case OP_TYPEID::Multiply:
case OP_TYPEID::Multiply_v1:
{
downgrade_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
modified = true;
break;
}
case OP_TYPEID::NotEqual:
case OP_TYPEID::NotEqual_v1:
{
downgrade_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
modified = true;
break;
}
case OP_TYPEID::Pad:
case OP_TYPEID::Pad_v1:
{
auto tmp = as_type_ptr<op::v1::Pad>(node);
const auto pad_arg = node->input(0).get_source_output();
......@@ -472,13 +567,13 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Power:
case OP_TYPEID::Power_v1:
{
downgrade_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
modified = true;
break;
}
case OP_TYPEID::Product:
case OP_TYPEID::ReduceProd_v1:
{
auto tmp = as_type_ptr<op::v1::ReduceProd>(node);
auto replacement_node = make_shared<op::v0::Product>(node->input(0).get_source_output(),
......@@ -512,7 +607,7 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Reverse:
case OP_TYPEID::Reverse_v1:
{
auto tmp = as_type_ptr<op::v1::Reverse>(node);
auto axes_node = tmp->input_value(1).get_node_shared_ptr();
......@@ -544,7 +639,7 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Slice:
case OP_TYPEID::StridedSlice_v1:
{
auto convert_mask_to_axes = [](const std::vector<int64_t>& mask) {
AxisSet axes{};
......@@ -611,7 +706,7 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
replace_node(node, replacement_node);
break;
}
case OP_TYPEID::Softmax:
case OP_TYPEID::Softmax_v1:
{
auto tmp = as_type_ptr<op::v1::Softmax>(node);
auto axis = tmp->get_axis();
......@@ -625,7 +720,7 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Sum:
case OP_TYPEID::ReduceSum_v1:
{
auto tmp = as_type_ptr<op::v1::ReduceSum>(node);
auto replacement_node = make_shared<op::v0::Sum>(node->input(0).get_source_output(),
......@@ -659,18 +754,18 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::TopK:
case OP_TYPEID::TopK_v1:
{
const auto tmp = as_type_ptr<op::v1::TopK>(node);
const auto axis = tmp->get_axis();
const auto sort_type = tmp->get_sort_type();
const auto index_elem_type = tmp->get_index_element_type();
bool comnpute_max;
bool compute_max;
switch (tmp->get_mode())
{
case op::v1::TopK::Mode::MAX: comnpute_max = true; break;
case op::v1::TopK::Mode::MIN: comnpute_max = false; break;
case op::v1::TopK::Mode::MAX: compute_max = true; break;
case op::v1::TopK::Mode::MIN: compute_max = false; break;
default: break;
}
......@@ -678,7 +773,7 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
const auto k_node = node->input_value(1);
auto replacement_node = make_shared<op::v0::TopK>(
arg_node, k_node, axis, index_elem_type, comnpute_max, sort_type);
arg_node, k_node, axis, index_elem_type, compute_max, sort_type);
// values output will be 0, indices 1
vector<int64_t> output_order{1, 0};
......
......@@ -15,40 +15,147 @@
//*****************************************************************************
#include "ngraph/pass/opset1_upgrade.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/atan2.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/binary_convolution.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast_distributed.hpp"
#include "ngraph/op/ceiling.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/cos.hpp"
#include "ngraph/op/cosh.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/erf.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/dyn_replace_slice.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/layers/interpolate.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/random_uniform.hpp"
#include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/tile.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/floor_mod.hpp"
#include "ngraph/op/fused/batch_mat_mul_transpose.hpp"
#include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/crossentropy.hpp"
#include "ngraph/op/fused/depth_to_space.hpp"
#include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/fused/fake_quantize.hpp"
#include "ngraph/op/fused/gelu.hpp"
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/layer_norm.hpp"
#include "ngraph/op/fused/log_softmax.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/fused/matmul.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/fused/partial_slice.hpp"
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/reciprocal.hpp"
#include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/selu.hpp"
#include "ngraph/op/fused/shuffle_channels.hpp"
#include "ngraph/op/fused/softmax_crossentropy.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/fused/split.hpp"
#include "ngraph/op/fused/squared_difference.hpp"
#include "ngraph/op/fused/squeeze.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/passthrough.hpp"
#include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/quantized_convolution.hpp"
#include "ngraph/op/quantized_dot.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/reduce_prod.hpp"
#include "ngraph/op/reduce_sum.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/scatter_add.hpp"
#include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp"
#include "ngraph/op/sinh.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/stop_gradient.hpp"
#include "ngraph/op/strided_slice.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/tensor_iterator.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/xor.hpp"
#include <limits>
......@@ -57,33 +164,30 @@
using namespace std;
using namespace ngraph;
#define NGRAPH_OP(a, b) a,
enum class OP_TYPEID
namespace
{
#include "ngraph/op/fused_op_tbl.hpp"
#include "ngraph/op/op_tbl.hpp"
};
#undef NGRAPH_OP
#define NGRAPH_OP(a, b) {#a, OP_TYPEID::a},
static unordered_map<string, OP_TYPEID> typeid_map{
#include "ngraph/op/fused_op_tbl.hpp"
#include "ngraph/op/op_tbl.hpp"
};
enum class OP_TYPEID
{
#define NGRAPH_OP(a, b) a,
#include "ngraph/op/op_v0_tbl.hpp"
#undef NGRAPH_OP
OTHER
};
}
static OP_TYPEID get_typeid(shared_ptr<Node> node)
{
OP_TYPEID type_id;
auto it = typeid_map.find(node->description());
static map<NodeTypeInfo, OP_TYPEID> typeid_map{
#define NGRAPH_OP(a, b) {b::a::type_info, OP_TYPEID::a},
#include "ngraph/op/op_v0_tbl.hpp"
#undef NGRAPH_OP
};
OP_TYPEID type_id = OP_TYPEID::OTHER;
auto it = typeid_map.find(node->get_type_info());
if (it != typeid_map.end())
{
type_id = it->second;
}
else
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
return type_id;
}
// END mapping to OP_TYPEID
......@@ -102,20 +206,6 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
{
bool modified = false;
size_t op_version = node->get_version();
if (op_version == 1)
{
return modified;
}
NGRAPH_CHECK(op_version == 0,
"Op version 1 transformation pass failed for ",
*node,
", only op version 0 operations expected. Op version ",
op_version,
" found.");
// Not all enumeration values explicitly handled in switch
#if defined(__clang__)
#pragma clang diagnostic push
......
......@@ -21,7 +21,7 @@ else()
endif()
if (NGRAPH_INTERPRETER_ENABLE)
add_library(interpreter_backend ${LIBRARY_TYPE} int_backend.cpp node_wrapper.cpp int_executable.cpp)
add_library(interpreter_backend ${LIBRARY_TYPE} int_backend.cpp int_executable.cpp)
target_compile_definitions(interpreter_backend PRIVATE INTERPRETER_BACKEND_EXPORTS)
if(NGRAPH_LIB_VERSIONING_ENABLE)
set_target_properties(interpreter_backend PROPERTIES
......
......@@ -18,9 +18,149 @@
#include "ngraph/cpio.hpp"
#include "ngraph/descriptor/layout/dense_tensor_layout.hpp"
#include "ngraph/except.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/atan2.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/binary_convolution.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast_distributed.hpp"
#include "ngraph/op/ceiling.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/cos.hpp"
#include "ngraph/op/cosh.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/erf.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/dyn_replace_slice.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/layers/interpolate.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
#include "ngraph/op/experimental/random_uniform.hpp"
#include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/tile.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/floor_mod.hpp"
#include "ngraph/op/fused/batch_mat_mul_transpose.hpp"
#include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/crossentropy.hpp"
#include "ngraph/op/fused/depth_to_space.hpp"
#include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/fused/fake_quantize.hpp"
#include "ngraph/op/fused/gelu.hpp"
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/layer_norm.hpp"
#include "ngraph/op/fused/log_softmax.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/fused/matmul.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/fused/partial_slice.hpp"
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/reciprocal.hpp"
#include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/selu.hpp"
#include "ngraph/op/fused/shuffle_channels.hpp"
#include "ngraph/op/fused/softmax_crossentropy.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/fused/split.hpp"
#include "ngraph/op/fused/squared_difference.hpp"
#include "ngraph/op/fused/squeeze.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/passthrough.hpp"
#include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/quantized_convolution.hpp"
#include "ngraph/op/quantized_dot.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/reduce_prod.hpp"
#include "ngraph/op/reduce_sum.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/scatter_add.hpp"
#include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp"
#include "ngraph/op/sinh.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/stop_gradient.hpp"
#include "ngraph/op/strided_slice.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/tensor_iterator.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/core_fusion.hpp"
#include "ngraph/pass/fused_op_decomposition.hpp"
......@@ -39,6 +179,31 @@ using namespace ngraph;
using descriptor::layout::DenseTensorLayout;
runtime::interpreter::OP_TYPEID
runtime::interpreter::INTExecutable::get_typeid(const NodeTypeInfo& type_info)
{
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// {Abs::type_info, OP_TYPEID::Abs},
// {Acos::type_info, OP_TYPEID::Acos},
// ...
static const map<NodeTypeInfo, OP_TYPEID> type_info_map{
#define NGRAPH_OP(a, b) {b::a::type_info, OP_TYPEID::a},
#include "ngraph/op/op_v0_tbl.hpp"
#undef NGRAPH_OP
#define NGRAPH_OP(a, b) {b::a::type_info, OP_TYPEID::a##_v1},
#include "ngraph/op/op_v1_tbl.hpp"
#undef NGRAPH_OP
};
OP_TYPEID rc = OP_TYPEID::UnknownOp;
auto it = type_info_map.find(type_info);
if (it != type_info_map.end())
{
rc = it->second;
}
return rc;
}
runtime::interpreter::INTExecutable::INTExecutable(const shared_ptr<Function>& function,
bool enable_performance_collection)
: m_is_compiled{true}
......@@ -52,10 +217,9 @@ runtime::interpreter::INTExecutable::INTExecutable(const shared_ptr<Function>& f
pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>();
pass_manager.register_pass<pass::Liveness>();
pass_manager.run_passes(m_function);
for (const shared_ptr<Node>& node : m_function->get_ordered_ops())
for (auto node : m_function->get_ordered_ops())
{
m_wrapped_nodes.emplace_back(node);
m_nodes.push_back(node);
}
set_parameters_and_results(*m_function);
}
......@@ -65,9 +229,9 @@ runtime::interpreter::INTExecutable::INTExecutable(const std::string& model_stri
, m_performance_counters_enabled{false}
{
m_function = deserialize(model_string);
for (const shared_ptr<Node>& node : m_function->get_ordered_ops())
for (auto node : m_function->get_ordered_ops())
{
m_wrapped_nodes.emplace_back(node);
m_nodes.push_back(node);
}
set_parameters_and_results(*m_function);
}
......@@ -122,12 +286,10 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
}
// for each ordered op in the graph
for (const NodeWrapper& wrapped : m_wrapped_nodes)
for (auto op : m_nodes)
{
auto op = wrapped.get_node();
runtime::event::Duration d2(op->description(), "Interpreter");
auto type_id = wrapped.get_typeid();
if (type_id == OP_TYPEID::Parameter)
if (op->is_parameter())
{
continue;
}
......@@ -164,40 +326,33 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
// get op type
element::Type type;
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wswitch-enum"
#endif
switch (type_id)
if (is_type<op::Convert>(op) || is_type<op::Quantize>(op) || is_type<op::Dequantize>(op) ||
is_type<op::ArgMin>(op) || is_type<op::ArgMax>(op))
{
type = op->get_input_element_type(0);
}
else if (is_type<op::Equal>(op) || is_type<op::Greater>(op) || is_type<op::GreaterEq>(op) ||
is_type<op::Less>(op) || is_type<op::LessEq>(op) || is_type<op::NotEqual>(op))
{
case OP_TYPEID::Convert:
case OP_TYPEID::Quantize:
case OP_TYPEID::Dequantize:
case OP_TYPEID::ArgMin:
case OP_TYPEID::ArgMax: type = op->get_input_element_type(0); break;
case OP_TYPEID::Equal:
case OP_TYPEID::Greater:
case OP_TYPEID::GreaterEq:
case OP_TYPEID::Less:
case OP_TYPEID::LessEq:
case OP_TYPEID::NotEqual:
// Get the type of the second input, not the first
// All BinaryElementwiseComparision ops have the same type for inputs
// Select has bool for first input and the type we are interested in for the second
type = op->get_input_element_type(1);
break;
case OP_TYPEID::TopK: type = op->get_output_element_type(1); break;
default: type = op->get_output_element_type(0); break;
}
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
#endif
else if (is_type<op::TopK>(op))
{
type = op->get_output_element_type(1);
}
else
{
type = op->get_output_element_type(0);
}
if (m_performance_counters_enabled)
{
m_timer_map[op].start();
}
generate_calls(type, wrapped, op_outputs, op_inputs);
generate_calls(type, *op.get(), op_outputs, op_inputs);
if (m_performance_counters_enabled)
{
m_timer_map[op].stop();
......@@ -212,7 +367,7 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
}
void runtime::interpreter::INTExecutable::generate_calls(const element::Type& type,
const NodeWrapper& op,
const Node& op,
const vector<shared_ptr<HostTensor>>& out,
const vector<shared_ptr<HostTensor>>& in)
{
......@@ -235,7 +390,7 @@ void runtime::interpreter::INTExecutable::generate_calls(const element::Type& ty
case element::Type_t::u1:
case element::Type_t::bf16:
case element::Type_t::f16:
ss << "unsupported element type " << type << " op " << op.get_node()->get_name();
ss << "unsupported element type " << type << " op " << op.get_name();
throw ngraph_error(ss.str());
}
}
......
......@@ -81,9 +81,11 @@
#include "ngraph/op/send.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/strided_slice.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/op/variadic_split.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/runtime/backend.hpp"
......@@ -91,7 +93,6 @@
#ifdef INTERPRETER_USE_HYBRID
#include "ngraph/runtime/hybrid/op/function_call.hpp"
#endif
#include "ngraph/runtime/interpreter/node_wrapper.hpp"
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp"
......@@ -187,6 +188,29 @@ namespace ngraph
{
class INTBackend;
class INTExecutable;
namespace
{
// This expands the op list in op_tbl.hpp into a list of enumerations that look like
// this:
// Abs,
// Acos,
// ...
enum class OP_TYPEID
{
#define NGRAPH_OP(a, b) a,
#include "ngraph/op/op_v0_tbl.hpp"
#ifdef INTERPRETER_USE_HYBRID
#include "ngraph/runtime/hybrid/op/op_tbl.hpp"
#endif
#undef NGRAPH_OP
#define NGRAPH_OP(a, b) a##_v1,
#include "ngraph/op/op_v1_tbl.hpp"
#undef NGRAPH_OP
UnknownOp
};
}
} // namespace interpreter
} // namespace runtime
} // namespace ngraph
......@@ -229,35 +253,25 @@ private:
bool m_performance_counters_enabled = false;
std::shared_ptr<Function> m_function;
std::unordered_map<std::shared_ptr<const Node>, stopwatch> m_timer_map;
std::vector<NodeWrapper> m_wrapped_nodes;
std::vector<std::shared_ptr<Node>> m_nodes;
std::unordered_map<const Node*, std::shared_ptr<State>> m_states;
std::set<std::string> m_unsupported_op_name_list;
static OP_TYPEID get_typeid(const NodeTypeInfo& type_info);
static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&,
const Node* op = nullptr);
void generate_calls(const element::Type& type,
const NodeWrapper& op,
const Node& op,
const std::vector<std::shared_ptr<HostTensor>>& outputs,
const std::vector<std::shared_ptr<HostTensor>>& inputs);
template <typename T>
void op_engine(const NodeWrapper& node_wrapper,
void op_engine(const Node& node,
const std::vector<std::shared_ptr<HostTensor>>& out,
const std::vector<std::shared_ptr<HostTensor>>& args)
{
const Node& node = *node_wrapper.get_node();
size_t op_version = node.get_version();
bool is_op_version_supported = op_version == 0;
NGRAPH_CHECK(is_op_version_supported,
"Unsupported operator version ",
op_version,
" in ",
node,
".\n",
"INTERPRETER backend currently only supports op in version 0.");
// We want to check that every OP_TYPEID enumeration is included in the list.
// These GCC flags enable compile-time checking so that if an enumeration
// is not in the list an error is generated.
......@@ -267,7 +281,7 @@ private:
#pragma GCC diagnostic error "-Wswitch-enum"
// #pragma GCC diagnostic error "-Wcovered-switch-default"
#endif
switch (node_wrapper.get_typeid())
switch (get_typeid(node.get_type_info()))
{
case OP_TYPEID::Abs:
{
......@@ -426,7 +440,7 @@ private:
avg_pool->get_include_padding_in_avg_computation());
break;
}
case OP_TYPEID::BinaryConvolution:
case OP_TYPEID::BinaryConvolution_v1:
{
throw unsupported_op("Unsupported op '" + node.description() + "'");
break;
......@@ -1017,7 +1031,7 @@ private:
less_eq->get_autob());
break;
}
case OP_TYPEID::LessEqual:
case OP_TYPEID::LessEqual_v1:
{
auto less_eq = static_cast<const op::v1::LessEqual*>(&node);
reference::less_eq<T>(args[0]->get_data_ptr<const T>(),
......@@ -1035,7 +1049,7 @@ private:
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::LogicalAnd:
case OP_TYPEID::LogicalAnd_v1:
{
auto logical_and = static_cast<const op::v1::LogicalAnd*>(&node);
reference::logical_and(args[0]->get_data_ptr<const T>(),
......@@ -1046,7 +1060,7 @@ private:
logical_and->get_autob());
break;
}
case OP_TYPEID::LogicalOr:
case OP_TYPEID::LogicalOr_v1:
{
auto logical_or = static_cast<const op::v1::LogicalOr*>(&node);
reference::logical_or(args[0]->get_data_ptr<const T>(),
......@@ -1057,7 +1071,7 @@ private:
logical_or->get_autob());
break;
}
case OP_TYPEID::LogicalXor:
case OP_TYPEID::LogicalXor_v1:
{
auto logical_xor = static_cast<const op::v1::LogicalXor*>(&node);
reference::logical_xor(args[0]->get_data_ptr<const T>(),
......@@ -1068,7 +1082,7 @@ private:
logical_xor->get_autob());
break;
}
case OP_TYPEID::LRN:
case OP_TYPEID::LRN_v1:
{
const op::LRN* lrn = static_cast<const op::LRN*>(&node);
reference::lrn<T>(args[0]->get_data_ptr<const T>(),
......@@ -1171,7 +1185,7 @@ private:
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::LogicalNot:
case OP_TYPEID::LogicalNot_v1:
case OP_TYPEID::Not:
{
size_t element_count = shape_size(node.get_output_shape(0));
......@@ -1551,6 +1565,11 @@ private:
throw unsupported_op("Unsupported op '" + node.description() + "'");
break;
}
case OP_TYPEID::Reciprocal:
{
throw unsupported_op("Unsupported op '" + node.description() + "'");
break;
}
case OP_TYPEID::Relu:
{
size_t element_count = shape_size(node.get_output_shape(0));
......@@ -1868,8 +1887,115 @@ private:
case OP_TYPEID::DynPad:
case OP_TYPEID::Tile:
case OP_TYPEID::DynReplaceSlice:
case OP_TYPEID::FloorMod:
case OP_TYPEID::VariadicSplit:
case OP_TYPEID::Abs_v1:
case OP_TYPEID::Acos_v1:
case OP_TYPEID::Add_v1:
case OP_TYPEID::Asin_v1:
case OP_TYPEID::Atan_v1:
case OP_TYPEID::AvgPool_v1:
case OP_TYPEID::BatchMatMulTranspose:
case OP_TYPEID::BatchNormInference_v1:
case OP_TYPEID::Broadcast_v1:
case OP_TYPEID::Ceiling_v1:
case OP_TYPEID::Clamp_v1:
case OP_TYPEID::Concat_v1:
case OP_TYPEID::Constant_v1:
case OP_TYPEID::Convert_v1:
case OP_TYPEID::Convolution_v1:
case OP_TYPEID::ConvolutionBackpropData_v1:
case OP_TYPEID::ConvolutionBias:
case OP_TYPEID::ConvolutionBiasAdd:
case OP_TYPEID::ConvolutionBiasBackpropFiltersBias:
case OP_TYPEID::Cos_v1:
case OP_TYPEID::Cosh_v1:
case OP_TYPEID::CrossEntropy:
case OP_TYPEID::CrossEntropyBackprop:
case OP_TYPEID::DepthToSpace_v1:
case OP_TYPEID::Divide_v1:
case OP_TYPEID::Elu_v1:
case OP_TYPEID::Erf_v1:
case OP_TYPEID::Equal_v1:
case OP_TYPEID::Exp_v1:
case OP_TYPEID::FakeQuantize_v1:
case OP_TYPEID::Floor_v1:
case OP_TYPEID::FloorMod_v1:
case OP_TYPEID::Gather_v1:
case OP_TYPEID::Greater_v1:
case OP_TYPEID::GreaterEq_v1:
case OP_TYPEID::GroupConvolution_v1:
case OP_TYPEID::GRN:
case OP_TYPEID::GRUCell:
case OP_TYPEID::Gelu:
case OP_TYPEID::GeluBackpropFactor:
case OP_TYPEID::Gemm:
case OP_TYPEID::GroupConvolutionTranspose:
case OP_TYPEID::HardSigmoid_v1:
case OP_TYPEID::Interpolate_v1:
case OP_TYPEID::LayerNorm:
case OP_TYPEID::LayerNormBackprop:
case OP_TYPEID::Less_v1:
case OP_TYPEID::Log_v1:
case OP_TYPEID::LogSoftmax:
case OP_TYPEID::LSTMCell_v1:
case OP_TYPEID::LSTMSequence_v1:
case OP_TYPEID::MatMul_v1:
case OP_TYPEID::MaxPool_v1:
case OP_TYPEID::Maximum_v1:
case OP_TYPEID::Minimum_v1:
case OP_TYPEID::Multiply_v1:
case OP_TYPEID::MVN:
case OP_TYPEID::Negative_v1:
case OP_TYPEID::NormalizeL2_v1:
case OP_TYPEID::NotEqual_v1:
case OP_TYPEID::OneHot_v1:
case OP_TYPEID::PRelu_v1:
case OP_TYPEID::Pad_v1:
case OP_TYPEID::Parameter_v1:
case OP_TYPEID::PartialSlice:
case OP_TYPEID::PartialSliceBackprop:
case OP_TYPEID::Power_v1:
case OP_TYPEID::Range_v1:
case OP_TYPEID::Relu_v1:
case OP_TYPEID::ReduceMax_v1:
case OP_TYPEID::ReduceMin_v1:
case OP_TYPEID::ReduceProd_v1:
case OP_TYPEID::ReduceSum_v1:
case OP_TYPEID::Reshape_v1:
case OP_TYPEID::Result_v1:
case OP_TYPEID::Reverse_v1:
case OP_TYPEID::ReverseSequence_v1:
case OP_TYPEID::RNNCell_v1:
case OP_TYPEID::ScaleShift:
case OP_TYPEID::Selu:
case OP_TYPEID::ShapeOf_v1:
case OP_TYPEID::ShuffleChannels_v1:
case OP_TYPEID::Sign_v1:
case OP_TYPEID::Sigmoid_v1:
case OP_TYPEID::Sin_v1:
case OP_TYPEID::Sinh_v1:
case OP_TYPEID::Softmax_v1:
case OP_TYPEID::SoftmaxCrossEntropy:
case OP_TYPEID::SoftmaxCrossEntropyBackprop:
case OP_TYPEID::Sqrt_v1:
case OP_TYPEID::SpaceToDepth_v1:
case OP_TYPEID::Split_v1:
case OP_TYPEID::SquaredDifference_v1:
case OP_TYPEID::StridedSlice_v1:
case OP_TYPEID::Subtract_v1:
case OP_TYPEID::Tan_v1:
case OP_TYPEID::Tanh_v1:
case OP_TYPEID::TensorIterator_v1:
case OP_TYPEID::Tile_v1:
case OP_TYPEID::TopK_v1:
case OP_TYPEID::Transpose_v1:
case OP_TYPEID::Unsqueeze_v1:
case OP_TYPEID::AvgPoolBackprop_v1:
case OP_TYPEID::ConvolutionBackpropFilters_v1:
case OP_TYPEID::MaxPoolBackprop_v1:
case OP_TYPEID::Squeeze_v1:
case OP_TYPEID::GenerateMask_v1:
case OP_TYPEID::UnknownOp:
case OP_TYPEID::VariadicSplit_v1:
throw unsupported_op("Unsupported op '" + node.description() + "'");
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/interpreter/node_wrapper.hpp"
using namespace ngraph;
using namespace std;
runtime::interpreter::NodeWrapper::NodeWrapper(const shared_ptr<const Node>& node)
: m_node{node}
{
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// {"Abs", runtime::interpreter::OP_TYPEID::Abs},
// {"Acos", runtime::interpreter::OP_TYPEID::Acos},
// ...
#define NGRAPH_OP(a, b) {#a, runtime::interpreter::OP_TYPEID::a},
static unordered_map<string, runtime::interpreter::OP_TYPEID> typeid_map{
#include "ngraph/op/op_tbl.hpp"
#ifdef INTERPRETER_USE_HYBRID
#include "ngraph/runtime/hybrid/op/op_tbl.hpp"
#endif
};
#undef NGRAPH_OP
auto it = typeid_map.find(m_node->description());
if (it != typeid_map.end())
{
m_typeid = it->second;
}
else
{
throw unsupported_op("Unsupported op '" + m_node->description() + "'");
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node.hpp"
namespace ngraph
{
namespace runtime
{
namespace interpreter
{
enum class OP_TYPEID;
class NodeWrapper;
}
}
}
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// Abs,
// Acos,
// ...
#define NGRAPH_OP(a, b) a,
enum class ngraph::runtime::interpreter::OP_TYPEID
{
#include "ngraph/op/op_tbl.hpp"
#ifdef INTERPRETER_USE_HYBRID
#include "ngraph/runtime/hybrid/op/op_tbl.hpp"
#endif
};
#undef NGRAPH_OP
/// \brief This class allows adding an enum typeid to each Node. This makes dealing with
/// collections of Nodes a little easier and faster as we can use switch() instead of
/// if/else statements
class ngraph::runtime::interpreter::NodeWrapper
{
public:
NodeWrapper(const std::shared_ptr<const ngraph::Node>& node);
std::shared_ptr<const Node> get_node() const { return m_node; }
ngraph::runtime::interpreter::OP_TYPEID get_typeid() const { return m_typeid; }
private:
std::shared_ptr<const ngraph::Node> m_node;
OP_TYPEID m_typeid;
};
......@@ -61,6 +61,7 @@
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/layers/interpolate.hpp"
#include "ngraph/op/experimental/quantized_conv_bias.hpp"
#include "ngraph/op/experimental/quantized_conv_relu.hpp"
#include "ngraph/op/experimental/quantized_dot_bias.hpp"
......@@ -75,6 +76,7 @@
#include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/crossentropy.hpp"
#include "ngraph/op/fused/crossentropy.hpp"
#include "ngraph/op/fused/depth_to_space.hpp"
#include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/fused/fake_quantize.hpp"
......@@ -155,6 +157,7 @@
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/stop_gradient.hpp"
#include "ngraph/op/strided_slice.hpp"
#include "ngraph/op/strided_slice.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tan.hpp"
......@@ -163,6 +166,7 @@
#include "ngraph/op/topk.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/variadic_split.hpp"
#include "ngraph/op/variadic_split.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/provenance.hpp"
#include "ngraph/serializer.hpp"
......@@ -187,34 +191,42 @@ bool ngraph::get_serialize_output_shapes()
return s_serialize_output_shapes_enabled;
}
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// Abs,
// Acos,
// ...
#define NGRAPH_OP(a, b) a,
enum class OP_TYPEID
namespace
{
#include "ngraph/op/fused_op_tbl.hpp"
#include "ngraph/op/op_tbl.hpp"
UnknownOp
};
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// Abs,
// Acos,
// ...
enum class OP_TYPEID
{
#define NGRAPH_OP(a, b) a,
#include "ngraph/op/op_v0_tbl.hpp"
#undef NGRAPH_OP
#define NGRAPH_OP(a, b) a##_v1,
#include "ngraph/op/op_v1_tbl.hpp"
#undef NGRAPH_OP
UnknownOp
};
}
static OP_TYPEID get_typeid(const string& s)
static OP_TYPEID get_typeid(const NodeTypeInfo& type_info)
{
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// {"Abs", OP_TYPEID::Abs},
// {"Acos", OP_TYPEID::Acos},
// ...
#define NGRAPH_OP(a, b) {#a, OP_TYPEID::a},
static const unordered_map<string, OP_TYPEID> typeid_map{
#include "ngraph/op/fused_op_tbl.hpp"
#include "ngraph/op/op_tbl.hpp"
};
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// {Abs::type_info, OP_TYPEID::Abs},
// {Acos::type_info, OP_TYPEID::Acos},
// ...
static const map<NodeTypeInfo, OP_TYPEID> type_info_map{
#define NGRAPH_OP(a, b) {b::a::type_info, OP_TYPEID::a},
#include "ngraph/op/op_v0_tbl.hpp"
#undef NGRAPH_OP
#define NGRAPH_OP(a, b) {b::a::type_info, OP_TYPEID::a##_v1},
#include "ngraph/op/op_v1_tbl.hpp"
#undef NGRAPH_OP
};
OP_TYPEID rc = OP_TYPEID::UnknownOp;
auto it = typeid_map.find(s);
if (it != typeid_map.end())
auto it = type_info_map.find(type_info);
if (it != type_info_map.end())
{
rc = it->second;
}
......@@ -902,10 +914,19 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
shared_ptr<Node> node;
try
{
string node_name = node_js.at("name").get<string>();
string node_op = node_js.at("op").get<string>();
string friendly_name = get_value<string>(node_js, "friendly_name");
size_t op_version = get_value<size_t>(node_js, "op_version");
NodeTypeInfo type_info{node_op.c_str(), op_version};
string type_info_name;
if (has_key(node_js, "type_info"))
{
json jtype_info = node_js["type_info"];
type_info_name = jtype_info.at("name").get<string>();
type_info.name = type_info_name.c_str();
type_info.version = jtype_info.at("version").get<uint64_t>();
}
string node_name = node_js.at("name").get<string>();
string friendly_name = get_value<string>(node_js, "friendly_name");
vector<json> control_deps_inputs = get_value<vector<json>>(node_js, "control_deps");
vector<string> node_outputs = get_value<vector<string>>(node_js, "outputs");
OutputVectorHelper args(deserialize_output_vector(node_js["inputs"]));
......@@ -917,19 +938,22 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
// #pragma GCC diagnostic error "-Wimplicit-fallthrough"
#endif
switch (get_typeid(node_op))
switch (get_typeid(type_info))
{
case OP_TYPEID::Abs:
case OP_TYPEID::Abs_v1:
{
node = make_shared<op::Abs>(args[0]);
break;
}
case OP_TYPEID::Acos:
case OP_TYPEID::Acos_v1:
{
node = make_shared<op::Acos>(args[0]);
break;
}
case OP_TYPEID::Add:
case OP_TYPEID::Add_v1:
{
if (op_version == 0)
{
......@@ -984,11 +1008,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Asin:
case OP_TYPEID::Asin_v1:
{
node = make_shared<op::Asin>(args[0]);
break;
}
case OP_TYPEID::Atan:
case OP_TYPEID::Atan_v1:
{
node = make_shared<op::Atan>(args[0]);
break;
......@@ -1000,6 +1026,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::AvgPool:
case OP_TYPEID::AvgPool_v1:
{
if (op_version == 0)
{
......@@ -1042,6 +1069,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::AvgPoolBackprop:
case OP_TYPEID::AvgPoolBackprop_v1:
{
if (op_version == 0)
{
......@@ -1094,6 +1122,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::BatchNormInference:
case OP_TYPEID::BatchNormInference_v1:
{
auto epsilon = node_js.at("eps").get<double>();
// Odd order for back-compatibility
......@@ -1109,7 +1138,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
args[2], args[0], args[1], args[3], args[4], args[5], epsilon);
break;
}
case OP_TYPEID::BinaryConvolution:
case OP_TYPEID::BinaryConvolution_v1:
{
auto strides = node_js.at("strides").get<vector<size_t>>();
auto dilations = node_js.at("dilations").get<vector<size_t>>();
......@@ -1131,6 +1160,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Broadcast:
case OP_TYPEID::Broadcast_v1:
{
if (op_version == 0)
{
......@@ -1157,11 +1187,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Ceiling:
case OP_TYPEID::Ceiling_v1:
{
node = make_shared<op::Ceiling>(args[0]);
break;
}
case OP_TYPEID::Clamp:
case OP_TYPEID::Clamp_v1:
{
const auto clamp_min = node_js.at("min").get<float>();
const auto clamp_max = node_js.at("max").get<float>();
......@@ -1169,12 +1200,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Concat:
case OP_TYPEID::Concat_v1:
{
auto axis = node_js.at("axis").get<size_t>();
node = make_shared<op::Concat>(static_cast<OutputVector>(args), axis);
break;
}
case OP_TYPEID::Constant:
case OP_TYPEID::Constant_v1:
{
auto type_node_js =
has_key(node_js, "element_type") ? node_js : node_js.at("value_type");
......@@ -1185,12 +1218,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Convert:
case OP_TYPEID::Convert_v1:
{
auto target_type = read_element_type(node_js.at("target_type"));
node = make_shared<op::Convert>(args[0], target_type);
break;
}
case OP_TYPEID::Convolution:
case OP_TYPEID::Convolution_v1:
{
if (op_version == 0)
{
......@@ -1252,6 +1287,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::ConvolutionBackpropData:
case OP_TYPEID::ConvolutionBackpropData_v1:
{
if (op_version == 0)
{
......@@ -1287,6 +1323,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::ConvolutionBackpropFilters:
case OP_TYPEID::ConvolutionBackpropFilters_v1:
{
if (op_version == 0)
{
......@@ -1391,11 +1428,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Cos:
case OP_TYPEID::Cos_v1:
{
node = make_shared<op::Cos>(args[0]);
break;
}
case OP_TYPEID::Cosh:
case OP_TYPEID::Cosh_v1:
{
node = make_shared<op::Cosh>(args[0]);
break;
......@@ -1415,7 +1454,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
args[0], args[1], args[2], soft_label, ignore_index);
break;
}
case OP_TYPEID::DepthToSpace:
case OP_TYPEID::DepthToSpace_v1:
{
auto mode = node_js.at("mode").get<op::DepthToSpace::DepthToSpaceMode>();
auto block_size = node_js.at("block_size").get<size_t>();
......@@ -1430,6 +1469,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Divide:
case OP_TYPEID::Divide_v1:
{
bool pythondiv = get_or_default(node_js, "pythondiv", true);
if (op_version == 0)
......@@ -1491,6 +1531,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::DynReshape:
case OP_TYPEID::Reshape_v1:
{
const auto zero_flag = node_js.at("zero_flag").get<bool>();
if (op_version == 0)
......@@ -1521,7 +1562,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
ellipsis_mask);
break;
}
case OP_TYPEID::Elu:
case OP_TYPEID::Elu_v1:
{
auto alpha = node_js.at("alpha").get<double>();
node = make_shared<op::Elu>(args[0], alpha);
......@@ -1533,6 +1574,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Equal:
case OP_TYPEID::Equal_v1:
{
if (op_version == 0)
{
......@@ -1549,16 +1591,18 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Erf:
case OP_TYPEID::Erf_v1:
{
node = make_shared<op::Erf>(args[0]);
break;
}
case OP_TYPEID::Exp:
case OP_TYPEID::Exp_v1:
{
node = make_shared<op::Exp>(args[0]);
break;
}
case OP_TYPEID::FakeQuantize:
case OP_TYPEID::FakeQuantize_v1:
{
size_t levels = node_js.at("levels").get<size_t>();
node =
......@@ -1566,17 +1610,19 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Floor:
case OP_TYPEID::Floor_v1:
{
node = make_shared<op::Floor>(args[0]);
break;
}
case OP_TYPEID::FloorMod:
case OP_TYPEID::FloorMod_v1:
{
node = make_shared<op::FloorMod>(
node = make_shared<op::v1::FloorMod>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break;
}
case OP_TYPEID::Gather:
case OP_TYPEID::Gather_v1:
{
if (op_version == 0)
{
......@@ -1614,6 +1660,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::GenerateMask:
case OP_TYPEID::GenerateMask_v1:
{
auto type = read_element_type(node_js.at("type"));
auto seed = node_js.at("seed").get<unsigned int>();
......@@ -1643,6 +1690,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Greater:
case OP_TYPEID::Greater_v1:
{
if (op_version == 0)
{
......@@ -1659,6 +1707,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::GreaterEq:
case OP_TYPEID::GreaterEq_v1:
{
if (op_version == 0)
{
......@@ -1680,7 +1729,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::GRN>(args[0], bias);
break;
}
case OP_TYPEID::GroupConvolution:
case OP_TYPEID::GroupConvolution_v1:
{
auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>();
......@@ -1747,13 +1796,15 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
linear_before_reset);
break;
}
case OP_TYPEID::HardSigmoid:
case OP_TYPEID::HardSigmoid_v1:
{
auto alpha = node_js.at("alpha").get<float>();
auto beta = node_js.at("beta").get<float>();
node = make_shared<op::HardSigmoid>(args[0], alpha, beta);
break;
}
case OP_TYPEID::Interpolate_v1: { break;
}
case OP_TYPEID::LayerNorm:
{
auto keep_stats = node_js.at("keep_stats").get<bool>();
......@@ -1800,6 +1851,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Less:
case OP_TYPEID::Less_v1:
{
if (op_version == 0)
{
......@@ -1821,7 +1873,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break;
}
case OP_TYPEID::LessEqual:
case OP_TYPEID::LessEqual_v1:
{
node = make_shared<op::v1::LessEqual>(
args[0],
......@@ -1830,28 +1882,29 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Log:
case OP_TYPEID::Log_v1:
{
node = make_shared<op::Log>(args[0]);
break;
}
case OP_TYPEID::LogicalAnd:
case OP_TYPEID::LogicalAnd_v1:
{
node = make_shared<op::v1::LogicalAnd>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break;
}
case OP_TYPEID::LogicalNot:
case OP_TYPEID::LogicalNot_v1:
{
node = make_shared<op::v1::LogicalNot>(args[0]);
break;
}
case OP_TYPEID::LogicalOr:
case OP_TYPEID::LogicalOr_v1:
{
node = make_shared<op::v1::LogicalOr>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break;
}
case OP_TYPEID::LogicalXor:
case OP_TYPEID::LogicalXor_v1:
{
node = make_shared<op::v1::LogicalXor>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
......@@ -1863,7 +1916,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::LogSoftmax>(args[0], axis);
break;
}
case OP_TYPEID::LRN:
case OP_TYPEID::LRN_v1:
{
auto alpha = node_js.at("alpha").get<double>();
auto beta = node_js.at("beta").get<double>();
......@@ -1872,7 +1925,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::LRN>(args[0], args[1], alpha, beta, bias, nsize);
break;
}
case OP_TYPEID::LSTMCell:
case OP_TYPEID::LSTMCell_v1:
{
auto hidden_size = node_js.at("hidden_size").get<size_t>();
auto weights_format = read_lstm_weights_format(node_js);
......@@ -1931,7 +1984,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
break;
}
case OP_TYPEID::LSTMSequence:
case OP_TYPEID::LSTMSequence_v1:
{
auto hidden_size = node_js.at("hidden_size").get<size_t>();
auto clip = node_js.at("clip").get<float>();
......@@ -1980,7 +2033,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
break;
}
case OP_TYPEID::MatMul:
case OP_TYPEID::MatMul_v1:
{
bool transpose_a = node_js.at("transpose_a").get<bool>();
bool transpose_b = node_js.at("transpose_b").get<bool>();
......@@ -1993,7 +2046,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::Max>(args[0], reduction_axes);
break;
}
case OP_TYPEID::ReduceMax_v1:
{
node = make_shared<op::v1::ReduceMax>(args[0], args[1]);
break;
}
case OP_TYPEID::MaxPool:
case OP_TYPEID::MaxPool_v1:
{
if (op_version == 0)
{
......@@ -2046,6 +2105,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::MaxPoolBackprop:
case OP_TYPEID::MaxPoolBackprop_v1:
{
if (op_version == 0)
{
......@@ -2094,6 +2154,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Maximum:
case OP_TYPEID::Maximum_v1:
{
if (op_version == 0)
{
......@@ -2115,7 +2176,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::Min>(args[0], reduction_axes);
break;
}
case OP_TYPEID::ReduceMin_v1:
{
node = make_shared<op::v1::ReduceMin>(args[0], args[1]);
break;
}
case OP_TYPEID::Minimum:
case OP_TYPEID::Minimum_v1:
{
if (op_version == 0)
{
......@@ -2132,6 +2199,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Multiply:
case OP_TYPEID::Multiply_v1:
{
if (op_version == 0)
{
......@@ -2156,11 +2224,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Negative:
case OP_TYPEID::Negative_v1:
{
node = make_shared<op::Negative>(args[0]);
break;
}
case OP_TYPEID::NormalizeL2:
case OP_TYPEID::NormalizeL2_v1:
{
float eps = node_js.at("eps").get<float>();
auto eps_mode = node_js.at("eps_mode").get<op::EpsMode>();
......@@ -2168,6 +2237,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::NotEqual:
case OP_TYPEID::NotEqual_v1:
{
if (op_version == 0)
{
......@@ -2189,6 +2259,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::OneHot:
case OP_TYPEID::OneHot_v1:
{
auto shape = node_js.at("shape").get<vector<size_t>>();
auto one_hot_axis = node_js.at("one_hot_axis").get<size_t>();
......@@ -2202,6 +2273,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Pad:
case OP_TYPEID::Pad_v1:
{
if (op_version == 0)
{
......@@ -2237,6 +2309,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Parameter:
case OP_TYPEID::Parameter_v1:
{
auto type_node_js =
has_key(node_js, "element_type") ? node_js : node_js.at("value_type");
......@@ -2282,6 +2355,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Power:
case OP_TYPEID::Power_v1:
{
if (op_version == 0)
{
......@@ -2297,12 +2371,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
break;
}
case OP_TYPEID::PRelu:
case OP_TYPEID::PRelu_v1:
{
node = make_shared<op::PRelu>(args[0], args[1]);
break;
}
case OP_TYPEID::Product:
case OP_TYPEID::ReduceProd_v1:
{
if (op_version == 0)
{
......@@ -2408,6 +2483,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Range:
case OP_TYPEID::Range_v1:
{
node = make_shared<op::Range>(args[0], args[1], args[2]);
break;
......@@ -2418,6 +2494,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Relu:
case OP_TYPEID::Relu_v1:
{
node = make_shared<op::Relu>(args[0]);
break;
......@@ -2444,6 +2521,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Result:
case OP_TYPEID::Result_v1:
{
auto needs_default_layout =
get_or_default<bool>(node_js, "needs_default_layout", false);
......@@ -2451,6 +2529,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Reverse:
case OP_TYPEID::Reverse_v1:
{
if (op_version == 0)
{
......@@ -2467,13 +2546,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::ReverseSequence:
case OP_TYPEID::ReverseSequence_v1:
{
auto batch_axis = node_js.at("batch_axis").get<size_t>();
auto sequence_axis = node_js.at("sequence_axis").get<size_t>();
node = make_shared<op::ReverseSequence>(args[0], args[1], batch_axis, sequence_axis);
break;
}
case OP_TYPEID::RNNCell:
case OP_TYPEID::RNNCell_v1:
{
auto hidden_size = node_js.at("hidden_size").get<size_t>();
auto clip = node_js.at("clip").get<float>();
......@@ -2530,11 +2610,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::ShapeOf:
case OP_TYPEID::ShapeOf_v1:
{
node = make_shared<op::ShapeOf>(args[0]);
break;
}
case OP_TYPEID::ShuffleChannels:
case OP_TYPEID::ShuffleChannels_v1:
{
const auto axis = node_js.at("axis").get<size_t>();
const auto groups = node_js.at("groups").get<size_t>();
......@@ -2542,6 +2623,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Sigmoid:
case OP_TYPEID::Sigmoid_v1:
{
node = make_shared<op::Sigmoid>(args[0]);
break;
......@@ -2552,21 +2634,25 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Sign:
case OP_TYPEID::Sign_v1:
{
node = make_shared<op::Sign>(args[0]);
break;
}
case OP_TYPEID::Sin:
case OP_TYPEID::Sin_v1:
{
node = make_shared<op::Sin>(args[0]);
break;
}
case OP_TYPEID::Sinh:
case OP_TYPEID::Sinh_v1:
{
node = make_shared<op::Sinh>(args[0]);
break;
}
case OP_TYPEID::Slice:
case OP_TYPEID::StridedSlice_v1:
{
if (op_version == 0)
{
......@@ -2595,6 +2681,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Softmax:
case OP_TYPEID::Softmax_v1:
{
if (op_version == 0)
{
......@@ -2630,14 +2717,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
args[0], args[1], args[2], soft_label, ignore_index);
break;
}
case OP_TYPEID::SpaceToDepth:
case OP_TYPEID::SpaceToDepth_v1:
{
auto block_size = node_js.at("block_size").get<size_t>();
auto mode = node_js.at("mode").get<op::SpaceToDepth::SpaceToDepthMode>();
node = make_shared<op::SpaceToDepth>(args[0], mode, block_size);
break;
}
case OP_TYPEID::Split:
case OP_TYPEID::Split_v1:
{
const auto axis = node_js.at("axis").get<size_t>();
const auto splits = node_js.at("splits").get<vector<size_t>>();
......@@ -2645,27 +2732,30 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Sqrt:
case OP_TYPEID::Sqrt_v1:
{
node = make_shared<op::Sqrt>(args[0]);
break;
}
case OP_TYPEID::SquaredDifference:
case OP_TYPEID::SquaredDifference_v1:
{
node = make_shared<op::SquaredDifference>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break;
}
case OP_TYPEID::Squeeze:
case OP_TYPEID::Squeeze_v1:
{
node = make_shared<op::Squeeze>(args[0], args[1]);
break;
}
case OP_TYPEID::Subtract:
case OP_TYPEID::Subtract_v1:
{
node = make_shared<op::Subtract>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break;
}
case OP_TYPEID::ReduceSum_v1:
case OP_TYPEID::Sum:
{
if (op_version == 0)
......@@ -2684,16 +2774,18 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Tan:
case OP_TYPEID::Tan_v1:
{
node = make_shared<op::Tan>(args[0]);
break;
}
case OP_TYPEID::Tanh:
case OP_TYPEID::Tanh_v1:
{
node = make_shared<op::Tanh>(args[0]);
break;
}
case OP_TYPEID::TensorIterator:
case OP_TYPEID::TensorIterator_v1:
{
auto ti = make_shared<op::TensorIterator>(args);
json jbody = node_js["body"];
......@@ -2738,11 +2830,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::Tile:
case OP_TYPEID::Tile_v1:
{
node = make_shared<op::Tile>(args[0], args[1]);
break;
}
case OP_TYPEID::TopK:
case OP_TYPEID::TopK_v1:
{
if (op_version == 0)
{
......@@ -2782,6 +2876,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break;
}
case OP_TYPEID::Transpose:
case OP_TYPEID::Transpose_v1:
{
node = make_shared<op::Transpose>(args[0], args[1]);
break;
......@@ -2791,12 +2886,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::StopGradient>(args[0]);
break;
}
case OP_TYPEID::Unsqueeze:
case OP_TYPEID::Unsqueeze_v1:
{
node = make_shared<op::Unsqueeze>(args[0], args[1]);
break;
}
case OP_TYPEID::VariadicSplit:
case OP_TYPEID::VariadicSplit_v1:
{
node = make_shared<op::v1::VariadicSplit>(args[0], args[1], args[2]);
break;
......@@ -2810,7 +2905,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
case OP_TYPEID::UnknownOp:
{
stringstream ss;
ss << "unsupported op " << node_op;
ss << "unsupported op " << type_info.name << ":" << type_info.version;
throw runtime_error(ss.str());
}
}
......@@ -2913,7 +3008,12 @@ json JSONSerializer::serialize_output_vector(const OutputVector& output_vector)
json JSONSerializer::serialize_node(const Node& n)
{
m_nodes_serialized.insert(&n);
const NodeTypeInfo& type_info = n.get_type_info();
json jtype_info;
jtype_info["name"] = type_info.name;
jtype_info["version"] = type_info.version;
json node;
node["type_info"] = jtype_info;
node["name"] = n.get_name();
auto op_version = n.get_version();
node["op_version"] = op_version;
......@@ -2922,7 +3022,7 @@ json JSONSerializer::serialize_node(const Node& n)
{
node["friendly_name"] = n.get_friendly_name();
}
node["op"] = n.description();
node["op"] = n.type_info.name;
// TODO Multiple outputs
json inputs = json::array();
json control_deps = json::array();
......@@ -2973,20 +3073,22 @@ json JSONSerializer::serialize_node(const Node& n)
node["provenance_tags"] = provenance_tags;
}
string node_op = n.description();
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
// #pragma GCC diagnostic error "-Wimplicit-fallthrough"
#endif
switch (get_typeid(node_op))
switch (get_typeid(type_info))
{
case OP_TYPEID::Abs: { break;
case OP_TYPEID::Abs:
case OP_TYPEID::Abs_v1: { break;
}
case OP_TYPEID::Acos: { break;
case OP_TYPEID::Acos:
case OP_TYPEID::Acos_v1: { break;
}
case OP_TYPEID::Add:
case OP_TYPEID::Add_v1:
{
const op::util::BinaryElementwiseArithmetic* tmp = nullptr;
if (op_version == 0)
......@@ -3040,9 +3142,11 @@ json JSONSerializer::serialize_node(const Node& n)
node["reduction_axes"] = serialize_axis_set(tmp->get_reduction_axes());
break;
}
case OP_TYPEID::Asin: { break;
case OP_TYPEID::Asin:
case OP_TYPEID::Asin_v1: { break;
}
case OP_TYPEID::Atan: { break;
case OP_TYPEID::Atan:
case OP_TYPEID::Atan_v1: { break;
}
case OP_TYPEID::Atan2:
{
......@@ -3054,6 +3158,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::AvgPool:
case OP_TYPEID::AvgPool_v1:
{
if (op_version == 0)
{
......@@ -3084,6 +3189,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::AvgPoolBackprop:
case OP_TYPEID::AvgPoolBackprop_v1:
{
if (op_version == 0)
{
......@@ -3124,6 +3230,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::BatchNormInference:
case OP_TYPEID::BatchNormInference_v1:
{
auto tmp = static_cast<const op::BatchNormInference*>(&n);
node["eps"] = tmp->get_eps_value();
......@@ -3135,7 +3242,7 @@ json JSONSerializer::serialize_node(const Node& n)
node["eps"] = tmp->get_eps_value();
break;
}
case OP_TYPEID::BinaryConvolution:
case OP_TYPEID::BinaryConvolution_v1:
{
auto tmp = static_cast<const op::v1::BinaryConvolution*>(&n);
node["strides"] = tmp->get_strides();
......@@ -3148,6 +3255,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::Broadcast:
case OP_TYPEID::Broadcast_v1:
{
if (op_version == 0)
{
......@@ -3173,9 +3281,10 @@ json JSONSerializer::serialize_node(const Node& n)
node["initial_axes"] = serialize_axis_set(tmp->get_initial_broadcast_axes());
break;
}
case OP_TYPEID::Ceiling: { break;
case OP_TYPEID::Ceiling:
case OP_TYPEID::Ceiling_v1: { break;
}
case OP_TYPEID::Clamp:
case OP_TYPEID::Clamp_v1:
{
auto tmp = static_cast<const op::Clamp*>(&n);
node["min"] = tmp->get_min();
......@@ -3183,12 +3292,14 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::Concat:
case OP_TYPEID::Concat_v1:
{
auto tmp = static_cast<const op::Concat*>(&n);
node["axis"] = tmp->get_concatenation_axis();
break;
}
case OP_TYPEID::Constant:
case OP_TYPEID::Constant_v1:
{
auto tmp = static_cast<const op::Constant*>(&n);
if (tmp->are_all_data_elements_bitwise_identical() && shape_size(tmp->get_shape()) > 0)
......@@ -3206,12 +3317,14 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::Convert:
case OP_TYPEID::Convert_v1:
{
auto tmp = static_cast<const op::Convert*>(&n);
node["target_type"] = write_element_type(tmp->get_convert_element_type());
break;
}
case OP_TYPEID::Convolution:
case OP_TYPEID::Convolution_v1:
{
if (op_version == 0)
{
......@@ -3235,6 +3348,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::ConvolutionBackpropData:
case OP_TYPEID::ConvolutionBackpropData_v1:
{
if (op_version == 0)
{
......@@ -3258,6 +3372,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::ConvolutionBackpropFilters:
case OP_TYPEID::ConvolutionBackpropFilters_v1:
{
if (op_version == 0)
{
......@@ -3312,9 +3427,11 @@ json JSONSerializer::serialize_node(const Node& n)
node["data_dilation_strides_forward"] = tmp->get_data_dilation_strides_forward();
break;
}
case OP_TYPEID::Cos: { break;
case OP_TYPEID::Cos:
case OP_TYPEID::Cos_v1: { break;
}
case OP_TYPEID::Cosh: { break;
case OP_TYPEID::Cosh:
case OP_TYPEID::Cosh_v1: { break;
}
case OP_TYPEID::CrossEntropy:
{
......@@ -3337,7 +3454,7 @@ json JSONSerializer::serialize_node(const Node& n)
node["axes"] = serialize_axis_set(tmp->get_axes());
break;
}
case OP_TYPEID::DepthToSpace:
case OP_TYPEID::DepthToSpace_v1:
{
auto tmp = static_cast<const op::DepthToSpace*>(&n);
node["type"] = write_element_type(tmp->get_element_type());
......@@ -3346,6 +3463,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::Divide:
case OP_TYPEID::Divide_v1:
{
const op::util::BinaryElementwiseArithmetic* bea_node = nullptr;
if (op_version == 0)
......@@ -3387,6 +3505,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::DynReshape:
case OP_TYPEID::Reshape_v1:
{
if (op_version == 0)
{
......@@ -3410,7 +3529,7 @@ json JSONSerializer::serialize_node(const Node& n)
node["ellipsis_mask"] = tmp->get_ellipsis_mask();
break;
}
case OP_TYPEID::Elu:
case OP_TYPEID::Elu_v1:
{
auto tmp = static_cast<const op::Elu*>(&n);
node["alpha"] = tmp->get_alpha();
......@@ -3419,6 +3538,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::EmbeddingLookup: { break;
}
case OP_TYPEID::Equal:
case OP_TYPEID::Equal_v1:
{
const op::util::BinaryElementwiseComparison* tmp = nullptr;
if (op_version == 0)
......@@ -3435,21 +3555,24 @@ json JSONSerializer::serialize_node(const Node& n)
}
break;
}
case OP_TYPEID::Erf: { break;
case OP_TYPEID::Erf:
case OP_TYPEID::Erf_v1: { break;
}
case OP_TYPEID::Exp: { break;
case OP_TYPEID::Exp:
case OP_TYPEID::Exp_v1: { break;
}
case OP_TYPEID::FakeQuantize:
case OP_TYPEID::FakeQuantize_v1:
{
auto tmp = static_cast<const op::FakeQuantize*>(&n);
node["levels"] = tmp->get_levels();
break;
}
case OP_TYPEID::Floor: { break;
case OP_TYPEID::Floor:
case OP_TYPEID::Floor_v1: { break;
}
case OP_TYPEID::FloorMod:
case OP_TYPEID::FloorMod_v1:
{
auto tmp = static_cast<const op::FloorMod*>(&n);
auto tmp = static_cast<const op::v1::FloorMod*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
......@@ -3457,6 +3580,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::Gather:
case OP_TYPEID::Gather_v1:
{
if (op_version == 0)
{
......@@ -3487,6 +3611,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::GenerateMask:
case OP_TYPEID::GenerateMask_v1:
{
auto tmp = static_cast<const op::GenerateMask*>(&n);
node["type"] = write_element_type(tmp->get_element_type());
......@@ -3501,6 +3626,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::Greater:
case OP_TYPEID::Greater_v1:
{
const op::util::BinaryElementwiseComparison* tmp = nullptr;
if (op_version == 0)
......@@ -3518,6 +3644,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::GreaterEq:
case OP_TYPEID::GreaterEq_v1:
{
const op::util::BinaryElementwiseComparison* tmp = nullptr;
if (op_version == 0)
......@@ -3551,7 +3678,7 @@ json JSONSerializer::serialize_node(const Node& n)
node["linear_before_reset"] = tmp->get_linear_before_reset();
break;
}
case OP_TYPEID::GroupConvolution:
case OP_TYPEID::GroupConvolution_v1:
{
auto tmp = static_cast<const op::GroupConvolution*>(&n);
node["window_movement_strides"] = tmp->get_window_movement_strides();
......@@ -3576,13 +3703,15 @@ json JSONSerializer::serialize_node(const Node& n)
node["output_shape"] = tmp->get_output_shape();
break;
}
case OP_TYPEID::HardSigmoid:
case OP_TYPEID::HardSigmoid_v1:
{
auto tmp = static_cast<const op::HardSigmoid*>(&n);
node["alpha"] = tmp->get_alpha();
node["beta"] = tmp->get_beta();
break;
}
case OP_TYPEID::Interpolate_v1: { break;
}
case OP_TYPEID::LayerNorm:
{
auto tmp = static_cast<const op::LayerNorm*>(&n);
......@@ -3602,6 +3731,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::Less:
case OP_TYPEID::Less_v1:
{
const op::util::BinaryElementwiseComparison* tmp = nullptr;
if (op_version == 0)
......@@ -3627,7 +3757,7 @@ json JSONSerializer::serialize_node(const Node& n)
}
break;
}
case OP_TYPEID::LessEqual:
case OP_TYPEID::LessEqual_v1:
{
auto tmp = static_cast<const op::v1::LessEqual*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
......@@ -3636,9 +3766,10 @@ json JSONSerializer::serialize_node(const Node& n)
}
break;
}
case OP_TYPEID::Log: { break;
case OP_TYPEID::Log:
case OP_TYPEID::Log_v1: { break;
}
case OP_TYPEID::LogicalAnd:
case OP_TYPEID::LogicalAnd_v1:
{
auto tmp = static_cast<const op::v1::LogicalAnd*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
......@@ -3647,9 +3778,9 @@ json JSONSerializer::serialize_node(const Node& n)
}
break;
}
case OP_TYPEID::LogicalNot: { break;
case OP_TYPEID::LogicalNot_v1: { break;
}
case OP_TYPEID::LogicalOr:
case OP_TYPEID::LogicalOr_v1:
{
auto tmp = static_cast<const op::v1::LogicalOr*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
......@@ -3658,7 +3789,7 @@ json JSONSerializer::serialize_node(const Node& n)
}
break;
}
case OP_TYPEID::LogicalXor:
case OP_TYPEID::LogicalXor_v1:
{
auto tmp = static_cast<const op::v1::LogicalXor*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
......@@ -3673,7 +3804,7 @@ json JSONSerializer::serialize_node(const Node& n)
node["axis"] = tmp->get_axis();
break;
}
case OP_TYPEID::LRN:
case OP_TYPEID::LRN_v1:
{
auto tmp = static_cast<const op::LRN*>(&n);
node["alpha"] = tmp->get_alpha();
......@@ -3682,7 +3813,7 @@ json JSONSerializer::serialize_node(const Node& n)
node["nsize"] = tmp->get_nsize();
break;
}
case OP_TYPEID::LSTMCell:
case OP_TYPEID::LSTMCell_v1:
{
auto tmp = static_cast<const op::LSTMCell*>(&n);
node["hidden_size"] = tmp->get_hidden_size();
......@@ -3694,7 +3825,7 @@ json JSONSerializer::serialize_node(const Node& n)
node["input_forget"] = tmp->get_input_forget();
break;
}
case OP_TYPEID::LSTMSequence:
case OP_TYPEID::LSTMSequence_v1:
{
auto tmp = dynamic_cast<const op::LSTMSequence*>(&n);
node["direction"] = tmp->get_direction();
......@@ -3707,7 +3838,7 @@ json JSONSerializer::serialize_node(const Node& n)
node["input_forget"] = tmp->get_input_forget();
break;
}
case OP_TYPEID::MatMul:
case OP_TYPEID::MatMul_v1:
{
auto tmp = static_cast<const op::MatMul*>(&n);
node["transpose_a"] = tmp->get_transpose_a();
......@@ -3721,6 +3852,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::MaxPool:
case OP_TYPEID::MaxPool_v1:
{
if (op_version == 0)
{
......@@ -3744,6 +3876,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::MaxPoolBackprop:
case OP_TYPEID::MaxPoolBackprop_v1:
{
if (op_version == 0)
{
......@@ -3764,6 +3897,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::Maximum:
case OP_TYPEID::Maximum_v1:
{
const op::util::BinaryElementwiseArithmetic* tmp = nullptr;
if (op_version == 0)
......@@ -3786,7 +3920,11 @@ json JSONSerializer::serialize_node(const Node& n)
node["reduction_axes"] = serialize_axis_set(tmp->get_reduction_axes());
break;
}
case OP_TYPEID::ReduceMin_v1:
case OP_TYPEID::ReduceMax_v1: { break;
}
case OP_TYPEID::Minimum:
case OP_TYPEID::Minimum_v1:
{
const op::util::BinaryElementwiseArithmetic* tmp = nullptr;
if (op_version == 0)
......@@ -3804,6 +3942,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::Multiply:
case OP_TYPEID::Multiply_v1:
{
const op::util::BinaryElementwiseArithmetic* tmp = nullptr;
if (op_version == 0)
......@@ -3828,9 +3967,10 @@ json JSONSerializer::serialize_node(const Node& n)
node["eps"] = tmp->get_eps();
break;
}
case OP_TYPEID::Negative: { break;
case OP_TYPEID::Negative:
case OP_TYPEID::Negative_v1: { break;
}
case OP_TYPEID::NormalizeL2:
case OP_TYPEID::NormalizeL2_v1:
{
auto tmp = static_cast<const op::NormalizeL2*>(&n);
node["eps"] = tmp->get_eps();
......@@ -3838,6 +3978,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::NotEqual:
case OP_TYPEID::NotEqual_v1:
{
const op::util::BinaryElementwiseComparison* tmp = nullptr;
if (op_version == 0)
......@@ -3857,6 +3998,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::Not: { break;
}
case OP_TYPEID::OneHot:
case OP_TYPEID::OneHot_v1:
{
auto tmp = static_cast<const op::OneHot*>(&n);
node["shape"] = write_partial_shape(tmp->get_output_partial_shape(0));
......@@ -3873,6 +4015,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::Pad:
case OP_TYPEID::Pad_v1:
{
if (op_version == 0)
{
......@@ -3889,6 +4032,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::Parameter:
case OP_TYPEID::Parameter_v1:
{
auto tmp = static_cast<const op::Parameter*>(&n);
node["shape"] = write_partial_shape(tmp->get_output_partial_shape(0));
......@@ -3930,9 +4074,10 @@ json JSONSerializer::serialize_node(const Node& n)
node["output_shapes"] = std::move(outputs_js);
break;
}
case OP_TYPEID::PRelu: { break;
case OP_TYPEID::PRelu_v1: { break;
}
case OP_TYPEID::Product:
case OP_TYPEID::ReduceProd_v1:
{
if (op_version == 0)
{
......@@ -3946,6 +4091,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::Power:
case OP_TYPEID::Power_v1:
{
const op::util::BinaryElementwiseArithmetic* tmp = nullptr;
if (op_version == 0)
......@@ -4016,11 +4162,13 @@ json JSONSerializer::serialize_node(const Node& n)
node["fixed_seed"] = tmp->get_fixed_seed();
break;
}
case OP_TYPEID::Range: { break;
case OP_TYPEID::Range:
case OP_TYPEID::Range_v1: { break;
}
case OP_TYPEID::Reciprocal: { break;
}
case OP_TYPEID::Relu: { break;
case OP_TYPEID::Relu:
case OP_TYPEID::Relu_v1: { break;
}
case OP_TYPEID::ReluBackprop: { break;
}
......@@ -4040,12 +4188,14 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::Result:
case OP_TYPEID::Result_v1:
{
auto tmp = static_cast<const op::Result*>(&n);
node["needs_default_layout"] = tmp->needs_default_layout();
break;
}
case OP_TYPEID::Reverse:
case OP_TYPEID::Reverse_v1:
{
if (op_version == 0)
{
......@@ -4062,13 +4212,14 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::ReverseSequence:
case OP_TYPEID::ReverseSequence_v1:
{
auto tmp = static_cast<const op::ReverseSequence*>(&n);
node["batch_axis"] = tmp->get_batch_axis();
node["sequence_axis"] = tmp->get_sequence_axis();
break;
}
case OP_TYPEID::RNNCell:
case OP_TYPEID::RNNCell_v1:
{
auto tmp = static_cast<const op::RNNCell*>(&n);
node["hidden_size"] = tmp->get_hidden_size();
......@@ -4102,26 +4253,32 @@ json JSONSerializer::serialize_node(const Node& n)
node["dest_id"] = tmp->get_dest_id();
break;
}
case OP_TYPEID::ShapeOf: { break;
case OP_TYPEID::ShapeOf:
case OP_TYPEID::ShapeOf_v1: { break;
}
case OP_TYPEID::ShuffleChannels:
case OP_TYPEID::ShuffleChannels_v1:
{
const auto tmp = static_cast<const op::ShuffleChannels*>(&n);
node["axis"] = tmp->get_axis();
node["groups"] = tmp->get_groups();
break;
}
case OP_TYPEID::Sigmoid: { break;
case OP_TYPEID::Sigmoid:
case OP_TYPEID::Sigmoid_v1: { break;
}
case OP_TYPEID::SigmoidBackprop: { break;
}
case OP_TYPEID::Sign: { break;
case OP_TYPEID::Sign:
case OP_TYPEID::Sign_v1: { break;
}
case OP_TYPEID::Sin: { break;
case OP_TYPEID::Sin:
case OP_TYPEID::Sin_v1: { break;
}
case OP_TYPEID::Sinh: { break;
case OP_TYPEID::Sinh:
case OP_TYPEID::Sinh_v1: { break;
}
case OP_TYPEID::Slice:
case OP_TYPEID::StridedSlice_v1:
{
if (op_version == 0)
{
......@@ -4141,7 +4298,7 @@ json JSONSerializer::serialize_node(const Node& n)
}
break;
}
case OP_TYPEID::SpaceToDepth:
case OP_TYPEID::SpaceToDepth_v1:
{
auto tmp = static_cast<const op::SpaceToDepth*>(&n);
node["type"] = write_element_type(tmp->get_element_type());
......@@ -4149,16 +4306,17 @@ json JSONSerializer::serialize_node(const Node& n)
node["block_size"] = tmp->get_block_size();
break;
}
case OP_TYPEID::Split:
case OP_TYPEID::Split_v1:
{
auto tmp = static_cast<const op::Split*>(&n);
node["axis"] = tmp->get_axis();
node["splits"] = tmp->get_splits();
break;
}
case OP_TYPEID::Sqrt: { break;
case OP_TYPEID::Sqrt:
case OP_TYPEID::Sqrt_v1: { break;
}
case OP_TYPEID::SquaredDifference:
case OP_TYPEID::SquaredDifference_v1:
{
auto tmp = static_cast<const op::SquaredDifference*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
......@@ -4167,11 +4325,12 @@ json JSONSerializer::serialize_node(const Node& n)
}
break;
}
case OP_TYPEID::Squeeze: { break;
case OP_TYPEID::Squeeze_v1: { break;
}
case OP_TYPEID::StopGradient: { break;
}
case OP_TYPEID::Subtract:
case OP_TYPEID::Subtract_v1:
{
auto tmp = static_cast<const op::Subtract*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
......@@ -4181,6 +4340,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::Sum:
case OP_TYPEID::ReduceSum_v1:
{
if (op_version == 0)
{
......@@ -4194,6 +4354,7 @@ json JSONSerializer::serialize_node(const Node& n)
break;
}
case OP_TYPEID::Softmax:
case OP_TYPEID::Softmax_v1:
{
if (op_version == 0)
{
......@@ -4220,11 +4381,13 @@ json JSONSerializer::serialize_node(const Node& n)
node["ignore_index"] = tmp->get_ignore_index();
break;
}
case OP_TYPEID::Tan: { break;
case OP_TYPEID::Tan:
case OP_TYPEID::Tan_v1: { break;
}
case OP_TYPEID::Tanh: { break;
case OP_TYPEID::Tanh:
case OP_TYPEID::Tanh_v1: { break;
}
case OP_TYPEID::TensorIterator:
case OP_TYPEID::TensorIterator_v1:
{
auto tmp = static_cast<const op::TensorIterator*>(&n);
json body = json::object();
......@@ -4266,9 +4429,11 @@ json JSONSerializer::serialize_node(const Node& n)
node["output_descriptions"] = outs;
break;
}
case OP_TYPEID::Tile: { break;
case OP_TYPEID::Tile:
case OP_TYPEID::Tile_v1: { break;
}
case OP_TYPEID::TopK:
case OP_TYPEID::TopK_v1:
{
if (op_version == 0)
{
......@@ -4286,9 +4451,10 @@ json JSONSerializer::serialize_node(const Node& n)
}
break;
}
case OP_TYPEID::Transpose: { break;
case OP_TYPEID::Transpose:
case OP_TYPEID::Transpose_v1: { break;
}
case OP_TYPEID::Unsqueeze: { break;
case OP_TYPEID::Unsqueeze_v1: { break;
}
case OP_TYPEID::Xor:
{
......@@ -4299,7 +4465,7 @@ json JSONSerializer::serialize_node(const Node& n)
}
break;
}
case OP_TYPEID::VariadicSplit: { break;
case OP_TYPEID::VariadicSplit_v1: { break;
}
case OP_TYPEID::UnknownOp: { break;
}
......
......@@ -17,6 +17,7 @@
#pragma once
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <utility>
......@@ -38,12 +39,31 @@ namespace ngraph
const char* name;
uint64_t version;
bool is_castable(const DiscreteTypeInfo& target_type) const { return this == &target_type; }
bool is_castable(const DiscreteTypeInfo& target_type) const { return *this == target_type; }
// For use as a key
bool operator<(const DiscreteTypeInfo& b) const
{
return version < b.version ||
(version == b.version && std::string(name) < std::string(b.name));
return version < b.version || (version == b.version && strcmp(name, b.name) < 0);
}
bool operator<=(const DiscreteTypeInfo& b) const
{
return version < b.version || (version == b.version && strcmp(name, b.name) <= 0);
}
bool operator>(const DiscreteTypeInfo& b) const
{
return version < b.version || (version == b.version && strcmp(name, b.name) > 0);
}
bool operator>=(const DiscreteTypeInfo& b) const
{
return version < b.version || (version == b.version && strcmp(name, b.name) >= 0);
}
bool operator==(const DiscreteTypeInfo& b) const
{
return version == b.version && strcmp(name, b.name) == 0;
}
bool operator!=(const DiscreteTypeInfo& b) const
{
return version != b.version || strcmp(name, b.name) != 0;
}
};
......
......@@ -35,11 +35,10 @@ void test_type_prop_opset0_downgrade_pass(const element::Type& output_type,
pass_manager.run_passes(f);
auto v0_result = f->get_results().at(0);
auto node = v0_result->input(0).get_source_output().get_node_shared_ptr();
auto v0_node = static_pointer_cast<OpV0>(node);
auto node = v0_result->input_value(0).get_node_shared_ptr();
auto v0_node = as_type_ptr<OpV0>(node);
EXPECT_EQ(v0_node->description(), (node_name.empty() ? v1_node->description() : node_name));
EXPECT_EQ(v0_node->get_version(), 0);
EXPECT_TRUE(v0_node);
EXPECT_EQ(v0_node->get_autob(), np_auto_b);
EXPECT_EQ(v0_node->output(0).get_element_type(), output_type);
EXPECT_EQ(v0_node->output(0).get_shape(), (Shape{1, 3, 2}));
......
......@@ -22,20 +22,18 @@ TEST(opset_transform, opset1_broadcast_upgrade_pass)
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
auto bcast_v1 = static_pointer_cast<op::v1::Broadcast>(
auto bcast_v1 = as_type_ptr<op::v1::Broadcast>(
f->get_results().at(0)->input_value(0).get_node_shared_ptr());
EXPECT_EQ(bcast_v1->description(), "Broadcast");
EXPECT_EQ(bcast_v1->get_version(), 1);
EXPECT_TRUE(bcast_v1);
EXPECT_EQ(bcast_v1->get_broadcast_spec(), op::AutoBroadcastSpec());
EXPECT_EQ(bcast_v1->get_broadcast_axes(), (std::make_pair<bool, AxisSet>(true, AxisSet{0, 2})));
EXPECT_EQ(bcast_v1->input_value(1).get_node()->description(), "Constant");
EXPECT_EQ(bcast_v1->input_value(2).get_node()->description(), "Constant");
EXPECT_TRUE(bcast_v1->input_value(1).get_node()->is_constant());
EXPECT_TRUE(bcast_v1->input_value(2).get_node()->is_constant());
EXPECT_EQ(static_pointer_cast<op::Constant>(bcast_v1->input_value(1).get_node_shared_ptr())
->get_shape_val(),
(Shape{3, 5, 4, 6}));
EXPECT_EQ(static_pointer_cast<op::Constant>(bcast_v1->input_value(2).get_node_shared_ptr())
EXPECT_EQ(as_type_ptr<op::Constant>(bcast_v1->input_value(2).get_node_shared_ptr())
->get_axis_set_val(),
(AxisSet{1, 3}));
}
......@@ -53,11 +51,10 @@ TEST(opset_transform, opset1_broadcast_downgrade_pass)
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
auto bcast_v0 = static_pointer_cast<op::v0::Broadcast>(
auto bcast_v0 = as_type_ptr<op::v0::Broadcast>(
f->get_results().at(0)->input_value(0).get_node_shared_ptr());
EXPECT_EQ(bcast_v0->description(), "Broadcast");
EXPECT_EQ(bcast_v0->get_version(), 0);
EXPECT_TRUE(bcast_v0);
EXPECT_EQ(bcast_v0->get_broadcast_shape(), (Shape{3, 1, 4, 2, 3}));
EXPECT_EQ(bcast_v0->get_broadcast_axes(), (AxisSet{0, 2}));
}
......@@ -33,10 +33,9 @@ TEST(opset_transform, opset1_convolution_upgrade_pass)
auto convolution_s1_result = f->get_results().at(0);
auto node = convolution_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto convolution_v1_node = static_pointer_cast<op::v1::Convolution>(node);
auto convolution_v1_node = as_type_ptr<op::v1::Convolution>(node);
EXPECT_EQ(convolution_v1_node->description(), "Convolution");
EXPECT_EQ(convolution_v1_node->get_version(), 1);
EXPECT_TRUE(convolution_v1_node);
EXPECT_EQ(convolution_v1_node->get_pads_begin(), pads_begin);
EXPECT_EQ(convolution_v1_node->get_pads_end(), pads_end);
......@@ -66,10 +65,9 @@ TEST(opset_transform, opset1_convolution_downgrade_pass)
auto conv_s0_result = f->get_results().at(0);
auto node = conv_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto conv_v0_node = static_pointer_cast<op::v0::Convolution>(node);
auto conv_v0_node = as_type_ptr<op::v0::Convolution>(node);
EXPECT_EQ(conv_v0_node->description(), "Convolution");
EXPECT_EQ(conv_v0_node->get_version(), 0);
EXPECT_TRUE(conv_v0_node);
EXPECT_EQ(conv_v0_node->get_window_movement_strides(), strides);
EXPECT_EQ(conv_v0_node->get_window_dilation_strides(), dilations);
EXPECT_EQ(conv_v0_node->get_padding_below(), pads_begin);
......@@ -99,10 +97,9 @@ TEST(opset_transform, opset1_convolution_backprop_data_downgrade_pass)
auto conv_s0_result = f->get_results().at(0);
auto node = conv_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto conv_v0_node = static_pointer_cast<op::v0::ConvolutionBackpropData>(node);
auto conv_v0_node = as_type_ptr<op::v0::ConvolutionBackpropData>(node);
EXPECT_EQ(conv_v0_node->description(), "ConvolutionBackpropData");
EXPECT_EQ(conv_v0_node->get_version(), 0);
EXPECT_TRUE(conv_v0_node);
EXPECT_EQ(conv_v0_node->get_data_batch_shape(), (Shape{64, 3, 100}));
EXPECT_EQ(conv_v0_node->get_window_movement_strides_forward(), strides);
EXPECT_EQ(conv_v0_node->get_window_dilation_strides_forward(), dilations);
......@@ -131,10 +128,9 @@ TEST(opset_transform, opset1_convolution_backprop_filters_downgrade_pass)
auto conv_s0_result = f->get_results().at(0);
auto node = conv_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto conv_v0_node = static_pointer_cast<op::v0::ConvolutionBackpropFilters>(node);
auto conv_v0_node = as_type_ptr<op::v0::ConvolutionBackpropFilters>(node);
EXPECT_EQ(conv_v0_node->description(), "ConvolutionBackpropFilters");
EXPECT_EQ(conv_v0_node->get_version(), 0);
EXPECT_TRUE(conv_v0_node);
EXPECT_EQ(conv_v0_node->get_filters_shape(), (Shape{128, 3, 10}));
EXPECT_EQ(conv_v0_node->get_window_movement_strides_forward(), strides);
EXPECT_EQ(conv_v0_node->get_window_dilation_strides_forward(), dilations);
......
......@@ -39,12 +39,8 @@ TEST(opset_transform, opset1_dyn_reshape_upgrade_pass)
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reshape_v1 = as_type_ptr<op::v1::Reshape>(pass_replacement_node);
EXPECT_EQ(reshape_v1->description(), "DynReshape");
EXPECT_EQ(reshape_v1->get_version(), 1);
const auto pass_replacement_node = f->get_result()->input_value(0).get_node_shared_ptr();
EXPECT_TRUE(is_type<op::v1::Reshape>(pass_replacement_node));
}
TEST(opset_transform, opset1_reshape_downgrade_pass)
......@@ -60,11 +56,8 @@ TEST(opset_transform, opset1_reshape_downgrade_pass)
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto pass_replacement_node = f->get_result()->input_value(0).get_node_shared_ptr();
const auto reshape_v1 = as_type_ptr<op::v0::DynReshape>(pass_replacement_node);
EXPECT_EQ(reshape_v1->description(), "DynReshape");
EXPECT_EQ(reshape_v1->get_version(), 0);
EXPECT_TRUE(reshape_v1);
EXPECT_EQ(reshape_v1->get_zero_flag(), true);
}
......@@ -40,10 +40,8 @@ TEST(opset_transform, opset1_gather_upgrade_pass)
pass_manager.run_passes(f);
auto gather_s1_result = f->get_results().at(0);
auto node = gather_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto gather_v1_node = static_pointer_cast<op::v1::Gather>(node);
EXPECT_EQ(gather_v1_node->description(), "Gather");
EXPECT_EQ(gather_v1_node->get_version(), 1);
auto gather_v1_node = as_type_ptr<op::v1::Gather>(
gather_s1_result->input(0).get_source_output().get_node_shared_ptr());
EXPECT_TRUE(gather_v1_node);
EXPECT_EQ(gather_v1_node->get_axis(), axis);
}
......@@ -26,10 +26,8 @@ TEST(opset_transform, opset1_generate_mask_downgrade_pass)
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
auto generate_mask_v0 = static_pointer_cast<op::v0::GenerateMask>(
auto generate_mask_v0 = as_type_ptr<op::v0::GenerateMask>(
f->get_results().at(0)->input_value(0).get_node_shared_ptr());
EXPECT_EQ(generate_mask_v0->description(), "GenerateMask");
EXPECT_EQ(generate_mask_v0->get_version(), 0);
EXPECT_TRUE(generate_mask_v0);
EXPECT_EQ(generate_mask_v0->get_mask_shape(), (Shape{1, 128}));
}
......@@ -40,10 +40,8 @@ TEST(opset_transform, opset1_logical_and_upgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto and_v1 = static_pointer_cast<op::v1::LogicalAnd>(pass_replacement_node);
EXPECT_EQ(and_v1->description(), "LogicalAnd");
EXPECT_EQ(and_v1->get_version(), 1);
const auto and_v1 = as_type_ptr<op::v1::LogicalAnd>(pass_replacement_node);
EXPECT_TRUE(and_v1);
const auto values_out_element_type = and_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
......@@ -63,10 +61,8 @@ TEST(opset_transform, opset1_logical_and_downgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto and_v0 = static_pointer_cast<op::v0::And>(pass_replacement_node);
EXPECT_EQ(and_v0->description(), "And");
EXPECT_EQ(and_v0->get_version(), 0);
const auto and_v0 = as_type_ptr<op::v0::And>(pass_replacement_node);
EXPECT_TRUE(and_v0);
const auto values_out_element_type = and_v0->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
......
......@@ -39,10 +39,8 @@ TEST(opset_transform, opset1_logical_not_upgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto not_v1 = static_pointer_cast<op::v1::LogicalNot>(pass_replacement_node);
EXPECT_EQ(not_v1->description(), "LogicalNot");
EXPECT_EQ(not_v1->get_version(), 1);
const auto not_v1 = as_type_ptr<op::v1::LogicalNot>(pass_replacement_node);
EXPECT_TRUE(not_v1);
const auto values_out_element_type = not_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
......@@ -61,10 +59,8 @@ TEST(opset_transform, opset1_logical_not_downgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto not_v0 = static_pointer_cast<op::v0::Not>(pass_replacement_node);
EXPECT_EQ(not_v0->description(), "Not");
EXPECT_EQ(not_v0->get_version(), 0);
const auto not_v0 = as_type_ptr<op::v0::Not>(pass_replacement_node);
EXPECT_TRUE(not_v0);
const auto values_out_element_type = not_v0->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
......
......@@ -40,10 +40,8 @@ TEST(opset_transform, opset1_logical_or_upgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto or_v1 = static_pointer_cast<op::v1::LogicalOr>(pass_replacement_node);
EXPECT_EQ(or_v1->description(), "LogicalOr");
EXPECT_EQ(or_v1->get_version(), 1);
const auto or_v1 = as_type_ptr<op::v1::LogicalOr>(pass_replacement_node);
EXPECT_TRUE(or_v1);
const auto values_out_element_type = or_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
......@@ -63,10 +61,8 @@ TEST(opset_transform, opset1_logical_or_downgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto or_v0 = static_pointer_cast<op::v0::Or>(pass_replacement_node);
EXPECT_EQ(or_v0->description(), "Or");
EXPECT_EQ(or_v0->get_version(), 0);
const auto or_v0 = as_type_ptr<op::v0::Or>(pass_replacement_node);
EXPECT_TRUE(or_v0);
const auto values_out_element_type = or_v0->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
......
......@@ -40,10 +40,8 @@ TEST(opset_transform, opset1_logical_xor_upgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto xor_v1 = static_pointer_cast<op::v1::LogicalXor>(pass_replacement_node);
EXPECT_EQ(xor_v1->description(), "LogicalXor");
EXPECT_EQ(xor_v1->get_version(), 1);
const auto xor_v1 = as_type_ptr<op::v1::LogicalXor>(pass_replacement_node);
EXPECT_TRUE(xor_v1);
const auto values_out_element_type = xor_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
......@@ -63,10 +61,8 @@ TEST(opset_transform, opset1_logical_xor_downgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto xor_v0 = static_pointer_cast<op::v0::Xor>(pass_replacement_node);
EXPECT_EQ(xor_v0->description(), "Xor");
EXPECT_EQ(xor_v0->get_version(), 0);
const auto xor_v0 = as_type_ptr<op::v0::Xor>(pass_replacement_node);
EXPECT_TRUE(xor_v0);
const auto values_out_element_type = xor_v0->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
......
......@@ -29,10 +29,8 @@ TEST(opset_transform, opset1_pad_upgrade_pass)
auto pad_s1_result = f->get_results().at(0);
auto node = pad_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto pad_v1_node = static_pointer_cast<op::v1::Pad>(node);
EXPECT_EQ(pad_v1_node->description(), "Pad");
EXPECT_EQ(pad_v1_node->get_version(), 1);
auto pad_v1_node = as_type_ptr<op::v1::Pad>(node);
EXPECT_TRUE(pad_v1_node);
EXPECT_EQ(pad_v1_node->get_pad_mode(), pad_mode);
EXPECT_EQ(pad_v1_node->get_pads_begin(), padding_below);
......@@ -58,10 +56,8 @@ TEST(opset_transform, opset1_pad_downgrade_pass)
auto pad_s0_result = f->get_results().at(0);
auto node = pad_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto pad_v0_node = static_pointer_cast<op::v0::Pad>(node);
EXPECT_EQ(pad_v0_node->description(), "Pad");
EXPECT_EQ(pad_v0_node->get_version(), 0);
auto pad_v0_node = as_type_ptr<op::v0::Pad>(node);
EXPECT_TRUE(pad_v0_node);
EXPECT_EQ(pad_v0_node->get_pad_mode(), pad_mode);
EXPECT_EQ(pad_v0_node->get_padding_below(), CoordinateDiff({1, 2}));
......
......@@ -33,10 +33,8 @@ TEST(opset_transform, opset1_avgpool_upgrade_pass)
auto avgpool_s1_result = f->get_results().at(0);
auto node = avgpool_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto avg_pool_v1_node = static_pointer_cast<op::v1::AvgPool>(node);
EXPECT_EQ(avg_pool_v1_node->description(), "AvgPool");
EXPECT_EQ(avg_pool_v1_node->get_version(), 1);
auto avg_pool_v1_node = as_type_ptr<op::v1::AvgPool>(node);
EXPECT_TRUE(avg_pool_v1_node);
EXPECT_EQ(avg_pool_v1_node->get_pads_begin(), pads_begin);
EXPECT_EQ(avg_pool_v1_node->get_pads_end(), pads_end);
......@@ -68,10 +66,8 @@ TEST(opset_transform, opset1_maxpool_upgrade_pass)
auto maxpool_s1_result = f->get_results().at(0);
auto node = maxpool_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto max_pool_v1_node = static_pointer_cast<op::v1::MaxPool>(node);
EXPECT_EQ(max_pool_v1_node->description(), "MaxPool");
EXPECT_EQ(max_pool_v1_node->get_version(), 1);
auto max_pool_v1_node = as_type_ptr<op::v1::MaxPool>(node);
EXPECT_TRUE(max_pool_v1_node);
EXPECT_EQ(max_pool_v1_node->get_pads_begin(), pads_begin);
EXPECT_EQ(max_pool_v1_node->get_pads_end(), pads_end);
......@@ -109,10 +105,8 @@ TEST(opset_transform, opset1_avgpool_downgrade_pass)
auto avgpool_s0_result = f->get_results().at(0);
auto node = avgpool_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto avg_pool_v0_node = static_pointer_cast<op::v0::AvgPool>(node);
EXPECT_EQ(avg_pool_v0_node->description(), "AvgPool");
EXPECT_EQ(avg_pool_v0_node->get_version(), 0);
auto avg_pool_v0_node = as_type_ptr<op::v0::AvgPool>(node);
EXPECT_TRUE(avg_pool_v0_node);
EXPECT_EQ(avg_pool_v0_node->get_padding_below(), padding_below);
EXPECT_EQ(avg_pool_v0_node->get_padding_above(), padding_above);
......@@ -149,10 +143,8 @@ TEST(opset_transform, opset1_maxpool_downgrade_pass)
auto maxpool_s0_result = f->get_results().at(0);
auto node = maxpool_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto max_pool_v0_node = static_pointer_cast<op::v0::MaxPool>(node);
EXPECT_EQ(max_pool_v0_node->description(), "MaxPool");
EXPECT_EQ(max_pool_v0_node->get_version(), 0);
auto max_pool_v0_node = as_type_ptr<op::v0::MaxPool>(node);
EXPECT_TRUE(max_pool_v0_node);
EXPECT_EQ(max_pool_v0_node->get_padding_below(), padding_below);
EXPECT_EQ(max_pool_v0_node->get_padding_above(), padding_above);
......@@ -189,10 +181,8 @@ TEST(opset_transform, opset1_avgpool_backprop_downgrade_pass)
auto avgpool_backprop_s0_result = f->get_results().at(0);
auto node = avgpool_backprop_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto avg_pool_backprop_v0_node = static_pointer_cast<op::v0::AvgPoolBackprop>(node);
EXPECT_EQ(avg_pool_backprop_v0_node->description(), "AvgPoolBackprop");
EXPECT_EQ(avg_pool_backprop_v0_node->get_version(), 0);
auto avg_pool_backprop_v0_node = as_type_ptr<op::v0::AvgPoolBackprop>(node);
EXPECT_TRUE(avg_pool_backprop_v0_node);
EXPECT_EQ(avg_pool_backprop_v0_node->get_padding_below(), padding_below);
EXPECT_EQ(avg_pool_backprop_v0_node->get_padding_above(), padding_above);
......@@ -229,10 +219,8 @@ TEST(opset_transform, opset1_maxpool_backprop_downgrade_pass)
auto max_pool_backprop_s0_result = f->get_results().at(0);
auto node = max_pool_backprop_s0_result->input(0).get_source_output().get_node_shared_ptr();
auto max_pool_backprop_v0_node = static_pointer_cast<op::v0::MaxPoolBackprop>(node);
EXPECT_EQ(max_pool_backprop_v0_node->description(), "MaxPoolBackprop");
EXPECT_EQ(max_pool_backprop_v0_node->get_version(), 0);
auto max_pool_backprop_v0_node = as_type_ptr<op::v0::MaxPoolBackprop>(node);
EXPECT_TRUE(max_pool_backprop_v0_node);
EXPECT_EQ(max_pool_backprop_v0_node->get_padding_below(), padding_below);
EXPECT_EQ(max_pool_backprop_v0_node->get_padding_above(), padding_above);
EXPECT_EQ(max_pool_backprop_v0_node->get_window_movement_strides(), window_movement_strides);
......
......@@ -41,10 +41,8 @@ TEST(opset_transform, opset1_product_upgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reduce_prod_v1 = static_pointer_cast<op::v1::ReduceProd>(pass_replacement_node);
EXPECT_EQ(reduce_prod_v1->description(), "Product");
EXPECT_EQ(reduce_prod_v1->get_version(), 1);
const auto reduce_prod_v1 = as_type_ptr<op::v1::ReduceProd>(pass_replacement_node);
EXPECT_TRUE(reduce_prod_v1);
EXPECT_EQ(reduce_prod_v1->get_keep_dims(), false);
}
......@@ -63,15 +61,12 @@ TEST(opset_transform, opset0_reduce_prod_downgrade_pass)
const auto reshape_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reshape = static_pointer_cast<op::Reshape>(reshape_replacement_node);
const auto reshape = as_type_ptr<op::Reshape>(reshape_replacement_node);
EXPECT_TRUE(reshape);
const auto product_replace_node =
reshape_replacement_node->input(0).get_source_output().get_node_shared_ptr();
const auto product_v0 = static_pointer_cast<op::v0::Product>(product_replace_node);
EXPECT_EQ(reshape->description(), "Reshape");
EXPECT_EQ(reshape->get_version(), 0);
EXPECT_EQ(product_v0->description(), "Product");
EXPECT_EQ(product_v0->get_version(), 0);
const auto product_v0 = as_type_ptr<op::v0::Product>(product_replace_node);
EXPECT_TRUE(product_v0);
}
TEST(opset_transform, opset0_reduce_prod_downgrade_pass_axes_not_constant)
......
......@@ -41,11 +41,9 @@ TEST(opset_transform, opset1_reverse_upgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reverse_v1 = static_pointer_cast<op::v1::Reverse>(pass_replacement_node);
const auto reverse_v1 = as_type_ptr<op::v1::Reverse>(pass_replacement_node);
EXPECT_TRUE(reverse_v1);
EXPECT_EQ(reverse_v1->get_mode(), op::v1::Reverse::Mode::INDEX);
EXPECT_EQ(reverse_v1->description(), "Reverse");
EXPECT_EQ(reverse_v1->get_version(), 1);
const auto& rev_axes_input_shape = reverse_v1->get_input_shape(1);
// should match the number of elements of v0::Reverse reverse_axes attribute
......@@ -69,10 +67,8 @@ TEST(opset_transform, opset0_reverse_downgrade_pass_index_mode)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reverse_v0 = static_pointer_cast<op::v0::Reverse>(pass_replacement_node);
EXPECT_EQ(reverse_v0->description(), "Reverse");
EXPECT_EQ(reverse_v0->get_version(), 0);
const auto reverse_v0 = as_type_ptr<op::v0::Reverse>(pass_replacement_node);
EXPECT_TRUE(reverse_v0);
EXPECT_EQ(reverse_v0->get_reversed_axes(), AxisSet({1, 2}));
}
......@@ -93,10 +89,8 @@ TEST(opset_transform, opset0_reverse_downgrade_pass_mask_mode)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reverse_v0 = static_pointer_cast<op::v0::Reverse>(pass_replacement_node);
EXPECT_EQ(reverse_v0->description(), "Reverse");
EXPECT_EQ(reverse_v0->get_version(), 0);
const auto reverse_v0 = as_type_ptr<op::v0::Reverse>(pass_replacement_node);
EXPECT_TRUE(reverse_v0);
EXPECT_EQ(reverse_v0->get_reversed_axes(), AxisSet({0, 2}));
}
......
......@@ -45,16 +45,17 @@ TEST(opset_transform, opset1_dyn_slice_upgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto strided_slice_v1 = as_type_ptr<op::v1::StridedSlice>(pass_replacement_node);
EXPECT_TRUE(strided_slice_v1);
auto begin_const =
as_type_ptr<op::Constant>(strided_slice_v1->input_value(1).get_node_shared_ptr());
EXPECT_TRUE(begin_const);
auto end_const =
as_type_ptr<op::Constant>(strided_slice_v1->input_value(2).get_node_shared_ptr());
EXPECT_TRUE(end_const);
auto strides_const =
as_type_ptr<op::Constant>(strided_slice_v1->input_value(3).get_node_shared_ptr());
EXPECT_TRUE(strides_const);
EXPECT_EQ(strided_slice_v1->description(), "Slice");
EXPECT_EQ(strided_slice_v1->get_version(), 1);
EXPECT_EQ(strided_slice_v1->get_begin_mask(), vector<int64_t>(4, 0));
EXPECT_EQ(strided_slice_v1->get_end_mask(), vector<int64_t>(4, 0));
EXPECT_EQ(begin_const->get_vector<int64_t>(),
......@@ -84,9 +85,7 @@ TEST(opset_transform, opset1_strided_slice_downgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto slice_v0 = as_type_ptr<op::v0::Slice>(pass_replacement_node);
EXPECT_EQ(slice_v0->description(), "Slice");
EXPECT_EQ(slice_v0->get_version(), 0);
EXPECT_TRUE(slice_v0);
EXPECT_EQ(slice_v0->get_lower_bounds(), Coordinate({1, 2, 0, 2}));
EXPECT_EQ(slice_v0->get_upper_bounds(), Coordinate({5, 4, 5, 6}));
EXPECT_EQ(slice_v0->get_strides(), Strides({1, 1, 1, 1}));
......
......@@ -40,11 +40,9 @@ TEST(opset_transform, opset1_softmax_upgrade_pass_axis)
auto softmax_s1_result = f->get_results().at(0);
auto node = softmax_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto softmax_s1_node = static_pointer_cast<op::v1::Softmax>(node);
auto softmax_s1_node = as_type_ptr<op::v1::Softmax>(node);
EXPECT_TRUE(softmax_s1_node);
EXPECT_EQ(softmax_s1_node->get_axis(), axis);
EXPECT_EQ(softmax_s1_node->description(), "Softmax");
EXPECT_EQ(softmax_s1_node->get_version(), 1);
}
TEST(opset_transform, opset1_softmax_upgrade_pass_axis_exception)
......@@ -75,43 +73,3 @@ TEST(opset_transform, opset1_softmax_upgrade_pass_axis_exception)
FAIL() << "Softmax pass failed for unexpected reason";
}
}
namespace fake_v2
{
class FakeSoftmax : public op::v0::Softmax
{
public:
FakeSoftmax(const Output<Node>& arg, const AxisSet& axes)
: Softmax{arg, axes}
{
}
size_t get_version() const override { return 2; }
};
}
TEST(opset_transform, opset1_softmax_upgrade_pass_incorrect_op_version)
{
const AxisSet axes{2};
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4});
auto softmax_s2 = make_shared<fake_v2::FakeSoftmax>(arg, axes);
auto result = make_shared<op::Result>(softmax_s2);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
try
{
pass_manager.run_passes(f);
FAIL() << "Opset 1 transformation pass failed for";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Op version 1 transformation pass failed for"));
}
catch (...)
{
FAIL() << "Softmax pass failed for unexpected reason";
}
}
......@@ -41,10 +41,8 @@ TEST(opset_transform, opset1_reduce_sum_upgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reduce_sum_v1 = static_pointer_cast<op::v1::ReduceSum>(pass_replacement_node);
EXPECT_EQ(reduce_sum_v1->description(), "Sum");
EXPECT_EQ(reduce_sum_v1->get_version(), 1);
const auto reduce_sum_v1 = as_type_ptr<op::v1::ReduceSum>(pass_replacement_node);
EXPECT_TRUE(reduce_sum_v1);
EXPECT_EQ(reduce_sum_v1->get_keep_dims(), false);
}
......@@ -63,15 +61,12 @@ TEST(opset_transform, opset0_reduce_sum_downgrade_pass)
const auto reshape_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reshape = static_pointer_cast<op::Reshape>(reshape_replacement_node);
const auto reshape = as_type_ptr<op::Reshape>(reshape_replacement_node);
EXPECT_TRUE(reshape);
const auto sum_replace_node =
reshape_replacement_node->input(0).get_source_output().get_node_shared_ptr();
const auto sum_v0 = static_pointer_cast<op::v0::Sum>(sum_replace_node);
EXPECT_EQ(reshape->description(), "Reshape");
EXPECT_EQ(reshape->get_version(), 0);
EXPECT_EQ(sum_v0->description(), "Sum");
EXPECT_EQ(sum_v0->get_version(), 0);
const auto sum_v0 = as_type_ptr<op::v0::Sum>(sum_replace_node);
EXPECT_TRUE(sum_v0);
}
TEST(opset_transform, opset0_reduce_sum_downgrade_pass_not_constant_axes)
......
......@@ -41,11 +41,9 @@ TEST(opset_transform, opset1_topk_upgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto topk_v1 = static_pointer_cast<op::v1::TopK>(pass_replacement_node);
const auto topk_v1 = as_type_ptr<op::v1::TopK>(pass_replacement_node);
EXPECT_TRUE(topk_v1);
EXPECT_EQ(topk_v1->get_axis(), axis);
EXPECT_EQ(topk_v1->description(), "TopK");
EXPECT_EQ(topk_v1->get_version(), 1);
EXPECT_EQ(topk_v1->get_mode(), op::v1::TopK::Mode::MAX);
EXPECT_EQ(topk_v1->get_sort_type(), op::v1::TopK::SortType::SORT_VALUES);
......@@ -74,9 +72,7 @@ TEST(opset_transform, opset1_topk_downgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto topk_v0 = as_type_ptr<op::v0::TopK>(pass_replacement_node);
EXPECT_EQ(topk_v0->description(), "TopK");
EXPECT_EQ(topk_v0->get_version(), 0);
EXPECT_TRUE(topk_v0);
EXPECT_EQ(topk_v0->get_k(), k);
EXPECT_EQ(topk_v0->get_top_k_axis(), axis);
EXPECT_EQ(topk_v0->get_compute_max(), true);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment