Unverified Commit 833a8f14 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Remove ParameterizedConstant (#309)

* remove use of ParameterizedConstant from unit test. Now using Constant instead. Constant is not a templated class.

* move ParameterizedTensorView to NGVM directory

* element_type cleanup
parent 8fc6473e
...@@ -34,7 +34,7 @@ using namespace ngraph; ...@@ -34,7 +34,7 @@ using namespace ngraph;
std::shared_ptr<Node> make_zero(const element::Type& element_type, const Shape& shape) std::shared_ptr<Node> make_zero(const element::Type& element_type, const Shape& shape)
{ {
std::shared_ptr<Node> zero = std::make_shared<op::Constant>(element_type, Shape{}, "0"); std::shared_ptr<Node> zero = op::Constant::create(element_type, Shape{}, {0.0});
if (shape.size() > 0) if (shape.size() > 0)
{ {
AxisSet axes; AxisSet axes;
......
...@@ -43,7 +43,7 @@ namespace ngraph ...@@ -43,7 +43,7 @@ namespace ngraph
auto x2sum = std::make_shared<op::Sum>(x2, reduction_axes); auto x2sum = std::make_shared<op::Sum>(x2, reduction_axes);
// TODO(mbrookhart): Use Sqrt instead of Power // TODO(mbrookhart): Use Sqrt instead of Power
auto half = std::make_shared<op::Constant>(et, x2sum->get_shape(), "0.5"); auto half = op::Constant::create(et, x2sum->get_shape(), {0.5});
return std::make_shared<op::Power>(x2sum, half); return std::make_shared<op::Power>(x2sum, half);
} }
...@@ -54,7 +54,7 @@ namespace ngraph ...@@ -54,7 +54,7 @@ namespace ngraph
auto N = get_num_elements(node->get_shape(), reduction_axes); auto N = get_num_elements(node->get_shape(), reduction_axes);
const auto& et = node->get_element_type(); const auto& et = node->get_element_type();
auto divisor = std::make_shared<op::Constant>(et, xsum->get_shape(), std::to_string(N)); auto divisor = op::Constant::create(et, xsum->get_shape(), {N});
return xsum / divisor; return xsum / divisor;
} }
...@@ -67,7 +67,7 @@ namespace ngraph ...@@ -67,7 +67,7 @@ namespace ngraph
const auto& et = node->get_element_type(); const auto& et = node->get_element_type();
// TODO(mbrookhart): Use Sqrt instead of Power // TODO(mbrookhart): Use Sqrt instead of Power
auto half = std::make_shared<op::Constant>(et, var->get_shape(), "0.5"); auto half = op::Constant::create(et, var->get_shape(), {0.5});
return std::make_shared<op::Power>(var, half); return std::make_shared<op::Power>(var, half);
} }
...@@ -88,15 +88,14 @@ namespace ngraph ...@@ -88,15 +88,14 @@ namespace ngraph
const auto& et = node->get_element_type(); const auto& et = node->get_element_type();
auto N = get_num_elements(node->get_shape(), reduction_axes); auto N = get_num_elements(node->get_shape(), reduction_axes);
auto Nconst = std::make_shared<op::Constant>(et, xsum->get_shape(), std::to_string(N)); auto Nconst = op::Constant::create(et, xsum->get_shape(), {N});
auto xbar2 = (xsum * xsum) / Nconst; auto xbar2 = (xsum * xsum) / Nconst;
auto diff = x2sum - xbar2; auto diff = x2sum - xbar2;
if (bessel_correction) if (bessel_correction)
{ {
auto N1const = auto N1const = op::Constant::create(et, xsum->get_shape(), {N - 1});
std::make_shared<op::Constant>(et, xsum->get_shape(), std::to_string(N - 1));
return diff / N1const; return diff / N1const;
} }
else else
......
...@@ -115,7 +115,6 @@ ...@@ -115,7 +115,6 @@
#include "ngraph/runtime/manager.hpp" #include "ngraph/runtime/manager.hpp"
#include "ngraph/runtime/ngvm/ngvm_backend.hpp" #include "ngraph/runtime/ngvm/ngvm_backend.hpp"
#include "ngraph/runtime/ngvm/ngvm_manager.hpp" #include "ngraph/runtime/ngvm/ngvm_manager.hpp"
#include "ngraph/runtime/parameterized_tensor_view.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tuple.hpp" #include "ngraph/runtime/tuple.hpp"
#include "ngraph/runtime/value.hpp" #include "ngraph/runtime/value.hpp"
......
...@@ -13,81 +13,115 @@ ...@@ -13,81 +13,115 @@
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/ops/constant.hpp" #include "ngraph/ops/constant.hpp"
#include "ngraph/log.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std;
namespace op::Constant::~Constant()
{ {
template <typename ET> if (m_data)
void check_value_strings(const std::vector<std::string>& value_strings)
{ {
auto result = parse_string<typename ET::type>(value_strings); free(m_data);
} }
} }
// std::vector<std::string> op::Constant::get_value_strings() const
// Utility macro for dispatching an element type-templated function at runtime.
//
// clang-format off
// Sorry, but you really don't want to see what clang-format does to this thing. :)
#define FUNCTION_ON_ELEMENT_TYPE(et, err_msg, f, ...) \
( \
((et) == element::Bool::element_type()) ? (f<element::Bool>(__VA_ARGS__)) : \
((et) == element::Float32::element_type()) ? (f<element::Float32>(__VA_ARGS__)) : \
((et) == element::Int8::element_type()) ? (f<element::Int8>(__VA_ARGS__)) : \
((et) == element::Int16::element_type()) ? (f<element::Int16>(__VA_ARGS__)) : \
((et) == element::Int32::element_type()) ? (f<element::Int32>(__VA_ARGS__)) : \
((et) == element::Int64::element_type()) ? (f<element::Int64>(__VA_ARGS__)) : \
((et) == element::UInt8::element_type()) ? (f<element::UInt8>(__VA_ARGS__)) : \
((et) == element::UInt16::element_type()) ? (f<element::UInt16>(__VA_ARGS__)) : \
((et) == element::UInt32::element_type()) ? (f<element::UInt32>(__VA_ARGS__)) : \
((et) == element::UInt64::element_type()) ? (f<element::UInt64>(__VA_ARGS__)) : \
(throw ngraph_error(err_msg)) \
)
// clang-format on
op::Constant::Constant(const element::Type& et,
const Shape& shape,
const std::vector<std::string>& value_strings)
: ConstantBase("Constant", std::make_shared<TensorViewType>(et, shape))
, m_value_strings(value_strings)
{ {
check_args(); vector<string> rc;
}
/// \brief Constructs a tensor constant with the same initialization value copied across the tensor.
///
/// \param et The element type of the tensor constant.
/// \param shape The shape of the tensor constant.
/// \param value_string A literal for initializing each tensor constant.
op::Constant::Constant(const element::Type& et, const Shape& shape, const std::string& value_string)
: ConstantBase("Constant", std::make_shared<TensorViewType>(et, shape))
, m_value_strings(ngraph::shape_size(shape), value_string)
{
check_args();
}
void op::Constant::check_args() if (m_element_type == element::boolean)
{
// We check the number of value strings and
// also call check_value_strings just to make sure the result will be parseable at compile
// time. (It will throw an exception if not.)
auto tvt = std::dynamic_pointer_cast<const TensorViewType>(m_value_type);
if (nullptr == tvt)
{ {
throw ngraph_error("Constant does not have tensor view type"); for (int value : get_vector<char>())
{
rc.push_back(to_string(value));
}
} }
auto shape = tvt->get_shape(); else if (m_element_type == element::f32)
{
if (ngraph::shape_size(shape) != m_value_strings.size()) for (float value : get_vector<float>())
{
rc.push_back(to_string(value));
}
}
else if (m_element_type == element::f64)
{
for (double value : get_vector<double>())
{
rc.push_back(to_string(value));
}
}
else if (m_element_type == element::i8)
{
for (int value : get_vector<int8_t>())
{
rc.push_back(to_string(value));
}
}
else if (m_element_type == element::i16)
{ {
throw ngraph_error("Constant does not have the expected number of literals"); for (int value : get_vector<int16_t>())
{
rc.push_back(to_string(value));
}
}
else if (m_element_type == element::i32)
{
for (int32_t value : get_vector<int32_t>())
{
NGRAPH_INFO << value;
rc.push_back(to_string(value));
}
}
else if (m_element_type == element::i64)
{
for (int64_t value : get_vector<int64_t>())
{
rc.push_back(to_string(value));
}
}
else if (m_element_type == element::u8)
{
for (uint value : get_vector<uint8_t>())
{
rc.push_back(to_string(value));
}
}
else if (m_element_type == element::u16)
{
for (uint value : get_vector<uint16_t>())
{
rc.push_back(to_string(value));
}
}
else if (m_element_type == element::u32)
{
for (uint32_t value : get_vector<uint32_t>())
{
rc.push_back(to_string(value));
}
}
else if (m_element_type == element::u64)
{
for (uint64_t value : get_vector<uint64_t>())
{
rc.push_back(to_string(value));
}
}
else
{
throw std::runtime_error("unsupported type");
} }
auto& et = tvt->get_element_type(); return rc;
}
FUNCTION_ON_ELEMENT_TYPE( template <>
et, "Constant has unhandled element type", check_value_strings, m_value_strings); void op::Constant::write_to_buffer<std::string>(const element::Type& target_type,
const Shape& target_shape,
const std::vector<std::string>& source,
void* target,
size_t target_element_count)
{
} }
This diff is collapsed.
...@@ -124,7 +124,7 @@ void op::ReplaceSlice::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -124,7 +124,7 @@ void op::ReplaceSlice::generate_adjoints(autodiff::Adjoints& adjoints,
auto& y_element_type = y_input.get_element_type(); auto& y_element_type = y_input.get_element_type();
auto y_shape = y_input.get_shape(); auto y_shape = y_input.get_shape();
auto zeros_shaped_like_y = std::make_shared<op::Constant>(y_element_type, y_shape, "0"); auto zeros_shaped_like_y = op::Constant::create(y_element_type, y_shape, {0.0});
adjoints.add_delta(x, adjoints.add_delta(x,
std::make_shared<op::ReplaceSlice>( std::make_shared<op::ReplaceSlice>(
......
This diff is collapsed.
...@@ -69,14 +69,6 @@ namespace ngraph ...@@ -69,14 +69,6 @@ namespace ngraph
void EMITTER_DECL(EmitNotEqual); void EMITTER_DECL(EmitNotEqual);
void EMITTER_DECL(EmitSelect); void EMITTER_DECL(EmitSelect);
void EMITTER_DECL(EmitSubtract); void EMITTER_DECL(EmitSubtract);
void EMITTER_DECL(EmitParameterizedConstantBool);
void EMITTER_DECL(EmitParameterizedConstantFloat32);
void EMITTER_DECL(EmitParameterizedConstantInt8);
void EMITTER_DECL(EmitParameterizedConstantInt32);
void EMITTER_DECL(EmitParameterizedConstantInt64);
void EMITTER_DECL(EmitParameterizedConstantUInt8);
void EMITTER_DECL(EmitParameterizedConstantUInt32);
void EMITTER_DECL(EmitParameterizedConstantUInt64);
void EMITTER_DECL(EmitBroadcast); void EMITTER_DECL(EmitBroadcast);
void EMITTER_DECL(EmitConvert); void EMITTER_DECL(EmitConvert);
void EMITTER_DECL(EmitConstant); void EMITTER_DECL(EmitConstant);
......
...@@ -127,22 +127,6 @@ static const runtime::cpu::OpMap dispatcher{ ...@@ -127,22 +127,6 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Power), &runtime::cpu::CPU_Emitter::EmitPower}, {TI(ngraph::op::Power), &runtime::cpu::CPU_Emitter::EmitPower},
{TI(ngraph::op::Select), &runtime::cpu::CPU_Emitter::EmitSelect}, {TI(ngraph::op::Select), &runtime::cpu::CPU_Emitter::EmitSelect},
{TI(ngraph::op::Subtract), &runtime::cpu::CPU_Emitter::EmitSubtract}, {TI(ngraph::op::Subtract), &runtime::cpu::CPU_Emitter::EmitSubtract},
{TI(ngraph::op::ParameterizedConstant<ngraph::element::Bool>),
&runtime::cpu::CPU_Emitter::EmitParameterizedConstantBool},
{TI(ngraph::op::ParameterizedConstant<ngraph::element::Float32>),
&runtime::cpu::CPU_Emitter::EmitParameterizedConstantFloat32},
{TI(ngraph::op::ParameterizedConstant<ngraph::element::Int8>),
&runtime::cpu::CPU_Emitter::EmitParameterizedConstantInt8},
{TI(ngraph::op::ParameterizedConstant<ngraph::element::Int32>),
&runtime::cpu::CPU_Emitter::EmitParameterizedConstantInt32},
{TI(ngraph::op::ParameterizedConstant<ngraph::element::Int64>),
&runtime::cpu::CPU_Emitter::EmitParameterizedConstantInt64},
{TI(ngraph::op::ParameterizedConstant<ngraph::element::UInt8>),
&runtime::cpu::CPU_Emitter::EmitParameterizedConstantUInt8},
{TI(ngraph::op::ParameterizedConstant<ngraph::element::UInt32>),
&runtime::cpu::CPU_Emitter::EmitParameterizedConstantUInt32},
{TI(ngraph::op::ParameterizedConstant<ngraph::element::UInt64>),
&runtime::cpu::CPU_Emitter::EmitParameterizedConstantUInt64},
{TI(ngraph::op::Broadcast), &runtime::cpu::CPU_Emitter::EmitBroadcast}, {TI(ngraph::op::Broadcast), &runtime::cpu::CPU_Emitter::EmitBroadcast},
{TI(ngraph::op::Convert), &runtime::cpu::CPU_Emitter::EmitConvert}, {TI(ngraph::op::Convert), &runtime::cpu::CPU_Emitter::EmitConvert},
{TI(ngraph::op::Constant), &runtime::cpu::CPU_Emitter::EmitConstant}, {TI(ngraph::op::Constant), &runtime::cpu::CPU_Emitter::EmitConstant},
......
...@@ -259,9 +259,8 @@ private: ...@@ -259,9 +259,8 @@ private:
} }
else if (node_op == "Constant") else if (node_op == "Constant")
{ {
auto c = static_cast<const op::Constant*>(&node); const op::Constant* c = static_cast<const op::Constant*>(&node);
std::vector<T> input = ngraph::parse_string<T>(c->get_value_strings()); kernel::constant<T>(reinterpret_cast<const T*>(c->get_data_ptr()),
kernel::constant<T>(reinterpret_cast<T*>(input.data()),
reinterpret_cast<T*>(out[0]->get_data_ptr()), reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count()); out[0]->get_element_count());
} }
...@@ -426,98 +425,6 @@ private: ...@@ -426,98 +425,6 @@ private:
else if (node_op == "Parameter") else if (node_op == "Parameter")
{ {
} }
else if (node_op == "ParameterizedConstant")
{
// I would like to appologize for this...
element::Type type = element::from<T>();
const void* data;
if (type == element::boolean)
{
data = dynamic_cast<const op::ParameterizedConstant<element::Bool>*>(&node)
->get_value()
->get_vector()
.data();
}
else if (type == element::f32)
{
data = dynamic_cast<const op::ParameterizedConstant<element::Float32>*>(&node)
->get_value()
->get_vector()
.data();
}
else if (type == element::f64)
{
data = dynamic_cast<const op::ParameterizedConstant<element::Float64>*>(&node)
->get_value()
->get_vector()
.data();
}
else if (type == element::i8)
{
data = dynamic_cast<const op::ParameterizedConstant<element::Int8>*>(&node)
->get_value()
->get_vector()
.data();
}
else if (type == element::i16)
{
data = dynamic_cast<const op::ParameterizedConstant<element::Int16>*>(&node)
->get_value()
->get_vector()
.data();
}
else if (type == element::i32)
{
data = dynamic_cast<const op::ParameterizedConstant<element::Int32>*>(&node)
->get_value()
->get_vector()
.data();
}
else if (type == element::i64)
{
data = dynamic_cast<const op::ParameterizedConstant<element::Int64>*>(&node)
->get_value()
->get_vector()
.data();
}
else if (type == element::u8)
{
data = dynamic_cast<const op::ParameterizedConstant<element::UInt8>*>(&node)
->get_value()
->get_vector()
.data();
}
else if (type == element::u16)
{
data = dynamic_cast<const op::ParameterizedConstant<element::UInt16>*>(&node)
->get_value()
->get_vector()
.data();
}
else if (type == element::u32)
{
data = dynamic_cast<const op::ParameterizedConstant<element::UInt32>*>(&node)
->get_value()
->get_vector()
.data();
}
else if (type == element::u64)
{
data = dynamic_cast<const op::ParameterizedConstant<element::UInt64>*>(&node)
->get_value()
->get_vector()
.data();
}
else
{
std::stringstream ss;
ss << "unsupported element type " << type << " op " << node.get_name();
throw std::runtime_error(ss.str());
}
kernel::copy<T>(reinterpret_cast<T*>(const_cast<void*>(data)),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Power") else if (node_op == "Power")
{ {
kernel::power<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()), kernel::power<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
......
...@@ -21,7 +21,7 @@ namespace ngraph ...@@ -21,7 +21,7 @@ namespace ngraph
namespace kernel namespace kernel
{ {
template <typename T> template <typename T>
void constant(T* arg0, T* out, size_t count) void constant(const T* arg0, T* out, size_t count)
{ {
for (size_t i = 0; i < count; i++) for (size_t i = 0; i < count; i++)
{ {
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/runtime/call_frame.hpp" #include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/parameterized_tensor_view.hpp" #include "ngraph/runtime/ngvm/parameterized_tensor_view.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
......
...@@ -247,12 +247,6 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func ...@@ -247,12 +247,6 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
} \ } \
} }
#define REGISTER_INSTRUCTION(op_class, instr_class, ...) \
REGISTER_TO_OP_MAP(op_class) \
{ \
ef->get_instructions()->push_back(make_shared<instr_class>(__VA_ARGS__)); \
}
#define M_REGISTER_NUMERIC_UNOP(T, instr_class) \ #define M_REGISTER_NUMERIC_UNOP(T, instr_class) \
ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], out[0])); ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], out[0]));
#define REGISTER_NUMERIC_UNOP(op_class, instr_class) \ #define REGISTER_NUMERIC_UNOP(op_class, instr_class) \
...@@ -316,27 +310,6 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func ...@@ -316,27 +310,6 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
instr_class); \ instr_class); \
} }
template <typename ET>
std::vector<typename ET::type>
get_vector(std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>> ptv)
{
std::vector<typename ET::type> rc;
rc = ptv->get_vector();
return rc;
}
#define REGISTER_CONSTANT_INSTRUCTIONS(T) \
{ \
REGISTER_INSTRUCTION( \
op::ParameterizedConstant<T>, \
instruction::ConstantInstruction<T>, \
std::vector<T::type>{ \
get_vector<T>(dynamic_cast<const op::ParameterizedConstant<T>*>(n)->get_value())}, \
out[0]); \
}
#define PUSH_INSTRUCTION(T, instr, ...) \ #define PUSH_INSTRUCTION(T, instr, ...) \
{ \ { \
ef->get_instructions()->push_back(make_shared<instr<T>>(__VA_ARGS__)); \ ef->get_instructions()->push_back(make_shared<instr<T>>(__VA_ARGS__)); \
...@@ -408,18 +381,6 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -408,18 +381,6 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
REGISTER_POLYMORPHIC_TERNOP(op::Select, instruction::SelectInstruction); REGISTER_POLYMORPHIC_TERNOP(op::Select, instruction::SelectInstruction);
REGISTER_CONSTANT_INSTRUCTIONS(element::Bool);
REGISTER_CONSTANT_INSTRUCTIONS(element::Float32);
REGISTER_CONSTANT_INSTRUCTIONS(element::Float64);
REGISTER_CONSTANT_INSTRUCTIONS(element::Int8);
REGISTER_CONSTANT_INSTRUCTIONS(element::Int16);
REGISTER_CONSTANT_INSTRUCTIONS(element::Int32);
REGISTER_CONSTANT_INSTRUCTIONS(element::Int64);
REGISTER_CONSTANT_INSTRUCTIONS(element::UInt8);
REGISTER_CONSTANT_INSTRUCTIONS(element::UInt16);
REGISTER_CONSTANT_INSTRUCTIONS(element::UInt32);
REGISTER_CONSTANT_INSTRUCTIONS(element::UInt64);
REGISTER_TO_OP_MAP(op::Broadcast) REGISTER_TO_OP_MAP(op::Broadcast)
{ {
auto broadcast = static_cast<const op::Broadcast*>(n); auto broadcast = static_cast<const op::Broadcast*>(n);
...@@ -658,9 +619,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -658,9 +619,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
std::shared_ptr<CallFrame> cf = \ std::shared_ptr<CallFrame> cf = \
std::dynamic_pointer_cast<CallFrame>(external->make_call_frame()); \ std::dynamic_pointer_cast<CallFrame>(external->make_call_frame()); \
\ \
auto tx = ngraph::runtime::make_tensor<ET>(Shape{}, {x}); \ auto tx = ngraph::runtime::ngvm::make_tensor<ET>(Shape{}, {x}); \
auto ty = ngraph::runtime::make_tensor<ET>(Shape{}, {y}); \ auto ty = ngraph::runtime::ngvm::make_tensor<ET>(Shape{}, {y}); \
auto tr = ngraph::runtime::make_tensor<ET>(Shape{}); \ auto tr = ngraph::runtime::ngvm::make_tensor<ET>(Shape{}); \
\ \
cf->call({tx, ty}, {tr}); \ cf->call({tx, ty}, {tr}); \
return tr->get_vector()[0]; \ return tr->get_vector()[0]; \
...@@ -943,7 +904,7 @@ shared_ptr<ngraph::runtime::CallFrame> ExternalFunction::make_call_frame(Functio ...@@ -943,7 +904,7 @@ shared_ptr<ngraph::runtime::CallFrame> ExternalFunction::make_call_frame(Functio
auto& et = tv->get_tensor_view_type()->get_element_type(); auto& et = tv->get_tensor_view_type()->get_element_type();
auto shape = tv->get_tensor_view_type()->get_shape(); auto shape = tv->get_tensor_view_type()->get_shape();
#define M(T) temps.push_back(ngraph::runtime::make_tensor<T>(shape)); #define M(T) temps.push_back(ngraph::runtime::ngvm::make_tensor<T>(shape));
DO_ON_ELEMENT_TYPE( DO_ON_ELEMENT_TYPE(
et, "Internal error: tried to create temporary for unhandled element type", M); et, "Internal error: tried to create temporary for unhandled element type", M);
#undef M #undef M
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "ngraph/runtime/ngvm/ngvm_backend.hpp" #include "ngraph/runtime/ngvm/ngvm_backend.hpp"
#include "ngraph/runtime/external_function.hpp" #include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/parameterized_tensor_view.hpp" #include "ngraph/runtime/ngvm/parameterized_tensor_view.hpp"
using namespace ngraph::runtime::ngvm; using namespace ngraph::runtime::ngvm;
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp" #include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/parameterized_tensor_view.hpp"
#include "ngraph/runtime/ngvm/tensor_view_info.hpp" #include "ngraph/runtime/ngvm/tensor_view_info.hpp"
namespace ngraph namespace ngraph
...@@ -44,6 +45,24 @@ namespace ngraph ...@@ -44,6 +45,24 @@ namespace ngraph
.get_layout<ngraph::descriptor::layout::DenseTensorViewLayout>() .get_layout<ngraph::descriptor::layout::DenseTensorViewLayout>()
->get_size(); ->get_size();
} }
/// @brief Framework constructor of a tensor of a specific element type and shape.
template <typename ET>
std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>
make_tensor(const Shape& shape)
{
return std::make_shared<runtime::ParameterizedTensorView<ET>>(shape);
}
/// @brief Framework constructor of a tensor of a specific element type and shape.
template <typename ET>
std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>
make_tensor(const Shape& shape, const std::vector<typename ET::type>& data)
{
auto rc = std::make_shared<runtime::ParameterizedTensorView<ET>>(shape);
rc->write(data.data(), 0, data.size() * sizeof(typename ET::type));
return rc;
}
} }
} }
} }
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "ngraph/runtime/parameterized_tensor_view.hpp"
#include "ngraph/runtime/tuple.hpp" #include "ngraph/runtime/tuple.hpp"
#include "ngraph/runtime/value.hpp" #include "ngraph/runtime/value.hpp"
#include "ngraph/types/element_type.hpp" #include "ngraph/types/element_type.hpp"
...@@ -26,24 +25,6 @@ namespace ngraph ...@@ -26,24 +25,6 @@ namespace ngraph
{ {
namespace runtime namespace runtime
{ {
/// @brief Framework constructor of a tensor of a specific element type and shape.
template <typename ET>
std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>
make_tensor(const Shape& shape)
{
return std::make_shared<runtime::ParameterizedTensorView<ET>>(shape);
}
/// @brief Framework constructor of a tensor of a specific element type and shape.
template <typename ET>
std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>
make_tensor(const Shape& shape, const std::vector<typename ET::type>& data)
{
auto rc = std::make_shared<runtime::ParameterizedTensorView<ET>>(shape);
rc->write(data.data(), 0, data.size() * sizeof(typename ET::type));
return rc;
}
/// @brief Framework constructor of a tuple from a sequence of values. /// @brief Framework constructor of a tuple from a sequence of values.
std::shared_ptr<ngraph::runtime::Tuple> std::shared_ptr<ngraph::runtime::Tuple>
make_tuple(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& elements); make_tuple(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& elements);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <cmath> #include <cmath>
#include <iostream> #include <iostream>
#include "ngraph/log.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/types/element_type.hpp" #include "ngraph/types/element_type.hpp"
...@@ -86,6 +87,73 @@ size_t element::Type::hash() const ...@@ -86,6 +87,73 @@ size_t element::Type::hash() const
return h1 ^ ((h2 ^ (h3 << 1)) << 1); return h1 ^ ((h2 ^ (h3 << 1)) << 1);
} }
namespace ngraph
{
namespace element
{
template <>
const Type& from<char>()
{
return boolean;
}
template <>
const Type& from<bool>()
{
return boolean;
}
template <>
const Type& from<float>()
{
return f32;
}
template <>
const Type& from<double>()
{
return f64;
}
template <>
const Type& from<int8_t>()
{
return i8;
}
template <>
const Type& from<int16_t>()
{
return i16;
}
template <>
const Type& from<int32_t>()
{
return i32;
}
template <>
const Type& from<int64_t>()
{
return i64;
}
template <>
const Type& from<uint8_t>()
{
return u8;
}
template <>
const Type& from<uint16_t>()
{
return u16;
}
template <>
const Type& from<uint32_t>()
{
return u32;
}
template <>
const Type& from<uint64_t>()
{
return u64;
}
}
}
std::ostream& element::operator<<(std::ostream& out, const element::Type& obj) std::ostream& element::operator<<(std::ostream& out, const element::Type& obj)
{ {
out << "element::Type(" << obj.m_bitwidth << ", " << obj.m_is_real << ", " << obj.m_is_signed out << "element::Type(" << obj.m_bitwidth << ", " << obj.m_is_real << ", " << obj.m_is_signed
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "ngraph/common.hpp" #include "ngraph/common.hpp"
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
#include "ngraph/log.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -76,55 +77,32 @@ namespace ngraph ...@@ -76,55 +77,32 @@ namespace ngraph
template <typename T> template <typename T>
const Type& from() const Type& from()
{ {
if (typeid(T) == typeid(char) || typeid(T) == typeid(bool)) throw std::invalid_argument("Unknown type");
{
return boolean;
}
else if (typeid(T) == typeid(float))
{
return f32;
}
else if (typeid(T) == typeid(double))
{
return f64;
}
else if (typeid(T) == typeid(int8_t))
{
return i8;
}
else if (typeid(T) == typeid(int16_t))
{
return i16;
}
else if (typeid(T) == typeid(int32_t))
{
return i32;
}
else if (typeid(T) == typeid(int64_t))
{
return i64;
}
else if (typeid(T) == typeid(uint8_t))
{
return u8;
}
else if (typeid(T) == typeid(uint16_t))
{
return u16;
}
else if (typeid(T) == typeid(uint32_t))
{
return u32;
}
else if (typeid(T) == typeid(uint64_t))
{
return u64;
}
else
{
throw std::invalid_argument("Unknown type");
}
} }
template <>
const Type& from<char>();
template <>
const Type& from<bool>();
template <>
const Type& from<float>();
template <>
const Type& from<double>();
template <>
const Type& from<int8_t>();
template <>
const Type& from<int16_t>();
template <>
const Type& from<int32_t>();
template <>
const Type& from<int64_t>();
template <>
const Type& from<uint8_t>();
template <>
const Type& from<uint16_t>();
template <>
const Type& from<uint32_t>();
template <>
const Type& from<uint64_t>();
std::ostream& operator<<(std::ostream& out, const ngraph::element::Type& obj); std::ostream& operator<<(std::ostream& out, const ngraph::element::Type& obj);
......
This diff is collapsed.
...@@ -76,19 +76,19 @@ TEST(build_graph, literal) ...@@ -76,19 +76,19 @@ TEST(build_graph, literal)
{ {
// float scalar from a float // float scalar from a float
//auto float0 = FloatConstant::make(3.0); //auto float0 = FloatConstant::make(3.0);
auto float_t = ngraph::runtime::make_tensor<element::Float32>(Shape{}, {3.0}); vector<float> float_t{3.0};
auto float0 = make_shared<op::Float32Constant>(Shape{}, float_t); auto float0 = make_shared<op::Constant>(element::f32, Shape{}, float_t);
auto float_scalar_type = make_shared<TensorViewType>(element::Float32::element_type(), Shape{}); auto float_scalar_type = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
ASSERT_EQ(float0->get_value()->get_vector(), std::vector<float>{3.0}); ASSERT_EQ(float0->get_vector<float>(), std::vector<float>{3.0});
ASSERT_EQ(*float0->get_value_type(), *float_scalar_type); ASSERT_EQ(*float0->get_value_type(), *float_scalar_type);
auto d = make_shared<op::Dot>(float0, float0); auto d = make_shared<op::Dot>(float0, float0);
ASSERT_EQ(d->get_input_ops().at(0), float0); ASSERT_EQ(d->get_input_ops().at(0), float0);
ASSERT_EQ(d->get_input_ops().at(1), float0); ASSERT_EQ(d->get_input_ops().at(1), float0);
auto int32_t = ngraph::runtime::make_tensor<element::Int32>(Shape{}, {3}); vector<int32_t> int32{3};
auto int32_0 = make_shared<op::Int32Constant>(Shape{}, int32_t); auto int32_0 = make_shared<op::Constant>(element::i32, Shape{}, int32);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{}); auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{});
ASSERT_EQ(int32_0->get_value()->get_vector(), std::vector<int>{3}); ASSERT_EQ(int32_0->get_vector<int32_t>(), std::vector<int>{3});
ASSERT_EQ(*int32_0->get_value_type(), *int32_scalar_type); ASSERT_EQ(*int32_0->get_value_type(), *int32_scalar_type);
ASSERT_NE(*int32_0->get_value_type(), *float_scalar_type); ASSERT_NE(*int32_0->get_value_type(), *float_scalar_type);
} }
...@@ -97,19 +97,19 @@ TEST(build_graph, tensor) ...@@ -97,19 +97,19 @@ TEST(build_graph, tensor)
{ {
// float scalar from a float // float scalar from a float
//auto float0 = FloatConstant::make(3.0); //auto float0 = FloatConstant::make(3.0);
auto float_t = ngraph::runtime::make_tensor<element::Float32>(Shape{2, 3}); Shape shape{2, 3};
auto float0 = make_shared<op::Float32Constant>(Shape{2, 3}, float_t); vector<float> float_t(shape_size(shape), 0);
auto float_tensor_type = auto float0 = make_shared<op::Constant>(element::f32, shape, float_t);
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 3}); auto float_tensor_type = make_shared<TensorViewType>(element::Float32::element_type(), shape);
ASSERT_EQ(*float0->get_value_type(), *float_tensor_type); ASSERT_EQ(*float0->get_value_type(), *float_tensor_type);
auto d = make_shared<op::Add>(float0, float0); auto d = make_shared<op::Add>(float0, float0);
ASSERT_EQ(d->get_input_ops().at(0), float0); ASSERT_EQ(d->get_input_ops().at(0), float0);
ASSERT_EQ(d->get_input_ops().at(1), float0); ASSERT_EQ(d->get_input_ops().at(1), float0);
auto int32_t = ngraph::runtime::make_tensor<element::Int32>(Shape{3, 5}); Shape ishape{3, 5};
auto int32_0 = make_shared<op::Int32Constant>(Shape{3, 5}, int32_t); vector<int32_t> idata(shape_size(ishape), 0);
auto int32_tensor_type = auto int32_0 = make_shared<op::Constant>(element::i32, ishape, idata);
make_shared<TensorViewType>(element::Int32::element_type(), Shape{3, 5}); auto int32_tensor_type = make_shared<TensorViewType>(element::Int32::element_type(), ishape);
ASSERT_EQ(*int32_0->get_value_type(), *int32_tensor_type); ASSERT_EQ(*int32_0->get_value_type(), *int32_tensor_type);
ASSERT_NE(*int32_0->get_value_type(), *float_tensor_type); ASSERT_NE(*int32_0->get_value_type(), *float_tensor_type);
} }
......
...@@ -130,40 +130,18 @@ TEST(copy, concat) ...@@ -130,40 +130,18 @@ TEST(copy, concat)
ASSERT_TRUE(node_cast->get_concatenation_axis() == axis); ASSERT_TRUE(node_cast->get_concatenation_axis() == axis);
} }
TEST(copy, parameterized_constant)
{
auto manager = runtime::Manager::get("NGVM");
auto backend = manager->allocate_backend();
// Create some tensors for input/output
auto c = backend->make_primary_tensor_view(element::Float32::element_type(), Shape{2, 2});
copy_data(c, test::NDArray<float, 2>({{1, 2}, {3, 4}}).get_vector());
Shape shape{2, 2};
auto cptv = dynamic_pointer_cast<ngraph::runtime::ParameterizedTensorView<element::Float32>>(c);
ASSERT_NE(cptv, nullptr);
auto node = make_shared<op::ParameterizedConstant<element::Float32>>(shape, cptv);
auto new_node = node->copy_with_new_args(Nodes{});
auto node_cast = dynamic_pointer_cast<op::ParameterizedConstant<element::Float32>>(new_node);
ASSERT_NE(node_cast, nullptr);
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(Nodes{} == new_node->get_input_ops());
ASSERT_TRUE(node_cast->get_value() == c);
ASSERT_TRUE(node_cast->get_shape() == shape);
}
TEST(copy, constant) TEST(copy, constant)
{ {
Shape shape{}; Shape shape{};
vector<string> c{"2.4"}; vector<float> c{2.4f};
auto& et = element::Float32::element_type(); auto& et = element::Float32::element_type();
auto node = make_shared<op::Constant>(et, shape, c); auto node = op::Constant::create(et, shape, c);
auto new_node = node->copy_with_new_args(Nodes{}); auto new_node = node->copy_with_new_args(Nodes{});
auto node_cast = dynamic_pointer_cast<op::Constant>(new_node); auto node_cast = dynamic_pointer_cast<op::Constant>(new_node);
ASSERT_NE(node_cast, nullptr); ASSERT_NE(node_cast, nullptr);
ASSERT_TRUE(nullptr != new_node); ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(Nodes{} == new_node->get_input_ops()); ASSERT_TRUE(Nodes{} == new_node->get_input_ops());
ASSERT_TRUE(node_cast->get_value_strings() == c); ASSERT_TRUE(node_cast->get_vector<float>() == c);
ASSERT_TRUE(node_cast->get_shape() == shape); ASSERT_TRUE(node_cast->get_shape() == shape);
ASSERT_TRUE(node_cast->get_element_type() == et); ASSERT_TRUE(node_cast->get_element_type() == et);
} }
......
...@@ -74,7 +74,7 @@ TEST(element_type, size) ...@@ -74,7 +74,7 @@ TEST(element_type, size)
EXPECT_EQ(1, t1.size()); EXPECT_EQ(1, t1.size());
} }
{ {
element::Type t1{2, false, false, ""}; element::Type t1{8, false, false, ""};
EXPECT_EQ(1, t1.size()); EXPECT_EQ(1, t1.size());
} }
{ {
......
...@@ -37,7 +37,7 @@ namespace ng = ngraph; ...@@ -37,7 +37,7 @@ namespace ng = ngraph;
TEST(liveness, constant) TEST(liveness, constant)
{ {
auto shape = Shape{1}; auto shape = Shape{1};
auto c = make_shared<op::Constant>(element::i32, shape, "5"); auto c = op::Constant::create(element::i32, shape, {5});
auto rt = make_shared<TensorViewType>(element::i32, shape); auto rt = make_shared<TensorViewType>(element::i32, shape);
auto f = make_shared<Function>(make_shared<op::Negative>(c), rt, op::Parameters{}); auto f = make_shared<Function>(make_shared<op::Negative>(c), rt, op::Parameters{});
......
...@@ -229,7 +229,7 @@ TEST(memory_layout, constant) ...@@ -229,7 +229,7 @@ TEST(memory_layout, constant)
pass_manager.register_pass<pass::DumpSorted>(dump_file); pass_manager.register_pass<pass::DumpSorted>(dump_file);
auto shape = Shape{1}; auto shape = Shape{1};
auto c = make_shared<op::Constant>(element::i32, shape, "5"); auto c = op::Constant::create(element::i32, shape, {5});
auto rt = make_shared<TensorViewType>(element::i32, shape); auto rt = make_shared<TensorViewType>(element::i32, shape);
auto f = make_shared<Function>(make_shared<op::Negative>(c), rt, op::Parameters{}); auto f = make_shared<Function>(make_shared<op::Negative>(c), rt, op::Parameters{});
......
...@@ -72,8 +72,7 @@ public: ...@@ -72,8 +72,7 @@ public:
static std::shared_ptr<Node> construct_constant_node(int n) static std::shared_ptr<Node> construct_constant_node(int n)
{ {
auto int_t = ngraph::runtime::make_tensor<element::Int32>(Shape{1}, {n}); return op::Constant::create(element::i32, Shape{1}, {n});
return make_shared<op::Int32Constant>(Shape{1}, int_t);
} }
class TestGraphRewrite : public ngraph::pass::GraphRewrite class TestGraphRewrite : public ngraph::pass::GraphRewrite
...@@ -94,19 +93,19 @@ public: ...@@ -94,19 +93,19 @@ public:
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
size_t const_node_index = m.match_root()->get_input_ops().at(0) == pattern_map[pattern]; size_t const_node_index = m.match_root()->get_input_ops().at(0) == pattern_map[pattern];
auto const_node = dynamic_pointer_cast<op::ParameterizedConstant<element::Int32>>( auto const_node = dynamic_pointer_cast<op::Constant>(
m.match_root()->get_input_ops().at(const_node_index)); m.match_root()->get_input_ops().at(const_node_index));
auto second_node = m.match_root()->get_input_ops().at(const_node_index); auto second_node = m.match_root()->get_input_ops().at(const_node_index);
NGRAPH_DEBUG << "second_node " << second_node->description() << " , " << second_node; NGRAPH_DEBUG << "second_node " << second_node->description() << " , " << second_node;
NGRAPH_DEBUG << "pattern " << pattern_map[pattern]->description() << " , " NGRAPH_DEBUG << "pattern " << pattern_map[pattern]->description() << " , "
<< pattern_map[pattern]; << pattern_map[pattern];
assert(const_node); ASSERT_TRUE(const_node);
auto pattern_value_type = auto pattern_value_type =
dynamic_pointer_cast<const TensorViewType>(pattern_map[pattern]->get_value_type()); dynamic_pointer_cast<const TensorViewType>(pattern_map[pattern]->get_value_type());
auto const_node_value_type = auto const_node_value_type =
dynamic_pointer_cast<const TensorViewType>(const_node->get_value_type()); dynamic_pointer_cast<const TensorViewType>(const_node->get_value_type());
assert(pattern_value_type && const_node); ASSERT_TRUE(pattern_value_type && const_node);
if (pattern_value_type->get_element_type() != if (pattern_value_type->get_element_type() !=
const_node_value_type->get_element_type() || const_node_value_type->get_element_type() ||
...@@ -116,7 +115,7 @@ public: ...@@ -116,7 +115,7 @@ public:
return; return;
} }
auto const_values = const_node->get_value()->get_vector(); auto const_values = const_node->get_vector<int32_t>();
bool all_ones = bool all_ones =
std::all_of(begin(const_values), end(const_values), [](int e) { return e == 1; }); std::all_of(begin(const_values), end(const_values), [](int e) { return e == 1; });
...@@ -149,19 +148,19 @@ public: ...@@ -149,19 +148,19 @@ public:
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
size_t const_node_index = m.match_root()->get_input_ops().at(0) == pattern_map[pattern]; size_t const_node_index = m.match_root()->get_input_ops().at(0) == pattern_map[pattern];
auto const_node = dynamic_pointer_cast<op::ParameterizedConstant<element::Int32>>( auto const_node = dynamic_pointer_cast<op::Constant>(
m.match_root()->get_input_ops().at(const_node_index)); m.match_root()->get_input_ops().at(const_node_index));
auto second_node = m.match_root()->get_input_ops().at(const_node_index); auto second_node = m.match_root()->get_input_ops().at(const_node_index);
NGRAPH_DEBUG << "second_node " << second_node->description() << " , " << second_node; NGRAPH_DEBUG << "second_node " << second_node->description() << " , " << second_node;
NGRAPH_DEBUG << "pattern " << pattern_map[pattern]->description() << " , " NGRAPH_DEBUG << "pattern " << pattern_map[pattern]->description() << " , "
<< pattern_map[pattern]; << pattern_map[pattern];
assert(const_node); ASSERT_NE(nullptr, const_node);
auto pattern_value_type = auto pattern_value_type =
dynamic_pointer_cast<const TensorViewType>(pattern_map[pattern]->get_value_type()); dynamic_pointer_cast<const TensorViewType>(pattern_map[pattern]->get_value_type());
auto const_node_value_type = auto const_node_value_type =
dynamic_pointer_cast<const TensorViewType>(const_node->get_value_type()); dynamic_pointer_cast<const TensorViewType>(const_node->get_value_type());
assert(pattern_value_type && const_node); ASSERT_TRUE(pattern_value_type && const_node);
if (pattern_value_type->get_element_type() != if (pattern_value_type->get_element_type() !=
const_node_value_type->get_element_type() || const_node_value_type->get_element_type() ||
...@@ -171,7 +170,7 @@ public: ...@@ -171,7 +170,7 @@ public:
return; return;
} }
auto const_values = const_node->get_value()->get_vector(); auto const_values = const_node->get_vector<int>();
bool all_zeros = bool all_zeros =
std::all_of(begin(const_values), end(const_values), [](int e) { return e == 0; }); std::all_of(begin(const_values), end(const_values), [](int e) { return e == 0; });
...@@ -342,10 +341,9 @@ TEST(pattern, matcher) ...@@ -342,10 +341,9 @@ TEST(pattern, matcher)
auto iconst1_1 = construct_constant_node(1); auto iconst1_1 = construct_constant_node(1);
ASSERT_TRUE(n.match(pattern * iconst1_0, a * iconst1_1)); //different iconst ASSERT_TRUE(n.match(pattern * iconst1_0, a * iconst1_1)); //different iconst
ASSERT_EQ(n.get_pattern_map()[pattern], a); ASSERT_EQ(n.get_pattern_map()[pattern], a);
auto fconst1_0 = auto fconst1_0 = op::Constant::create(element::Float32::element_type(), Shape{1}, {1});
make_shared<op::Constant>(element::Float32::element_type(), Shape{1}, std::to_string(1));
auto patternf = pattern::op::Label::make_from_node(fconst1_0); auto patternf = pattern::op::Label::make_from_node(fconst1_0);
ASSERT_FALSE(n.match(patternf * fconst1_0, a * iconst1_1)); //different iconst ASSERT_TRUE(n.match(patternf * fconst1_0, a * iconst1_1)); //different iconst
//Subgraph labels //Subgraph labels
auto add = a + b; auto add = a + b;
......
...@@ -1367,78 +1367,35 @@ TEST(type_prop, slice_deduce_matrix_upper_extra) ...@@ -1367,78 +1367,35 @@ TEST(type_prop, slice_deduce_matrix_upper_extra)
TEST(type_prop, scalar_constant_deduce_float32) TEST(type_prop, scalar_constant_deduce_float32)
{ {
auto c = make_shared<op::Constant>(element::Float32::element_type(), Shape{}, "208"); auto c = op::Constant::create(element::Float32::element_type(), Shape{}, {208});
ASSERT_EQ(*(c->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{})); ASSERT_EQ(*(c->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{}));
} }
TEST(type_prop, scalar_constant_deduce_bool) TEST(type_prop, scalar_constant_deduce_bool)
{ {
auto c = make_shared<op::Constant>(element::Bool::element_type(), Shape{}, "1"); auto c = op::Constant::create(element::Bool::element_type(), Shape{}, {1});
ASSERT_EQ(*(c->get_value_type()), TensorViewType(element::Bool::element_type(), Shape{})); ASSERT_EQ(*(c->get_value_type()), TensorViewType(element::Bool::element_type(), Shape{}));
} }
TEST(type_prop, tensor_constant_deduce_float32) TEST(type_prop, tensor_constant_deduce_float32)
{ {
auto c = make_shared<op::Constant>(element::Float32::element_type(), auto c =
Shape{2, 2}, op::Constant::create(element::Float32::element_type(), Shape{2, 2}, {208, 208, 208, 208});
std::vector<std::string>{"208", "208", "208", "208"});
ASSERT_EQ(*(c->get_value_type()), ASSERT_EQ(*(c->get_value_type()),
TensorViewType(element::Float32::element_type(), Shape{2, 2})); TensorViewType(element::Float32::element_type(), Shape{2, 2}));
} }
TEST(type_prop, tensor_constant_deduce_bool) TEST(type_prop, tensor_constant_deduce_bool)
{ {
auto c = make_shared<op::Constant>( auto c = op::Constant::create(element::Bool::element_type(), Shape{2, 2}, {1, 1, 1, 1});
element::Bool::element_type(), Shape{2, 2}, std::vector<std::string>{"1", "1", "1", "1"});
ASSERT_EQ(*(c->get_value_type()), TensorViewType(element::Bool::element_type(), Shape{2, 2})); ASSERT_EQ(*(c->get_value_type()), TensorViewType(element::Bool::element_type(), Shape{2, 2}));
} }
TEST(type_prop, tensor_constant_bad_parse)
{
try
{
auto c = make_shared<op::Constant>(element::Bool::element_type(),
Shape{2, 2},
std::vector<std::string>{"1", "grunk", "1", "1"});
// Should have thrown, so fail if it didn't
FAIL() << "Bad literal parse not detected";
}
catch (const runtime_error& error)
{
EXPECT_TRUE(string(error.what()).find("Could not parse literal") != string::npos);
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, tensor_constant_bad_parse_float_for_int)
{
try
{
auto c = make_shared<op::Constant>(element::Int32::element_type(),
Shape{2, 2},
std::vector<std::string>{"1", "2.7", "1", "1"});
// Should have thrown, so fail if it didn't
FAIL() << "Bad literal parse not detected";
}
catch (const runtime_error& error)
{
EXPECT_TRUE(string(error.what()).find("Could not parse literal") != string::npos);
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, tensor_constant_bad_count) TEST(type_prop, tensor_constant_bad_count)
{ {
try try
{ {
auto c = make_shared<op::Constant>( auto c = op::Constant::create(element::Bool::element_type(), Shape{2, 2}, {1, 1, 1});
element::Bool::element_type(), Shape{2, 2}, std::vector<std::string>{"1", "1", "1"});
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Incorrect number of literals not detected"; FAIL() << "Incorrect number of literals not detected";
} }
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "ngraph/runtime/parameterized_tensor_view.hpp"
#include "ngraph/types/element_type.hpp" #include "ngraph/types/element_type.hpp"
namespace ngraph namespace ngraph
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include <memory> #include <memory>
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/runtime/parameterized_tensor_view.hpp"
#include "ngraph/types/element_type.hpp" #include "ngraph/types/element_type.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
......
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include "ngraph/runtime/parameterized_tensor_view.hpp"
namespace ngraph namespace ngraph
{ {
class Node; class Node;
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include "ngraph/runtime/backend.hpp" #include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/manager.hpp" #include "ngraph/runtime/manager.hpp"
#include "ngraph/runtime/parameterized_tensor_view.hpp"
#include "ngraph/runtime/tuple.hpp" #include "ngraph/runtime/tuple.hpp"
#include "ngraph/runtime/value.hpp" #include "ngraph/runtime/value.hpp"
#include "ngraph/types/element_type.hpp" #include "ngraph/types/element_type.hpp"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment