Commit 36f5ab7a authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Add bfloat16 to make_constant builder (#3741)

parent b7d8c78b
...@@ -31,67 +31,75 @@ namespace ngraph ...@@ -31,67 +31,75 @@ namespace ngraph
{ {
std::shared_ptr<Node> val = nullptr; std::shared_ptr<Node> val = nullptr;
if (type == element::f32) #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (type)
{ {
case element::Type_t::f32:
val = std::make_shared<ngraph::op::Constant>( val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<float>{static_cast<float>(num)}); type, ngraph::Shape{}, std::vector<float>{static_cast<float>(num)});
} break;
else if (type == element::f64) case element::Type_t::f64:
{
val = std::make_shared<ngraph::op::Constant>( val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<double>{static_cast<double>(num)}); type, ngraph::Shape{}, std::vector<double>{static_cast<double>(num)});
} break;
else if (type == element::f16) case element::Type_t::f16:
{
val = std::make_shared<ngraph::op::Constant>( val = std::make_shared<ngraph::op::Constant>(
type, type,
ngraph::Shape{}, ngraph::Shape{},
std::vector<ngraph::float16>{ngraph::float16(static_cast<float>(num))}); std::vector<ngraph::float16>{ngraph::float16(static_cast<float>(num))});
} break;
else if (type == element::i64) case element::Type_t::bf16:
{ val = std::make_shared<ngraph::op::Constant>(
type,
ngraph::Shape{},
std::vector<ngraph::bfloat16>{ngraph::bfloat16(static_cast<float>(num))});
break;
case element::Type_t::i64:
val = std::make_shared<ngraph::op::Constant>( val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<int64_t>{static_cast<int64_t>(num)}); type, ngraph::Shape{}, std::vector<int64_t>{static_cast<int64_t>(num)});
} break;
else if (type == element::i32) case element::Type_t::i32:
{
val = std::make_shared<ngraph::op::Constant>( val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<int32_t>{static_cast<int32_t>(num)}); type, ngraph::Shape{}, std::vector<int32_t>{static_cast<int32_t>(num)});
} break;
else if (type == element::i16) case element::Type_t::i16:
{
val = std::make_shared<ngraph::op::Constant>( val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<int16_t>{static_cast<int16_t>(num)}); type, ngraph::Shape{}, std::vector<int16_t>{static_cast<int16_t>(num)});
} break;
else if (type == element::i8) case element::Type_t::i8:
{
val = std::make_shared<ngraph::op::Constant>( val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<int8_t>{static_cast<int8_t>(num)}); type, ngraph::Shape{}, std::vector<int8_t>{static_cast<int8_t>(num)});
} break;
else if (type == element::u64) case element::Type_t::u64:
{
val = std::make_shared<ngraph::op::Constant>( val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<uint64_t>{static_cast<uint64_t>(num)}); type, ngraph::Shape{}, std::vector<uint64_t>{static_cast<uint64_t>(num)});
} break;
else if (type == element::u32) case element::Type_t::u32:
{
val = std::make_shared<ngraph::op::Constant>( val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<uint32_t>{static_cast<uint32_t>(num)}); type, ngraph::Shape{}, std::vector<uint32_t>{static_cast<uint32_t>(num)});
} break;
else if (type == element::u16) case element::Type_t::u16:
{
val = std::make_shared<ngraph::op::Constant>( val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<uint16_t>{static_cast<uint16_t>(num)}); type, ngraph::Shape{}, std::vector<uint16_t>{static_cast<uint16_t>(num)});
} break;
else if (type == element::u8) case element::Type_t::u8:
{
val = std::make_shared<ngraph::op::Constant>( val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<uint8_t>{static_cast<uint8_t>(num)}); type, ngraph::Shape{}, std::vector<uint8_t>{static_cast<uint8_t>(num)});
break;
case element::Type_t::dynamic:
throw ngraph_error("make_constant: Unsupported element type 'dynamic'");
case element::Type_t::boolean:
throw ngraph_error("make_constant: Unsupported element type 'boolean'");
case element::Type_t::undefined:
throw ngraph_error("make_constant: Unsupported element type 'undefined'");
} }
else #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
{ #pragma GCC diagnostic pop
throw ngraph_error("make_constant: Unsupported element type"); #endif
}
if (shape.size() > 0) if (shape.size() > 0)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment