Commit 36f5ab7a authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Add bfloat16 to make_constant builder (#3741)

parent b7d8c78b
......@@ -31,67 +31,75 @@ namespace ngraph
{
std::shared_ptr<Node> val = nullptr;
if (type == element::f32)
{
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (type)
{
case element::Type_t::f32:
val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<float>{static_cast<float>(num)});
}
else if (type == element::f64)
{
break;
case element::Type_t::f64:
val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<double>{static_cast<double>(num)});
}
else if (type == element::f16)
{
break;
case element::Type_t::f16:
val = std::make_shared<ngraph::op::Constant>(
type,
ngraph::Shape{},
std::vector<ngraph::float16>{ngraph::float16(static_cast<float>(num))});
}
else if (type == element::i64)
{
break;
case element::Type_t::bf16:
val = std::make_shared<ngraph::op::Constant>(
type,
ngraph::Shape{},
std::vector<ngraph::bfloat16>{ngraph::bfloat16(static_cast<float>(num))});
break;
case element::Type_t::i64:
val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<int64_t>{static_cast<int64_t>(num)});
}
else if (type == element::i32)
{
break;
case element::Type_t::i32:
val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<int32_t>{static_cast<int32_t>(num)});
}
else if (type == element::i16)
{
break;
case element::Type_t::i16:
val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<int16_t>{static_cast<int16_t>(num)});
}
else if (type == element::i8)
{
break;
case element::Type_t::i8:
val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<int8_t>{static_cast<int8_t>(num)});
}
else if (type == element::u64)
{
break;
case element::Type_t::u64:
val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<uint64_t>{static_cast<uint64_t>(num)});
}
else if (type == element::u32)
{
break;
case element::Type_t::u32:
val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<uint32_t>{static_cast<uint32_t>(num)});
}
else if (type == element::u16)
{
break;
case element::Type_t::u16:
val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<uint16_t>{static_cast<uint16_t>(num)});
}
else if (type == element::u8)
{
break;
case element::Type_t::u8:
val = std::make_shared<ngraph::op::Constant>(
type, ngraph::Shape{}, std::vector<uint8_t>{static_cast<uint8_t>(num)});
}
else
{
throw ngraph_error("make_constant: Unsupported element type");
}
break;
case element::Type_t::dynamic:
throw ngraph_error("make_constant: Unsupported element type 'dynamic'");
case element::Type_t::boolean:
throw ngraph_error("make_constant: Unsupported element type 'boolean'");
case element::Type_t::undefined:
throw ngraph_error("make_constant: Unsupported element type 'undefined'");
}
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
#endif
if (shape.size() > 0)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment