Unverified Commit f70e2174 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

GeluBackpropFactor (#3447)

parent 5bbd199b
......@@ -13,14 +13,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/fused/gelu.hpp"
#include <cmath>
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/erf.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/fused/gelu.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/subtract.hpp"
using namespace std;
using namespace ngraph;
......@@ -64,11 +68,60 @@ void op::Gelu::pre_validate_and_infer_types()
element::Type input_element_type = get_input_element_type(0);
NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type == element::f32 ||
input_element_type == element::f64 ||
input_element_type == element::f16 ||
input_element_type == element::bf16,
input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type,
").");
}
void op::Gelu::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
auto delta = deltas.at(0);
auto x = input_value(0);
adjoints.add_delta(x, delta * (make_shared<op::GeluBackpropFactor>(x)));
}
const string op::GeluBackpropFactor::type_name{"GeluBackpropFactor"};
op::GeluBackpropFactor::GeluBackpropFactor(const Output<Node>& x)
: FusedOp({x})
{
constructor_validate_and_infer_types();
}
void op::GeluBackpropFactor::pre_validate_and_infer_types()
{
element::Type input_element_type = get_input_element_type(0);
NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type,
").");
}
shared_ptr<Node> op::GeluBackpropFactor::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<GeluBackpropFactor>(new_args.at(0));
}
NodeVector op::GeluBackpropFactor::decompose_op() const
{
auto x = input_value(0);
// 0.5 * (1 + erf( x * sqrt(1/2))
// + [x * exp (-x^2/2)] / sqrt(2 * pi)
auto half = builder::make_constant(x.get_element_type(), x.get_shape(), 0.5);
auto one = builder::make_constant(x.get_element_type(), x.get_shape(), 1.0);
auto pi = 4.0 * std::atan(1);
auto inv_sqrt_two_pi =
builder::make_constant(x.get_element_type(), x.get_shape(), 1.0 / std::sqrt(2.0 * pi));
auto sqrt_half = builder::make_constant(x.get_element_type(), x.get_shape(), std::sqrt(0.5));
auto e1 = half * (one + make_shared<op::Erf>(x * sqrt_half));
auto e2 = x * make_shared<op::Exp>(x * x * (-half)) * inv_sqrt_two_pi;
return {e1 + e2};
}
......@@ -26,9 +26,6 @@ namespace ngraph
{
/// \brief Gaussian Error Linear Unit
/// f(x) = 0.5 * x * (1 + erf( x / sqrt(2) )
/// erf'(x) = 2 / sqrt(pi) * exp (-x^2)
/// f'(x) = 0.5 * (1 + erf( x / sqrt(2)) + x * sqrt(2 / pi) * exp (-(x / sqrt(2))^2))
///
class Gelu : public ngraph::op::util::FusedOp
{
public:
......@@ -45,6 +42,29 @@ namespace ngraph
void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
/// \brief Backprop for Gelu(x) is GeluBackprop(x) * delta
class GeluBackpropFactor : public util::FusedOp
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
GeluBackpropFactor() = default;
GeluBackpropFactor(const Output<Node>& x);
virtual NodeVector decompose_op() const override;
void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
......
......@@ -30,6 +30,7 @@ NGRAPH_OP(DepthToSpace, ngraph::op)
NGRAPH_OP(Elu, ngraph::op)
NGRAPH_OP(FakeQuantize, ngraph::op)
NGRAPH_OP(Gelu, ngraph::op)
NGRAPH_OP(GeluBackpropFactor, ngraph::op)
NGRAPH_OP(Gemm, ngraph::op)
NGRAPH_OP(GRN, ngraph::op)
NGRAPH_OP(GroupConvolution, ngraph::op)
......
......@@ -216,6 +216,12 @@ fake_quantize_with_clip
fake_quantize_with_clip_across_channels
send_recv
send_recv_ring
logical_xor
# Needs erf
gelu_f32
gelu_f64
logical_xor
gelu_backprop_factor_f32
gelu_backprop_factor_f64
backwards_gelu_f32
backwards_gelu_f64
......@@ -110,9 +110,16 @@ fake_quantize_with_clip
fake_quantize_with_clip_across_channels
send_recv
send_recv_ring
logical_xor
# Needs erf
gelu_f32
gelu_f64
logical_xor
gelu_backprop_factor_f32
gelu_backprop_factor_f64
backwards_gelu_f32
backwards_gelu_f64
# Not supported quant ops
model_dequantize_linear_1d_zero_scale_int8
......
......@@ -232,8 +232,15 @@ backwards_softmax_underflow
backwards_softmax_3d
batch_mat_mul_forward
dot_matrix_2x0_0x2
# Need erf
gelu_f32
gelu_f64
gelu_backprop_factor_f32
gelu_backprop_factor_f64
backwards_gelu_f32
backwards_gelu_f64
# From onnx tests
model_quant_conv_linear_2d
model_quant_conv_linear_3d
......
......@@ -1221,6 +1221,11 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::Gelu>(args[0]);
break;
}
case OP_TYPEID::GeluBackpropFactor:
{
node = make_shared<op::GeluBackpropFactor>(args[0]);
break;
}
case OP_TYPEID::Gemm:
{
auto alpha = node_js.at("alpha").get<double>();
......@@ -2411,6 +2416,8 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Gelu: { break;
}
case OP_TYPEID::GeluBackpropFactor: { break;
}
case OP_TYPEID::Gemm:
{
auto tmp = dynamic_cast<const op::Gemm*>(&n);
......
......@@ -35,6 +35,7 @@
#include "ngraph/ngraph.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/autodiff/numeric_compare.hpp"
#include "util/ndarray.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"
......@@ -89,3 +90,92 @@ NGRAPH_TEST(${BACKEND_NAME}, gelu_f64)
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f(input, read_vector<double>(result)));
}
static double gelu_backprop_factor(double x)
{
auto pi = 4.0 * std::atan(1.0);
return 0.5 * (1.0 + erf(x * sqrt(1.0 / 2.0))) + (x * exp(-x * x / 2.0)) / sqrt(2.0 * pi);
}
NGRAPH_TEST(${BACKEND_NAME}, gelu_backprop_factor_f32)
{
Shape shape{8};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::GeluBackpropFactor>(A), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
vector<float> input{-4.0f, -3.0f, -2.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f};
copy_data(a, input);
auto result = backend->create_tensor(element::f32, shape);
std::transform(input.begin(), input.end(), input.begin(), [](float x) -> float {
return static_cast<float>(gelu_backprop_factor(x));
});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(
test::all_close_f(input, read_vector<float>(result), DEFAULT_FLOAT_TOLERANCE_BITS + 6));
}
NGRAPH_TEST(${BACKEND_NAME}, gelu_backprop_factor_f64)
{
Shape shape{8};
auto A = make_shared<op::Parameter>(element::f64, shape);
auto f = make_shared<Function>(make_shared<op::GeluBackpropFactor>(A), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f64, shape);
vector<double> input{-4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0};
copy_data(a, input);
auto result = backend->create_tensor(element::f64, shape);
std::transform(input.begin(), input.end(), input.begin(), [](double x) -> double {
return gelu_backprop_factor(x);
});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_TRUE(test::all_close_f(input, read_vector<double>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, backwards_gelu_f32)
{
auto backend = runtime::Backend::create("${BACKEND_NAME}");
Shape shape{8};
auto make_graph = [shape]() {
auto A = make_shared<op::Parameter>(element::f32, shape);
return make_shared<Function>(make_shared<op::Gelu>(A), ParameterVector{A});
};
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
vector<float> input{-4.0f, -3.0f, -2.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f};
copy_data(a, input);
EXPECT_TRUE(autodiff_numeric_compare<float>(backend.get(), make_graph, {a}, .01f, .01f));
}
NGRAPH_TEST(${BACKEND_NAME}, backwards_gelu_f64)
{
auto backend = runtime::Backend::create("${BACKEND_NAME}");
Shape shape{8};
auto make_graph = [shape]() {
auto A = make_shared<op::Parameter>(element::f64, shape);
return make_shared<Function>(make_shared<op::Gelu>(A), ParameterVector{A});
};
// Create some tensors for input/output
auto a = backend->create_tensor(element::f64, shape);
vector<double> input{-4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0};
copy_data(a, input);
EXPECT_TRUE(autodiff_numeric_compare<double>(backend.get(), make_graph, {a}, .01f, .01f));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment