Commit 1b52a67f authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Michał Karzyński

[ONNX] Inverse hyperbolic trigonometric functions. (#2436)

parent 6cd77ff2
......@@ -39,6 +39,8 @@ add_library(onnx_import STATIC
core/value_info.hpp
exceptions.hpp
op/acos.hpp
op/acosh.cpp
op/acosh.hpp
op/add.hpp
op/and.hpp
op/argmax.cpp
......@@ -46,7 +48,11 @@ add_library(onnx_import STATIC
op/argmin.cpp
op/argmin.hpp
op/asin.hpp
op/asinh.cpp
op/asinh.hpp
op/atan.hpp
op/atanh.cpp
op/atanh.hpp
op/average_pool.cpp
op/average_pool.hpp
op/batch_norm.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "acosh.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector acosh(const Node& node)
{
std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
// Define inverse hyperbolic cosine in terms of natural logarithm:
//
// arccosh(x) = ln(x + sqrt(x^2 - 1))
//
std::shared_ptr<ngraph::Node> one_node{ngraph::op::Constant::create(
data->get_element_type(),
data->get_shape(),
std::vector<float>(ngraph::shape_size(data->get_shape()), 1.f))};
std::shared_ptr<ngraph::Node> sqrt_node{
std::make_shared<ngraph::op::Sqrt>(data * data - one_node)};
return {std::make_shared<ngraph::op::Log>(data + sqrt_node)};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector acosh(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "asinh.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector asinh(const Node& node)
{
std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
// Define inverse hyperbolic sine in terms of natural logarithm:
//
// asinh(x) = ln(x + sqrt(x^2 + 1))
//
std::shared_ptr<ngraph::Node> one_node{ngraph::op::Constant::create(
data->get_element_type(),
data->get_shape(),
std::vector<float>(ngraph::shape_size(data->get_shape()), 1.f))};
std::shared_ptr<ngraph::Node> sqrt_node{
std::make_shared<ngraph::op::Sqrt>(data * data + one_node)};
return {std::make_shared<ngraph::op::Log>(data + sqrt_node)};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector asinh(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "atanh.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector atanh(const Node& node)
{
std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
// Define inverse hyperbolic tangent in terms of natural logarithm:
//
// atanh(x) = 0.5 * ln((1 + x) / (1 - x))
// = 0.5 * (ln(1 + x) - ln(1 - x))
//
std::shared_ptr<ngraph::Node> one_node{ngraph::op::Constant::create(
data->get_element_type(),
data->get_shape(),
std::vector<float>(ngraph::shape_size(data->get_shape()), 1.f))};
std::shared_ptr<ngraph::Node> half_node{ngraph::op::Constant::create(
data->get_element_type(),
data->get_shape(),
std::vector<float>(ngraph::shape_size(data->get_shape()), 0.5f))};
return {half_node * (std::make_shared<ngraph::op::Log>(one_node + data) -
std::make_shared<ngraph::op::Log>(one_node - data))};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector atanh(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -20,12 +20,15 @@ opset versions starting from `1` to `6` and to the latest opset version.
|------|----------------------------|---------|
| Abs | 1-6- |
| Acos | 7- |
| Acosh | 9- |
| Add | 1-7- |
| And | 1-7- |
| ArgMax | 1- |
| ArgMin | 1- |
| Asin | 7- |
| Asinh | 9- |
| Atan | 7 - |
| Atanh | 9- |
| AveragePool | 1-7- |
| BatchNormalization | 1-6-7- |
| Ceil | 1-6- |
......@@ -104,9 +107,6 @@ opset versions starting from `1` to `6` and to the latest opset version.
### Lack of support in nGraph
| Name | Opset supported | NGCORE | NGONNX | Comment |
|------|-----------------|--------|--------|---------|
| Acosh | (9) | 283 | 444 | |
| Asinh | (9) | 283 | 444 | |
| Atanh | (9) | 283 | 444 | |
| Erf | (9) | 284 | 442 | Maybe we may implement this as a simple closed interval integral? :) |
| Pad | 1-2- | 273 | 416 | Not fully supported. |
| LSTM | 1-7- | | 430 | Not fully supported. |
......
......@@ -24,12 +24,15 @@
#include "ngraph/log.hpp"
#include "op/abs.hpp"
#include "op/acos.hpp"
#include "op/acosh.hpp"
#include "op/add.hpp"
#include "op/and.hpp"
#include "op/argmax.hpp"
#include "op/argmin.hpp"
#include "op/asin.hpp"
#include "op/asinh.hpp"
#include "op/atan.hpp"
#include "op/atanh.hpp"
#include "op/average_pool.hpp"
#include "op/batch_norm.hpp"
#include "op/cast.hpp"
......@@ -205,13 +208,16 @@ namespace ngraph
{
REGISTER_OPERATOR("Abs", 1, abs);
REGISTER_OPERATOR("Acos", 1, acos);
REGISTER_OPERATOR("Acosh", 1, acosh);
REGISTER_OPERATOR("Add", 1, add);
REGISTER_OPERATOR("Add", 7, add);
REGISTER_OPERATOR("And", 1, logical_and);
REGISTER_OPERATOR("ArgMin", 1, argmin);
REGISTER_OPERATOR("ArgMax", 1, argmax);
REGISTER_OPERATOR("Asin", 1, asin);
REGISTER_OPERATOR("Asinh", 1, asinh);
REGISTER_OPERATOR("Atan", 1, atan);
REGISTER_OPERATOR("Atanh", 1, atanh);
REGISTER_OPERATOR("AveragePool", 1, average_pool);
REGISTER_OPERATOR("BatchNormalization", 1, batch_norm);
REGISTER_OPERATOR("Cast", 1, cast);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment