Commit acdf4d30 authored by Adam Rogowiec's avatar Adam Rogowiec

Move activation functions to nGraph core.

parent 8cc4ccef
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cmath>
#include <functional>
#include <iterator>
#include <unordered_map>
#include "activation_functions.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/tanh.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace rnn
{
namespace detail
{
std::shared_ptr<ngraph::Node>
sigmoid(const std::shared_ptr<ngraph::Node>& arg, float alpha, float beta)
{
return std::make_shared<ngraph::op::Sigmoid>(arg);
}
std::shared_ptr<ngraph::Node>
tanh(const std::shared_ptr<ngraph::Node>& arg, float alpha, float beta)
{
return std::make_shared<ngraph::op::Tanh>(arg);
}
std::shared_ptr<ngraph::Node>
relu(const std::shared_ptr<ngraph::Node>& arg, float alpha, float beta)
{
return std::make_shared<ngraph::op::Relu>(arg);
}
std::shared_ptr<ngraph::Node>
hardsigmoid(const std::shared_ptr<ngraph::Node>& arg, float alpha, float beta)
{
return std::make_shared<ngraph::op::HardSigmoid>(arg, alpha, beta);
}
} // namespace detail
ActivationFunction::ActivationFunction(ActivationFunctionType f,
float alpha,
float beta)
: m_function{f}
, m_alpha{alpha}
, m_beta{beta}
{
}
ActivationFunction::ActivationFunction(ActivationFunctionType f, float alpha)
: ActivationFunction(f, alpha, std::nanf(""))
{
}
ActivationFunction::ActivationFunction(ActivationFunctionType f)
: ActivationFunction(f, std::nanf(""), std::nanf(""))
{
}
std::shared_ptr<ngraph::Node> ActivationFunction::
operator()(const std::shared_ptr<ngraph::Node>& arg) const
{
return m_function(arg, m_alpha, m_beta);
}
ActivationFunction get_activation_func_by_name(const std::string& func_name)
{
using ActivationFunctionMap = std::unordered_map<std::string, ActivationFunction>;
using namespace std::placeholders;
static ActivationFunctionMap func_map{
{"sigmoid", ActivationFunction{detail::sigmoid}},
{"tanh", ActivationFunction{detail::tanh}},
{"relu", ActivationFunction{detail::relu}},
{"hardsigmoid", ActivationFunction{detail::hardsigmoid, 0.2f, 0.5f}},
};
auto func_it = func_map.find(func_name);
if (func_it == std::end(func_map))
{
throw error::UnknownActivationFunction(func_name);
}
return func_it->second;
}
} //namespace rnn
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cmath>
#include <functional>
#include <iterator>
#include <unordered_map>
#include "activation_functions.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/tanh.hpp"
using namespace std;
using namespace ngraph;
static shared_ptr<Node> sigmoid(const shared_ptr<Node>& arg, float alpha, float beta)
{
return make_shared<op::Sigmoid>(arg);
}
static shared_ptr<Node> tanh(const shared_ptr<Node>& arg, float alpha, float beta)
{
return make_shared<op::Tanh>(arg);
}
static shared_ptr<Node> relu(const shared_ptr<Node>& arg, float alpha, float beta)
{
return make_shared<op::Relu>(arg);
}
static shared_ptr<Node> hardsigmoid(const shared_ptr<Node>& arg, float alpha, float beta)
{
return make_shared<op::HardSigmoid>(arg, alpha, beta);
}
op::ActivationFunction::ActivationFunction(ActivationFunctionType f, float alpha, float beta)
: m_function{f}
, m_alpha{alpha}
, m_beta{beta}
{
}
op::ActivationFunction::ActivationFunction(ActivationFunctionType f, float alpha)
: ActivationFunction(f, alpha, nanf(""))
{
}
op::ActivationFunction::ActivationFunction(ActivationFunctionType f)
: ActivationFunction(f, nanf(""), nanf(""))
{
}
shared_ptr<Node> op::ActivationFunction::operator()(const shared_ptr<Node>& arg) const
{
return m_function(arg, m_alpha, m_beta);
}
op::ActivationFunction op::get_activation_func_by_name(const string& func_name)
{
using ActivationFunctionMap = unordered_map<string, op::ActivationFunction>;
using namespace placeholders;
static ActivationFunctionMap func_map{
{"sigmoid", op::ActivationFunction{sigmoid}},
{"tanh", op::ActivationFunction{tanh}},
{"relu", op::ActivationFunction{relu}},
{"hardsigmoid", op::ActivationFunction{hardsigmoid, 0.2f, 0.5f}},
};
auto func_it = func_map.find(func_name);
if (func_it == end(func_map))
{
throw error::UnknownActivationFunction(func_name);
}
return func_it->second;
}
......@@ -31,80 +31,76 @@
// Prevents the compiler from complaining about or optimizing away variables
// that appear unused on Linux
#if (defined(__GNUC__) && !defined(__clang__))
#undef ONNX_ATTRIBUTE_UNUSED
#define ONNX_ATTRIBUTE_UNUSED __attribute__((__unused__))
#undef NG_ATTRIBUTE_UNUSED
#define NG_ATTRIBUTE_UNUSED __attribute__((__unused__))
#else
#define ONNX_ATTRIBUTE_UNUSED
#define NG_ATTRIBUTE_UNUSED
#endif
#define UNUSED_PARAMETER ONNX_ATTRIBUTE_UNUSED = 0
#define UNUSED_PARAMETER NG_ATTRIBUTE_UNUSED = 0
namespace ngraph
{
namespace onnx_import
namespace op
{
namespace rnn
namespace error
{
namespace error
struct UnknownActivationFunction : ngraph_error
{
struct UnknownActivationFunction : ngraph_error
UnknownActivationFunction(const std::string& func_name)
: ngraph_error{"Unknown activation function: " + func_name}
{
UnknownActivationFunction(const std::string& func_name)
: ngraph_error{"Unknown activation function: " + func_name}
{
}
};
}
namespace detail
{
std::shared_ptr<ngraph::Node> sigmoid(const std::shared_ptr<ngraph::Node>& arg,
float alpha UNUSED_PARAMETER,
float beta UNUSED_PARAMETER);
std::shared_ptr<ngraph::Node> tanh(const std::shared_ptr<ngraph::Node>& arg,
float alpha UNUSED_PARAMETER,
float beta UNUSED_PARAMETER);
std::shared_ptr<ngraph::Node> relu(const std::shared_ptr<ngraph::Node>& arg,
float alpha UNUSED_PARAMETER,
float beta UNUSED_PARAMETER);
std::shared_ptr<ngraph::Node>
hardsigmoid(const std::shared_ptr<ngraph::Node>& arg, float alpha, float beta);
}
using ActivationFunctionType = std::shared_ptr<ngraph::Node> (*)(
const std::shared_ptr<ngraph::Node>&, float, float);
class ActivationFunction
{
public:
ActivationFunction(ActivationFunctionType f, float alpha, float beta);
ActivationFunction(ActivationFunctionType f, float alpha);
ActivationFunction(ActivationFunctionType f);
std::shared_ptr<ngraph::Node>
operator()(const std::shared_ptr<ngraph::Node>& arg) const;
void set_alpha(float alpha) { m_alpha = alpha; }
void set_beta(float beta) { m_beta = beta; }
private:
ActivationFunctionType m_function;
float m_alpha;
float m_beta;
}
};
}
/// \brief Gets the activation function by name.
///
/// \param[in] func_name The function name
///
/// \throws UnknownActivationFunction When provided func_name is unknown.
///
/// \return The activation function object.
///
ActivationFunction get_activation_func_by_name(const std::string& func_name);
} //namespace rnn
} // namespace onnx_import
namespace detail
{
std::shared_ptr<Node> sigmoid(const std::shared_ptr<Node>& arg,
float alpha UNUSED_PARAMETER,
float beta UNUSED_PARAMETER);
std::shared_ptr<Node> tanh(const std::shared_ptr<Node>& arg,
float alpha UNUSED_PARAMETER,
float beta UNUSED_PARAMETER);
std::shared_ptr<Node> relu(const std::shared_ptr<Node>& arg,
float alpha UNUSED_PARAMETER,
float beta UNUSED_PARAMETER);
std::shared_ptr<Node>
hardsigmoid(const std::shared_ptr<Node>& arg, float alpha, float beta);
}
using ActivationFunctionType = std::shared_ptr<Node> (*)(const std::shared_ptr<Node>&,
float,
float);
class ActivationFunction
{
public:
ActivationFunction(ActivationFunctionType f, float alpha, float beta);
ActivationFunction(ActivationFunctionType f, float alpha);
ActivationFunction(ActivationFunctionType f);
std::shared_ptr<Node> operator()(const std::shared_ptr<Node>& arg) const;
void set_alpha(float alpha) { m_alpha = alpha; }
void set_beta(float beta) { m_beta = beta; }
private:
ActivationFunctionType m_function;
float m_alpha;
float m_beta;
};
/// \brief Gets the activation function by name.
///
/// \param[in] func_name The function name
///
/// \throws UnknownActivationFunction When provided func_name is unknown.
///
/// \return The activation function object.
///
ActivationFunction get_activation_func_by_name(const std::string& func_name);
} // namespace op
} // namespace ngraph
......@@ -115,6 +111,6 @@ namespace ngraph
#ifdef UNUSED_PARAMETER
#undef UNUSED_PARAMETER
#endif
#ifdef ONNX_ATTRIBUTE_UNUSED
#undef ONNX_ATTRIBUTE_UNUSED
#ifdef NG_ATTRIBUTE_UNUSED
#undef NG_ATTRIBUTE_UNUSED
#endif
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment