Commit 8ea33de1 authored by yimeisun123's avatar yimeisun123 Committed by Scott Cyphers

Add bfloat16 data type in ngraph (#1861)

* Add bfloat16 data type in ngraph

* Update on bfloat16 files
- remove uint16_t related functions
- adding/removing const modifier in functions
- default assignment operator
- style conformance

* Add bf16 element type handling in Constant OP

* Update bfloat16 data type implementation
- support nan and infinity cases
- add rounding option when creating bfloat16 from float

* Update the comment for copyright info
parent e92ee04c
......@@ -154,6 +154,7 @@ set (SRC
shape.cpp
shape_util.cpp
strides.cpp
type/bfloat16.cpp
type/element_type.cpp
util.cpp
validation_util.cpp
......
......@@ -71,6 +71,15 @@ vector<string> op::Constant::get_value_strings() const
rc.push_back(to_string(value));
}
}
else if (m_element_type == element::bf16)
{
float temp = 0;
for (bfloat16 value : get_vector<bfloat16>())
{
temp = static_cast<float>(value);
rc.push_back(to_cpp_string(temp));
}
}
else if (m_element_type == element::f32)
{
for (float value : get_vector<float>())
......
......@@ -21,6 +21,7 @@
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/type/bfloat16.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
......@@ -210,6 +211,10 @@ namespace ngraph
{
write_buffer<char, T>(target, source, target_element_count);
}
else if (target_type == element::bf16)
{
write_buffer<bfloat16, T>(target, source, target_element_count);
}
else if (target_type == element::f32)
{
write_buffer<float, T>(target, source, target_element_count);
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// Contains logic derived from TensorFlow’s bfloat16 implementation
// https://github.com/tensorflow/tensorflow/blob/d354efc/tensorflow/core/lib/bfloat16/bfloat16.h
// Copyright notice from original source file is as follows.
//*******************************************************************************
// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//==============================================================================
#include <cmath>
#include <iostream>
#include "ngraph/type/bfloat16.hpp"
using namespace std;
using namespace ngraph;
// A value represents NaN in bfloat16
static const uint16_t BF16_NAN_VALUE = 0x7FC0;
bool float_isnan(const float& x)
{
return std::isnan(x);
}
std::vector<float> bfloat16::to_float_vector(const std::vector<bfloat16>& v_bf16)
{
std::vector<float> v_f32(v_bf16.begin(), v_bf16.end());
return v_f32;
}
std::vector<bfloat16> bfloat16::from_float_vector(const std::vector<float>& v_f32)
{
std::vector<bfloat16> v_bf16(v_f32.size());
for (float a : v_f32)
{
v_bf16.push_back(static_cast<bfloat16>(a));
}
return v_bf16;
}
bfloat16::bfloat16(float value, bool rounding)
{
if (float_isnan(value))
{
m_value = BF16_NAN_VALUE;
}
else if (!rounding)
{
// Truncate off 16 LSB, no rounding
// Treat system as little endian (Intel x86 family)
uint16_t* u16_ptr = reinterpret_cast<uint16_t*>(&value);
m_value = u16_ptr[1];
}
else
{
// Rounding with round-nearest-to-even to create bfloat16
// from float. Refer to TF implementation explanation:
// https://github.com/tensorflow/tensorflow/blob/d354efc/tensorflow/core/lib/bfloat16/bfloat16.h#L199
uint32_t* u32_ptr = reinterpret_cast<uint32_t*>(&value);
uint32_t u32_value = *u32_ptr;
uint32_t lsb = (u32_value >> 16) & 1;
uint32_t rounding_bias = 0x7fff + lsb;
u32_value += rounding_bias;
m_value = static_cast<uint16_t>(u32_value >> 16);
}
}
std::string bfloat16::to_string() const
{
return std::to_string(static_cast<float>(*this));
}
size_t bfloat16::size() const
{
return sizeof(m_value);
}
bool bfloat16::operator==(const bfloat16& other) const
{
return (static_cast<float>(*this) == static_cast<float>(other));
}
bool bfloat16::operator<(const bfloat16& other) const
{
return (static_cast<float>(*this) < static_cast<float>(other));
}
bool bfloat16::operator<=(const bfloat16& other) const
{
return (static_cast<float>(*this) <= static_cast<float>(other));
}
bool bfloat16::operator>(const bfloat16& other) const
{
return (static_cast<float>(*this) > static_cast<float>(other));
}
bool bfloat16::operator>=(const bfloat16& other) const
{
return (static_cast<float>(*this) >= static_cast<float>(other));
}
bfloat16::operator float() const
{
float result = 0;
uint16_t* u16_ptr = reinterpret_cast<uint16_t*>(&result);
// Treat the system as little endian (Intel x86 family)
u16_ptr[1] = m_value;
return result;
}
std::ostream& operator<<(std::ostream& out, const bfloat16& obj)
{
return (out << static_cast<float>(obj));
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
//================================================================================================
// bfloat16 type
//================================================================================================
#pragma once
#include <iostream>
#include <memory>
#include <string>
#include <vector>
namespace ngraph
{
class bfloat16
{
public:
bfloat16() {}
bfloat16(float value, bool rounding = false);
bfloat16& operator=(const bfloat16&) = default;
virtual ~bfloat16() {}
std::string to_string() const;
size_t size() const;
bool operator==(const bfloat16& other) const;
bool operator!=(const bfloat16& other) const { return !(*this == other); }
bool operator<(const bfloat16& other) const;
bool operator<=(const bfloat16& other) const;
bool operator>(const bfloat16& other) const;
bool operator>=(const bfloat16& other) const;
operator float() const;
static std::vector<float> to_float_vector(const std::vector<bfloat16>&);
static std::vector<bfloat16> from_float_vector(const std::vector<float>&);
friend std::ostream& operator<<(std::ostream&, const bfloat16&);
private:
uint16_t m_value{0};
};
}
......@@ -23,6 +23,7 @@ using namespace ngraph;
const element::Type element::dynamic(0, false, false, false, "dynamic");
const element::Type element::boolean(8, false, true, false, "char");
const element::Type element::bf16(16, true, true, false, "bfloat16");
const element::Type element::f32(32, true, true, false, "float");
const element::Type element::f64(64, true, true, false, "double");
const element::Type element::i8(8, false, true, true, "int8_t");
......@@ -37,6 +38,7 @@ const element::Type element::u64(64, false, false, false, "uint64_t");
std::vector<const element::Type*> element::Type::get_known_types()
{
std::vector<const element::Type*> rc = {&element::boolean,
&element::bf16,
&element::f32,
&element::f64,
&element::i8,
......@@ -175,6 +177,11 @@ namespace ngraph
{
return u64;
}
template <>
const Type& from<ngraph::bfloat16>()
{
return bf16;
}
}
}
......
......@@ -26,6 +26,7 @@
#include <vector>
#include "ngraph/except.hpp"
#include "ngraph/type/bfloat16.hpp"
namespace ngraph
{
......@@ -35,6 +36,7 @@ namespace ngraph
extern const Type dynamic;
extern const Type boolean;
extern const Type bf16;
extern const Type f32;
extern const Type f64;
extern const Type i8;
......@@ -132,6 +134,8 @@ namespace ngraph
const Type& from<uint32_t>();
template <>
const Type& from<uint64_t>();
template <>
const Type& from<ngraph::bfloat16>();
std::ostream& operator<<(std::ostream& out, const ngraph::element::Type& obj);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment