Commit 61926759 authored by Tomasz Socha's avatar Tomasz Socha

[ONNX] Unary ops

parent 309bfdf0
...@@ -37,6 +37,11 @@ add_library(onnx_import STATIC ...@@ -37,6 +37,11 @@ add_library(onnx_import STATIC
op/average_pool.hpp op/average_pool.hpp
op/batch_norm.cpp op/batch_norm.cpp
op/batch_norm.hpp op/batch_norm.hpp
op/cast.cpp
op/cast.hpp
op/ceil.hpp
op/clip.cpp
op/clip.hpp
op/concat.cpp op/concat.cpp
op/concat.hpp op/concat.hpp
op/constant.cpp op/constant.cpp
...@@ -45,12 +50,19 @@ add_library(onnx_import STATIC ...@@ -45,12 +50,19 @@ add_library(onnx_import STATIC
op/conv.hpp op/conv.hpp
op/div.hpp op/div.hpp
op/equal.hpp op/equal.hpp
op/exp.hpp
op/flatten.cpp op/flatten.cpp
op/flatten.hpp op/flatten.hpp
op/floor.hpp
op/gemm.cpp op/gemm.cpp
op/gemm.hpp op/gemm.hpp
op/greater.hpp op/greater.hpp
op/hard_sigmoid.cpp
op/hard_sigmoid.hpp
op/identity.hpp
op/less.hpp op/less.hpp
op/log.hpp
op/log_softmax.hpp
op/matmul.hpp op/matmul.hpp
op/max_pool.cpp op/max_pool.cpp
op/max_pool.hpp op/max_pool.hpp
...@@ -59,9 +71,12 @@ add_library(onnx_import STATIC ...@@ -59,9 +71,12 @@ add_library(onnx_import STATIC
op/mean.hpp op/mean.hpp
op/min.hpp op/min.hpp
op/mul.hpp op/mul.hpp
op/neg.hpp
op/not.hpp op/not.hpp
op/or.hpp op/or.hpp
op/pow.hpp op/pow.hpp
op/reciprocal.cpp
op/reciprocal.hpp
op/reduce.cpp op/reduce.cpp
op/reduce.hpp op/reduce.hpp
op/relu.hpp op/relu.hpp
...@@ -71,8 +86,13 @@ add_library(onnx_import STATIC ...@@ -71,8 +86,13 @@ add_library(onnx_import STATIC
op/shape.hpp op/shape.hpp
op/softmax.cpp op/softmax.cpp
op/softmax.hpp op/softmax.hpp
op/softplus.cpp
op/softplus.hpp
op/softsign.cpp
op/softsign.hpp
op/split.cpp op/split.cpp
op/split.hpp op/split.hpp
op/sqrt.hpp
op/sub.hpp op/sub.hpp
op/sum.hpp op/sum.hpp
op/unsqueeze.cpp op/unsqueeze.cpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node_vector.hpp"
#include "ngraph/op/abs.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector abs(const Node& node)
{
return {std::make_shared<ngraph::op::Abs>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include <onnx.pb.h>
#include "ngraph/op/convert.hpp"
#include "ngraph/type/element_type.hpp"
#include "exceptions.hpp"
#include "cast.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector cast(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
int64_t target_type = node.get_attribute_value<int64_t>("to");
element::Type elem_type;
switch (target_type)
{
case onnx::TensorProto_DataType_BOOL: elem_type = element::boolean; break;
case onnx::TensorProto_DataType_DOUBLE: elem_type = element::f64; break;
case onnx::TensorProto_DataType_FLOAT16:
case onnx::TensorProto_DataType_FLOAT: elem_type = element::f32; break;
case onnx::TensorProto_DataType_INT8: elem_type = element::i8; break;
case onnx::TensorProto_DataType_INT16: elem_type = element::i16; break;
case onnx::TensorProto_DataType_INT32: elem_type = element::i32; break;
case onnx::TensorProto_DataType_INT64: elem_type = element::i64; break;
case onnx::TensorProto_DataType_UINT8: elem_type = element::u8; break;
case onnx::TensorProto_DataType_UINT16: elem_type = element::u16; break;
case onnx::TensorProto_DataType_UINT32: elem_type = element::u32; break;
case onnx::TensorProto_DataType_UINT64: elem_type = element::u64; break;
case onnx::TensorProto_DataType_UNDEFINED: elem_type = element::unspecified; break;
default: ASSERT_IS_SUPPORTED(node, false) << "unsupported type";
}
return {std::make_shared<ngraph::op::Convert>(data, elem_type)};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector cast(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node_vector.hpp"
#include "ngraph/op/ceiling.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector ceil(const Node& node)
{
return {std::make_shared<ngraph::op::Ceiling>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <limits>
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
#include "clip.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector clip(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double max_value =
node.get_attribute_value<double>("max", std::numeric_limits<double>::max());
double min_value =
node.get_attribute_value<double>("min", std::numeric_limits<double>::lowest());
std::shared_ptr<ngraph::Node> max_value_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{max_value});
max_value_node = make_broadcast_node(max_value_node, data->get_shape());
std::shared_ptr<ngraph::Node> min_value_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{min_value});
min_value_node = make_broadcast_node(min_value_node, data->get_shape());
return {std::make_shared<ngraph::op::Minimum>(
max_value_node, std::make_shared<ngraph::op::Maximum>(data, min_value_node))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector clip(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node_vector.hpp"
#include "ngraph/op/exp.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector exp(const Node& node)
{
return {std::make_shared<ngraph::op::Exp>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node_vector.hpp"
#include "ngraph/op/floor.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector floor(const Node& node)
{
return {std::make_shared<ngraph::op::Floor>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
#include "hard_sigmoid.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector hard_sigmoid(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 0.2);
double beta = node.get_attribute_value<double>("beta", 0.5);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> beta_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{beta});
beta_node = make_broadcast_node(beta_node, data->get_shape());
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>(
zero_node,
std::make_shared<ngraph::op::Minimum>(one_node,
alpha_node * data + beta_node))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector hard_sigmoid(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector identity(const Node& node) { return {node.get_ng_inputs().at(0)}; }
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node_vector.hpp"
#include "ngraph/op/log.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector log(const Node& node)
{
return {std::make_shared<ngraph::op::Log>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/log.hpp"
#include "core/node.hpp"
#include "ngraph/frontend/onnx_import/op/softmax.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector log_softmax(const Node& node)
{
return {std::make_shared<ngraph::op::Log>(softmax(node).at(0))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/negative.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector neg(const Node& node) { return {-node.get_ng_inputs().at(0)}; }
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include <vector>
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/shape.hpp"
#include "utils/broadcasting.hpp"
#include "reciprocal.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector reciprocal(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape());
return {one_node / data};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector reciprocal(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/log.hpp"
#include "utils/broadcasting.hpp"
#include "softplus.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector softplus(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape());
return {std::make_shared<ngraph::op::Log>(std::make_shared<ngraph::op::Exp>(data) +
one_node)};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector softplus(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include <vector>
#include "ngraph/op/abs.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/shape.hpp"
#include "utils/broadcasting.hpp"
#include "softsign.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector softsign(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape());
return {data / (std::make_shared<ngraph::op::Abs>(data) + one_node)};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector softsign(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node_vector.hpp"
#include "ngraph/op/sqrt.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector sqrt(const Node& node)
{
return {std::make_shared<ngraph::op::Sqrt>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -18,34 +18,49 @@ ...@@ -18,34 +18,49 @@
#include <functional> #include <functional>
#include "core/attribute.hpp" #include "core/attribute.hpp"
#include "op/abs.hpp"
#include "op/add.hpp" #include "op/add.hpp"
#include "op/and.hpp" #include "op/and.hpp"
#include "op/average_pool.hpp" #include "op/average_pool.hpp"
#include "op/batch_norm.hpp" #include "op/batch_norm.hpp"
#include "op/cast.hpp"
#include "op/ceil.hpp"
#include "op/clip.hpp"
#include "op/concat.hpp" #include "op/concat.hpp"
#include "op/constant.hpp" #include "op/constant.hpp"
#include "op/conv.hpp" #include "op/conv.hpp"
#include "op/div.hpp" #include "op/div.hpp"
#include "op/equal.hpp" #include "op/equal.hpp"
#include "op/exp.hpp"
#include "op/flatten.hpp" #include "op/flatten.hpp"
#include "op/floor.hpp"
#include "op/gemm.hpp" #include "op/gemm.hpp"
#include "op/greater.hpp" #include "op/greater.hpp"
#include "op/hard_sigmoid.hpp"
#include "op/identity.hpp"
#include "op/less.hpp" #include "op/less.hpp"
#include "op/log.hpp"
#include "op/log_softmax.hpp"
#include "op/matmul.hpp" #include "op/matmul.hpp"
#include "op/max.hpp" #include "op/max.hpp"
#include "op/max_pool.hpp" #include "op/max_pool.hpp"
#include "op/mean.hpp" #include "op/mean.hpp"
#include "op/min.hpp" #include "op/min.hpp"
#include "op/mul.hpp" #include "op/mul.hpp"
#include "op/neg.hpp"
#include "op/not.hpp" #include "op/not.hpp"
#include "op/or.hpp" #include "op/or.hpp"
#include "op/pow.hpp" #include "op/pow.hpp"
#include "op/reciprocal.hpp"
#include "op/reduce.hpp" #include "op/reduce.hpp"
#include "op/relu.hpp" #include "op/relu.hpp"
#include "op/reshape.hpp" #include "op/reshape.hpp"
#include "op/shape.hpp" #include "op/shape.hpp"
#include "op/softmax.hpp" #include "op/softmax.hpp"
#include "op/softplus.hpp"
#include "op/softsign.hpp"
#include "op/split.hpp" #include "op/split.hpp"
#include "op/sqrt.hpp"
#include "op/sub.hpp" #include "op/sub.hpp"
#include "op/sum.hpp" #include "op/sum.hpp"
#include "op/unsqueeze.hpp" #include "op/unsqueeze.hpp"
...@@ -94,30 +109,43 @@ namespace ngraph ...@@ -94,30 +109,43 @@ namespace ngraph
ops_bridge() ops_bridge()
{ {
m_map.emplace("Abs", std::bind(op::abs, std::placeholders::_1));
m_map.emplace("Add", std::bind(op::add, std::placeholders::_1)); m_map.emplace("Add", std::bind(op::add, std::placeholders::_1));
m_map.emplace("And", std::bind(op::logical_and, std::placeholders::_1)); m_map.emplace("And", std::bind(op::logical_and, std::placeholders::_1));
m_map.emplace("AveragePool", m_map.emplace("AveragePool",
std::bind(op::average_pool, std::placeholders::_1)); std::bind(op::average_pool, std::placeholders::_1));
m_map.emplace("BatchNormalization", m_map.emplace("BatchNormalization",
std::bind(op::batch_norm, std::placeholders::_1)); std::bind(op::batch_norm, std::placeholders::_1));
m_map.emplace("Cast", std::bind(op::cast, std::placeholders::_1));
m_map.emplace("Ceil", std::bind(op::ceil, std::placeholders::_1));
m_map.emplace("Clip", std::bind(op::clip, std::placeholders::_1));
m_map.emplace("Concat", std::bind(op::concat, std::placeholders::_1)); m_map.emplace("Concat", std::bind(op::concat, std::placeholders::_1));
m_map.emplace("Constant", std::bind(op::constant, std::placeholders::_1)); m_map.emplace("Constant", std::bind(op::constant, std::placeholders::_1));
m_map.emplace("Conv", std::bind(op::conv, std::placeholders::_1)); m_map.emplace("Conv", std::bind(op::conv, std::placeholders::_1));
m_map.emplace("Div", std::bind(op::div, std::placeholders::_1)); m_map.emplace("Div", std::bind(op::div, std::placeholders::_1));
m_map.emplace("Equal", std::bind(op::equal, std::placeholders::_1)); m_map.emplace("Equal", std::bind(op::equal, std::placeholders::_1));
m_map.emplace("Exp", std::bind(op::exp, std::placeholders::_1));
m_map.emplace("Flatten", std::bind(op::flatten, std::placeholders::_1)); m_map.emplace("Flatten", std::bind(op::flatten, std::placeholders::_1));
m_map.emplace("Floor", std::bind(op::floor, std::placeholders::_1));
m_map.emplace("Gemm", std::bind(op::gemm, std::placeholders::_1)); m_map.emplace("Gemm", std::bind(op::gemm, std::placeholders::_1));
m_map.emplace("Greater", std::bind(op::greater, std::placeholders::_1)); m_map.emplace("Greater", std::bind(op::greater, std::placeholders::_1));
m_map.emplace("HardSigmoid",
std::bind(op::hard_sigmoid, std::placeholders::_1));
m_map.emplace("Identity", std::bind(op::identity, std::placeholders::_1));
m_map.emplace("Less", std::bind(op::less, std::placeholders::_1)); m_map.emplace("Less", std::bind(op::less, std::placeholders::_1));
m_map.emplace("Log", std::bind(op::log, std::placeholders::_1));
m_map.emplace("LogSoftmax", std::bind(op::log_softmax, std::placeholders::_1));
m_map.emplace("MatMul", std::bind(op::matmul, std::placeholders::_1)); m_map.emplace("MatMul", std::bind(op::matmul, std::placeholders::_1));
m_map.emplace("MaxPool", std::bind(op::max_pool, std::placeholders::_1)); m_map.emplace("MaxPool", std::bind(op::max_pool, std::placeholders::_1));
m_map.emplace("Max", std::bind(op::max, std::placeholders::_1)); m_map.emplace("Max", std::bind(op::max, std::placeholders::_1));
m_map.emplace("Mean", std::bind(op::mean, std::placeholders::_1)); m_map.emplace("Mean", std::bind(op::mean, std::placeholders::_1));
m_map.emplace("Min", std::bind(op::min, std::placeholders::_1)); m_map.emplace("Min", std::bind(op::min, std::placeholders::_1));
m_map.emplace("Mul", std::bind(op::mul, std::placeholders::_1)); m_map.emplace("Mul", std::bind(op::mul, std::placeholders::_1));
m_map.emplace("Neg", std::bind(op::neg, std::placeholders::_1));
m_map.emplace("Not", std::bind(op::logical_not, std::placeholders::_1)); m_map.emplace("Not", std::bind(op::logical_not, std::placeholders::_1));
m_map.emplace("Or", std::bind(op::logical_or, std::placeholders::_1)); m_map.emplace("Or", std::bind(op::logical_or, std::placeholders::_1));
m_map.emplace("Pow", std::bind(op::pow, std::placeholders::_1)); m_map.emplace("Pow", std::bind(op::pow, std::placeholders::_1));
m_map.emplace("Reciprocal", std::bind(op::reciprocal, std::placeholders::_1));
m_map.emplace("ReduceLogSum", m_map.emplace("ReduceLogSum",
std::bind(op::reduce_log_sum, std::placeholders::_1)); std::bind(op::reduce_log_sum, std::placeholders::_1));
m_map.emplace("ReduceLogSumExp", m_map.emplace("ReduceLogSumExp",
...@@ -135,7 +163,10 @@ namespace ngraph ...@@ -135,7 +163,10 @@ namespace ngraph
m_map.emplace("Reshape", std::bind(op::reshape, std::placeholders::_1)); m_map.emplace("Reshape", std::bind(op::reshape, std::placeholders::_1));
m_map.emplace("Shape", std::bind(op::shape, std::placeholders::_1)); m_map.emplace("Shape", std::bind(op::shape, std::placeholders::_1));
m_map.emplace("Softmax", std::bind(op::softmax, std::placeholders::_1)); m_map.emplace("Softmax", std::bind(op::softmax, std::placeholders::_1));
m_map.emplace("Softplus", std::bind(op::softplus, std::placeholders::_1));
m_map.emplace("Softsign", std::bind(op::softsign, std::placeholders::_1));
m_map.emplace("Split", std::bind(op::split, std::placeholders::_1)); m_map.emplace("Split", std::bind(op::split, std::placeholders::_1));
m_map.emplace("Sqrt", std::bind(op::sqrt, std::placeholders::_1));
m_map.emplace("Sub", std::bind(op::sub, std::placeholders::_1)); m_map.emplace("Sub", std::bind(op::sub, std::placeholders::_1));
m_map.emplace("Sum", std::bind(op::sum, std::placeholders::_1)); m_map.emplace("Sum", std::bind(op::sum, std::placeholders::_1));
m_map.emplace("Unsqueeze", std::bind(op::unsqueeze, std::placeholders::_1)); m_map.emplace("Unsqueeze", std::bind(op::unsqueeze, std::placeholders::_1));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment