Commit f749c9d0 authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Michał Karzyński

[SPEC] ConvertLike op (#3944)

parent 8f999289
...@@ -138,6 +138,8 @@ set (SRC ...@@ -138,6 +138,8 @@ set (SRC
op/constant.hpp op/constant.hpp
op/convert.cpp op/convert.cpp
op/convert.hpp op/convert.hpp
op/convert_like.cpp
op/convert_like.hpp
op/convolution.cpp op/convolution.cpp
op/convolution.hpp op/convolution.hpp
op/cos.cpp op/cos.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "ngraph/op/convert_like.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::v1::ConvertLike::type_info;
op::v1::ConvertLike::ConvertLike(const Output<Node>& data, const Output<Node>& like)
: Op({data, like})
{
constructor_validate_and_infer_types();
}
void op::v1::ConvertLike::validate_and_infer_types()
{
set_output_type(0, get_input_element_type(1), get_input_partial_shape(0));
}
shared_ptr<Node> op::v1::ConvertLike::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<ConvertLike>(new_args.at(0), new_args.at(1));
}
void op::v1::ConvertLike::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
const auto delta = deltas.at(0);
adjoints.add_delta(input_value(0), make_shared<op::v1::ConvertLike>(delta, input_value(1)));
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v1
{
/// \brief Elementwise type conversion operation.
class ConvertLike : public Op
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"ConvertLike", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a conversion operation.
ConvertLike() = default;
/// \brief Constructs a conversion operation.
/// \param data Node that produces the input tensor.
/// \param like Node which provides the target type information for the conversion.
ConvertLike(const Output<Node>& data, const Output<Node>& like);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
}
}
}
...@@ -60,6 +60,7 @@ NGRAPH_OP(CompiledKernel, ngraph::op, 0) ...@@ -60,6 +60,7 @@ NGRAPH_OP(CompiledKernel, ngraph::op, 0)
NGRAPH_OP(Concat, ngraph::op::v0, 0) NGRAPH_OP(Concat, ngraph::op::v0, 0)
NGRAPH_OP(Constant, ngraph::op, 0) NGRAPH_OP(Constant, ngraph::op, 0)
NGRAPH_OP(Convert, ngraph::op, 0) NGRAPH_OP(Convert, ngraph::op, 0)
NGRAPH_OP(ConvertLike, ngraph::op::v1, 1)
NGRAPH_OP(Convolution, ngraph::op::v0, 0) NGRAPH_OP(Convolution, ngraph::op::v0, 0)
NGRAPH_OP(Convolution, ngraph::op::v1, 1) NGRAPH_OP(Convolution, ngraph::op::v1, 1)
NGRAPH_OP(ConvolutionBackpropData, ngraph::op::v0, 0) NGRAPH_OP(ConvolutionBackpropData, ngraph::op::v0, 0)
......
...@@ -39,6 +39,7 @@ ...@@ -39,6 +39,7 @@
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp" #include "ngraph/op/convert.hpp"
#include "ngraph/op/convert_like.hpp"
#include "ngraph/op/convolution.hpp" #include "ngraph/op/convolution.hpp"
#include "ngraph/op/cos.hpp" #include "ngraph/op/cos.hpp"
#include "ngraph/op/cosh.hpp" #include "ngraph/op/cosh.hpp"
......
...@@ -65,6 +65,7 @@ NGRAPH_OP(Clamp, ngraph::op::v0) ...@@ -65,6 +65,7 @@ NGRAPH_OP(Clamp, ngraph::op::v0)
NGRAPH_OP(Concat, ngraph::op::v0) NGRAPH_OP(Concat, ngraph::op::v0)
NGRAPH_OP(Constant, ngraph::op) NGRAPH_OP(Constant, ngraph::op)
NGRAPH_OP(Convert, ngraph::op::v0) NGRAPH_OP(Convert, ngraph::op::v0)
NGRAPH_OP(ConvertLike, ngraph::op::v1)
NGRAPH_OP(Convolution, ngraph::op::v1) NGRAPH_OP(Convolution, ngraph::op::v1)
NGRAPH_OP(ConvolutionBackpropData, ngraph::op::v1) NGRAPH_OP(ConvolutionBackpropData, ngraph::op::v1)
NGRAPH_OP(Cos, ngraph::op::v0) NGRAPH_OP(Cos, ngraph::op::v0)
......
...@@ -1057,6 +1057,11 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1057,6 +1057,11 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::Convert>(args[0], target_type); node = make_shared<op::Convert>(args[0], target_type);
break; break;
} }
case OP_TYPEID::ConvertLike_v1:
{
node = make_shared<op::v1::ConvertLike>(args[0], args[1]);
break;
}
case OP_TYPEID::Convolution: case OP_TYPEID::Convolution:
{ {
auto window_movement_strides = auto window_movement_strides =
...@@ -3135,6 +3140,8 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -3135,6 +3140,8 @@ json JSONSerializer::serialize_node(const Node& n)
node["target_type"] = write_element_type(tmp->get_convert_element_type()); node["target_type"] = write_element_type(tmp->get_convert_element_type());
break; break;
} }
case OP_TYPEID::ConvertLike_v1: { break;
}
case OP_TYPEID::Convolution: case OP_TYPEID::Convolution:
{ {
auto tmp = static_cast<const op::v0::Convolution*>(&n); auto tmp = static_cast<const op::v0::Convolution*>(&n);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment