Commit 394b74fa authored by Artur Wojcik's avatar Artur Wojcik Committed by Michał Karzyński

[ONNX] re-enable unit tests for CentOS (#1517)

parent c4970542
...@@ -23,10 +23,10 @@ namespace ngraph ...@@ -23,10 +23,10 @@ namespace ngraph
{ {
std::vector<Graph> Attribute::get_graph_array() const std::vector<Graph> Attribute::get_graph_array() const
{ {
return {std::begin(m_attribute_proto.graphs()), std::end(m_attribute_proto.graphs())}; return {std::begin(m_attribute_proto->graphs()), std::end(m_attribute_proto->graphs())};
} }
Graph Attribute::get_graph() const { return Graph{m_attribute_proto.g()}; } Graph Attribute::get_graph() const { return Graph{m_attribute_proto->g()}; }
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -38,8 +38,8 @@ namespace ngraph ...@@ -38,8 +38,8 @@ namespace ngraph
{ {
struct Attribute : ngraph_error struct Attribute : ngraph_error
{ {
Attribute(std::string msg, onnx::AttributeProto_AttributeType type) Attribute(const std::string& msg, onnx::AttributeProto_AttributeType type)
: ngraph_error{std::move(msg) + ": " + : ngraph_error{msg + ": " +
onnx::AttributeProto_AttributeType_Name(type)} onnx::AttributeProto_AttributeType_Name(type)}
{ {
} }
...@@ -246,7 +246,7 @@ namespace ngraph ...@@ -246,7 +246,7 @@ namespace ngraph
Attribute() = delete; Attribute() = delete;
explicit Attribute(const onnx::AttributeProto& attribute_proto) explicit Attribute(const onnx::AttributeProto& attribute_proto)
: m_attribute_proto{attribute_proto} : m_attribute_proto{&attribute_proto}
{ {
} }
...@@ -256,8 +256,8 @@ namespace ngraph ...@@ -256,8 +256,8 @@ namespace ngraph
Attribute& operator=(Attribute&&) noexcept = delete; Attribute& operator=(Attribute&&) noexcept = delete;
Attribute& operator=(const Attribute&) = delete; Attribute& operator=(const Attribute&) = delete;
const std::string& get_name() const { return m_attribute_proto.name(); } const std::string& get_name() const { return m_attribute_proto->name(); }
Type get_type() const { return static_cast<Type>(m_attribute_proto.type()); } Type get_type() const { return static_cast<Type>(m_attribute_proto->type()); }
bool is_tensor() const { return get_type() == Type::tensor; } bool is_tensor() const { return get_type() == Type::tensor; }
bool is_tensor_array() const { return get_type() == Type::tensor_array; } bool is_tensor_array() const { return get_type() == Type::tensor_array; }
bool is_float() const { return get_type() == Type::float_point; } bool is_float() const { return get_type() == Type::float_point; }
...@@ -268,50 +268,50 @@ namespace ngraph ...@@ -268,50 +268,50 @@ namespace ngraph
bool is_string_array() const { return get_type() == Type::string_array; } bool is_string_array() const { return get_type() == Type::string_array; }
bool is_graph() const { return get_type() == Type::graph; } bool is_graph() const { return get_type() == Type::graph; }
bool is_graph_array() const { return get_type() == Type::graph_array; } bool is_graph_array() const { return get_type() == Type::graph_array; }
Tensor get_tensor() const { return Tensor{m_attribute_proto.t()}; } Tensor get_tensor() const { return Tensor{m_attribute_proto->t()}; }
float get_float() const { return m_attribute_proto.f(); } float get_float() const { return m_attribute_proto->f(); }
int64_t get_integer() const { return m_attribute_proto.i(); } int64_t get_integer() const { return m_attribute_proto->i(); }
const std::string& get_string() const { return m_attribute_proto.s(); } const std::string& get_string() const { return m_attribute_proto->s(); }
Graph get_graph() const; Graph get_graph() const;
std::vector<Tensor> get_tensor_array() const std::vector<Tensor> get_tensor_array() const
{ {
return {std::begin(m_attribute_proto.tensors()), return {std::begin(m_attribute_proto->tensors()),
std::end(m_attribute_proto.tensors())}; std::end(m_attribute_proto->tensors())};
} }
std::vector<float> get_float_array() const std::vector<float> get_float_array() const
{ {
return {std::begin(m_attribute_proto.floats()), return {std::begin(m_attribute_proto->floats()),
std::end(m_attribute_proto.floats())}; std::end(m_attribute_proto->floats())};
} }
std::vector<int64_t> get_integer_array() const std::vector<int64_t> get_integer_array() const
{ {
return {std::begin(m_attribute_proto.ints()), std::end(m_attribute_proto.ints())}; return {std::begin(m_attribute_proto->ints()), std::end(m_attribute_proto->ints())};
} }
std::vector<std::string> get_string_array() const std::vector<std::string> get_string_array() const
{ {
return {std::begin(m_attribute_proto.strings()), return {std::begin(m_attribute_proto->strings()),
std::end(m_attribute_proto.strings())}; std::end(m_attribute_proto->strings())};
} }
std::vector<Graph> get_graph_array() const; std::vector<Graph> get_graph_array() const;
/* explicit */ operator onnx::AttributeProto_AttributeType() const /* explicit */ operator onnx::AttributeProto_AttributeType() const
{ {
return m_attribute_proto.type(); return m_attribute_proto->type();
} }
template <typename T> template <typename T>
T get_value() const T get_value() const
{ {
return detail::attribute::get_value<T>(m_attribute_proto); return detail::attribute::get_value<T>(*m_attribute_proto);
} }
private: private:
const onnx::AttributeProto& m_attribute_proto; const onnx::AttributeProto* m_attribute_proto;
}; };
} // namespace onnx_import } // namespace onnx_import
......
...@@ -22,9 +22,9 @@ namespace ngraph ...@@ -22,9 +22,9 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
Graph::Graph(const onnx::GraphProto& graph_proto) Graph::Graph(const onnx::GraphProto& graph_proto)
: m_graph_proto(graph_proto) : m_graph_proto{&graph_proto}
{ {
for (const auto& tensor : m_graph_proto.initializer()) for (const auto& tensor : m_graph_proto->initializer())
{ {
if (tensor.has_name()) if (tensor.has_name())
{ {
...@@ -33,20 +33,20 @@ namespace ngraph ...@@ -33,20 +33,20 @@ namespace ngraph
} }
// Process all ONNX graph inputs, convert them to nGraph nodes and store in cache // Process all ONNX graph inputs, convert them to nGraph nodes and store in cache
for (const auto& input : m_graph_proto.input()) for (const auto& input : m_graph_proto->input())
{ {
m_inputs.emplace_back(input); m_inputs.emplace_back(input);
m_ng_node_cache[input.name()] = m_ng_node_cache[input.name()] =
m_inputs.back().get_ng_node(m_parameters, m_initializers); m_inputs.back().get_ng_node(m_parameters, m_initializers);
} }
for (const auto& output : m_graph_proto.output()) for (const auto& output : m_graph_proto->output())
{ {
m_outputs.emplace_back(output); m_outputs.emplace_back(output);
} }
// Process ONNX graph nodes, convert to nGraph nodes // Process ONNX graph nodes, convert to nGraph nodes
for (const auto& node_proto : m_graph_proto.node()) for (const auto& node_proto : m_graph_proto->node())
{ {
m_nodes.emplace_back(node_proto, this); m_nodes.emplace_back(node_proto, this);
const Node& node{m_nodes.back()}; const Node& node{m_nodes.back()};
......
...@@ -42,9 +42,9 @@ namespace ngraph ...@@ -42,9 +42,9 @@ namespace ngraph
return m_ng_node_cache.at(name); return m_ng_node_cache.at(name);
} }
const std::string& get_name() const { return m_graph_proto.name(); } const std::string& get_name() const { return m_graph_proto->name(); }
private: private:
const onnx::GraphProto& m_graph_proto; const onnx::GraphProto* m_graph_proto;
std::vector<Node> m_nodes; std::vector<Node> m_nodes;
std::vector<ValueInfo> m_inputs; std::vector<ValueInfo> m_inputs;
std::vector<ValueInfo> m_outputs; std::vector<ValueInfo> m_outputs;
......
...@@ -28,7 +28,7 @@ namespace ngraph ...@@ -28,7 +28,7 @@ namespace ngraph
public: public:
Model() = delete; Model() = delete;
explicit Model(const onnx::ModelProto& model_proto) explicit Model(const onnx::ModelProto& model_proto)
: m_model_proto{model_proto} : m_model_proto{&model_proto}
{ {
} }
...@@ -38,16 +38,16 @@ namespace ngraph ...@@ -38,16 +38,16 @@ namespace ngraph
Model& operator=(Model&&) noexcept = delete; Model& operator=(Model&&) noexcept = delete;
Model& operator=(const Model&) = delete; Model& operator=(const Model&) = delete;
const std::string& get_producer_name() const { return m_model_proto.producer_name(); } const std::string& get_producer_name() const { return m_model_proto->producer_name(); }
const onnx::GraphProto& get_graph() const { return m_model_proto.graph(); } const onnx::GraphProto& get_graph() const { return m_model_proto->graph(); }
std::int64_t get_model_version() const { return m_model_proto.model_version(); } std::int64_t get_model_version() const { return m_model_proto->model_version(); }
const std::string& get_producer_version() const const std::string& get_producer_version() const
{ {
return m_model_proto.producer_version(); return m_model_proto->producer_version();
} }
private: private:
const onnx::ModelProto& m_model_proto; const onnx::ModelProto* m_model_proto;
}; };
inline std::ostream& operator<<(std::ostream& outs, const Model& model) inline std::ostream& operator<<(std::ostream& outs, const Model& model)
......
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
NodeVector Node::get_ng_inputs() const NodeVector Node::get_ng_inputs() const
{ {
NodeVector result; NodeVector result;
for (const auto& name : m_node_proto.input()) for (const auto& name : m_node_proto->input())
{ {
result.push_back(m_graph->get_ng_node_from_cache(name)); result.push_back(m_graph->get_ng_node_from_cache(name));
} }
......
...@@ -53,7 +53,7 @@ namespace ngraph ...@@ -53,7 +53,7 @@ namespace ngraph
public: public:
Node() = delete; Node() = delete;
Node(const onnx::NodeProto& node_proto, const Graph* graph) Node(const onnx::NodeProto& node_proto, const Graph* graph)
: m_node_proto{node_proto} : m_node_proto{&node_proto}
, m_graph{graph} , m_graph{graph}
, m_attributes{std::begin(node_proto.attribute()), std::end(node_proto.attribute())} , m_attributes{std::begin(node_proto.attribute()), std::end(node_proto.attribute())}
, m_output_names{std::begin(node_proto.output()), std::end(node_proto.output())} , m_output_names{std::begin(node_proto.output()), std::end(node_proto.output())}
...@@ -70,13 +70,13 @@ namespace ngraph ...@@ -70,13 +70,13 @@ namespace ngraph
NodeVector get_ng_nodes() const; NodeVector get_ng_nodes() const;
NodeVector get_ng_inputs() const; NodeVector get_ng_inputs() const;
const std::string& op_type() const { return m_node_proto.op_type(); } const std::string& op_type() const { return m_node_proto->op_type(); }
const std::string& get_name() const { return m_node_proto.name(); } const std::string& get_name() const { return m_node_proto->name(); }
const std::vector<std::reference_wrapper<const std::string>>& get_output_names() const const std::vector<std::reference_wrapper<const std::string>>& get_output_names() const
{ {
return m_output_names; return m_output_names;
} }
const std::string& output(int index) const { return m_node_proto.output(index); } const std::string& output(int index) const { return m_node_proto->output(index); }
template <typename T> template <typename T>
T get_attribute_value(const std::string& name, T default_value) const T get_attribute_value(const std::string& name, T default_value) const
{ {
...@@ -106,7 +106,7 @@ namespace ngraph ...@@ -106,7 +106,7 @@ namespace ngraph
} }
private: private:
const onnx::NodeProto& m_node_proto; const onnx::NodeProto* m_node_proto;
const Graph* m_graph; const Graph* m_graph;
std::vector<Attribute> m_attributes; std::vector<Attribute> m_attributes;
std::vector<std::reference_wrapper<const std::string>> m_output_names; std::vector<std::reference_wrapper<const std::string>> m_output_names;
......
...@@ -198,7 +198,7 @@ namespace ngraph ...@@ -198,7 +198,7 @@ namespace ngraph
Tensor() = delete; Tensor() = delete;
explicit Tensor(const onnx::TensorProto& tensor) explicit Tensor(const onnx::TensorProto& tensor)
: m_tensor_proto{tensor} : m_tensor_proto{&tensor}
, m_shape{std::begin(tensor.dims()), std::end(tensor.dims())} , m_shape{std::begin(tensor.dims()), std::end(tensor.dims())}
{ {
} }
...@@ -213,34 +213,34 @@ namespace ngraph ...@@ -213,34 +213,34 @@ namespace ngraph
template <typename T> template <typename T>
std::vector<T> get_data() const std::vector<T> get_data() const
{ {
return detail::tensor::get_data<T>(m_tensor_proto); return detail::tensor::get_data<T>(*m_tensor_proto);
} }
const std::string& get_name() const const std::string& get_name() const
{ {
if (!m_tensor_proto.has_name()) if (!m_tensor_proto->has_name())
{ {
throw error::tensor::unspecified_name{}; throw error::tensor::unspecified_name{};
} }
return m_tensor_proto.name(); return m_tensor_proto->name();
} }
Type get_type() const Type get_type() const
{ {
if (!m_tensor_proto.has_data_type()) if (!m_tensor_proto->has_data_type())
{ {
throw error::tensor::unspecified_data_type{}; throw error::tensor::unspecified_data_type{};
} }
return static_cast<Type>(m_tensor_proto.data_type()); return static_cast<Type>(m_tensor_proto->data_type());
} }
const element::Type& get_ng_type() const const element::Type& get_ng_type() const
{ {
if (!m_tensor_proto.has_data_type()) if (!m_tensor_proto->has_data_type())
{ {
throw error::tensor::unspecified_data_type{}; throw error::tensor::unspecified_data_type{};
} }
switch (m_tensor_proto.data_type()) switch (m_tensor_proto->data_type())
{ {
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: return element::boolean; case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: return element::boolean;
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT: case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
...@@ -254,13 +254,13 @@ namespace ngraph ...@@ -254,13 +254,13 @@ namespace ngraph
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: return element::u16; case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: return element::u16;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: return element::u32; case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: return element::u32;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: return element::u64; case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: return element::u64;
default: throw error::tensor::unsupported_data_type{m_tensor_proto.data_type()}; default: throw error::tensor::unsupported_data_type{m_tensor_proto->data_type()};
} }
} }
operator onnx::TensorProto_DataType() const { return m_tensor_proto.data_type(); } operator onnx::TensorProto_DataType() const { return m_tensor_proto->data_type(); }
private: private:
const onnx::TensorProto& m_tensor_proto; const onnx::TensorProto* m_tensor_proto;
Shape m_shape; Shape m_shape;
}; };
......
...@@ -60,7 +60,7 @@ namespace ngraph ...@@ -60,7 +60,7 @@ namespace ngraph
ValueInfo() = delete; ValueInfo() = delete;
explicit ValueInfo(const onnx::ValueInfoProto& value_info_proto) explicit ValueInfo(const onnx::ValueInfoProto& value_info_proto)
: m_value_info_proto{value_info_proto} : m_value_info_proto{&value_info_proto}
{ {
if (value_info_proto.type().has_tensor_type()) if (value_info_proto.type().has_tensor_type())
{ {
...@@ -74,15 +74,15 @@ namespace ngraph ...@@ -74,15 +74,15 @@ namespace ngraph
ValueInfo& operator=(const ValueInfo&) = delete; ValueInfo& operator=(const ValueInfo&) = delete;
ValueInfo& operator=(ValueInfo&&) = delete; ValueInfo& operator=(ValueInfo&&) = delete;
const std::string& get_name() const { return m_value_info_proto.name(); } const std::string& get_name() const { return m_value_info_proto->name(); }
const Shape& get_shape() const { return m_shape; } const Shape& get_shape() const { return m_shape; }
const element::Type& get_element_type() const const element::Type& get_element_type() const
{ {
if (!m_value_info_proto.type().tensor_type().has_elem_type()) if (!m_value_info_proto->type().tensor_type().has_elem_type())
{ {
throw error::value_info::unspecified_element_type{}; throw error::value_info::unspecified_element_type{};
} }
switch (m_value_info_proto.type().tensor_type().elem_type()) switch (m_value_info_proto->type().tensor_type().elem_type())
{ {
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: return element::boolean; case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: return element::boolean;
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT: case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
...@@ -98,7 +98,7 @@ namespace ngraph ...@@ -98,7 +98,7 @@ namespace ngraph
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: return element::u64; case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: return element::u64;
default: default:
throw error::value_info::unsupported_element_type{ throw error::value_info::unsupported_element_type{
m_value_info_proto.type().tensor_type().elem_type()}; m_value_info_proto->type().tensor_type().elem_type()};
} }
} }
...@@ -126,7 +126,7 @@ namespace ngraph ...@@ -126,7 +126,7 @@ namespace ngraph
std::shared_ptr<op::Constant> get_ng_constant(const Tensor& tensor) const std::shared_ptr<op::Constant> get_ng_constant(const Tensor& tensor) const
{ {
switch (m_value_info_proto.type().tensor_type().elem_type()) switch (m_value_info_proto->type().tensor_type().elem_type())
{ {
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
return make_ng_constant<bool>(element::boolean, tensor); return make_ng_constant<bool>(element::boolean, tensor);
...@@ -153,7 +153,7 @@ namespace ngraph ...@@ -153,7 +153,7 @@ namespace ngraph
return make_ng_constant<uint64_t>(element::u64, tensor); return make_ng_constant<uint64_t>(element::u64, tensor);
default: default:
throw error::value_info::unsupported_element_type{ throw error::value_info::unsupported_element_type{
m_value_info_proto.type().tensor_type().elem_type()}; m_value_info_proto->type().tensor_type().elem_type()};
} }
} }
...@@ -165,7 +165,7 @@ namespace ngraph ...@@ -165,7 +165,7 @@ namespace ngraph
} }
private: private:
const onnx::ValueInfoProto& m_value_info_proto; const onnx::ValueInfoProto* m_value_info_proto;
Shape m_shape; Shape m_shape;
}; };
......
...@@ -47,15 +47,7 @@ set(SRC ...@@ -47,15 +47,7 @@ set(SRC
) )
if (NGRAPH_ONNX_IMPORT_ENABLE) if (NGRAPH_ONNX_IMPORT_ENABLE)
if (APPLE OR WIN32)
list(APPEND SRC onnx_import.cpp) list(APPEND SRC onnx_import.cpp)
else()
# ONNX unit tests temporarly disabled if CentOS detected
# (Protobuf issue with interpreting messages)
if (NOT ${DISTRIB_ID} STREQUAL "CentOS Linux")
list(APPEND SRC onnx_import.cpp)
endif()
endif()
endif() endif()
if (NGRAPH_INTERPRETER_ENABLE) if (NGRAPH_INTERPRETER_ENABLE)
......
...@@ -56,12 +56,12 @@ TEST(onnx, model_add_abc_initializers) ...@@ -56,12 +56,12 @@ TEST(onnx, model_add_abc_initializers)
TEST(onnx, model_addmul_abc) TEST(onnx, model_addmul_abc)
{ {
auto function = ngraph::onnx_import::import_onnx_function( auto function = onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/addmul_abc.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/addmul_abc.onnx"));
std::vector<std::vector<float>> inputs; std::vector<std::vector<float>> inputs;
ngraph::Shape shape{1, 2, 2}; Shape shape{1, 2, 2};
inputs.emplace_back(test::NDArray<float, 3>({{{9, 10}}, {{11, 12}}}).get_vector()); inputs.emplace_back(test::NDArray<float, 3>({{{9, 10}}, {{11, 12}}}).get_vector());
inputs.emplace_back(test::NDArray<float, 3>({{{5, 6}}, {{7, 8}}}).get_vector()); inputs.emplace_back(test::NDArray<float, 3>({{{5, 6}}, {{7, 8}}}).get_vector());
inputs.emplace_back(test::NDArray<float, 3>({{{1, 2}}, {{3, 4}}}).get_vector()); inputs.emplace_back(test::NDArray<float, 3>({{{1, 2}}, {{3, 4}}}).get_vector());
...@@ -124,8 +124,7 @@ TEST(onnx, model_split_variable_parts_2d) ...@@ -124,8 +124,7 @@ TEST(onnx, model_split_variable_parts_2d)
namespace namespace
{ {
std::vector<std::vector<float>> std::vector<std::vector<float>> conv2d_execute(const std::shared_ptr<Function>& function)
conv2d_execute(const std::shared_ptr<ngraph::Function>& function)
{ {
std::vector<std::vector<float>> args; std::vector<std::vector<float>> args;
...@@ -151,8 +150,8 @@ namespace ...@@ -151,8 +150,8 @@ namespace
TEST(onnx, model_conv2d_strides_padding) TEST(onnx, model_conv2d_strides_padding)
{ {
// Convolution with strides=2 and padding=1 // Convolution with strides=2 and padding=1
auto function = ngraph::onnx_import::import_onnx_function( auto function = onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/conv_with_strides_padding.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/conv_with_strides_padding.onnx"));
// (1, 1, 4, 3) // (1, 1, 4, 3)
auto expected_output = test::NDArray<float, 4>({{{{12.f, 27.f, 24.f}, auto expected_output = test::NDArray<float, 4>({{{{12.f, 27.f, 24.f},
...@@ -168,8 +167,8 @@ TEST(onnx, model_conv2d_strides_padding) ...@@ -168,8 +167,8 @@ TEST(onnx, model_conv2d_strides_padding)
TEST(onnx, model_conv2d_strides_no_padding) TEST(onnx, model_conv2d_strides_no_padding)
{ {
// Convolution with strides=2 and padding=1 // Convolution with strides=2 and padding=1
auto function = ngraph::onnx_import::import_onnx_function( auto function = onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/conv_with_strides_no_padding.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/conv_with_strides_no_padding.onnx"));
// (1, 1, 3, 2) // (1, 1, 3, 2)
auto expected_output = auto expected_output =
...@@ -182,8 +181,8 @@ TEST(onnx, model_conv2d_strides_no_padding) ...@@ -182,8 +181,8 @@ TEST(onnx, model_conv2d_strides_no_padding)
TEST(onnx, model_conv2d_strides_assymetric_padding) TEST(onnx, model_conv2d_strides_assymetric_padding)
{ {
// Convolution with strides=2 and padding=1 // Convolution with strides=2 and padding=1
auto function = ngraph::onnx_import::import_onnx_function(ngraph::file_util::path_join( auto function = onnx_import::import_onnx_function(
SERIALIZED_ZOO, "onnx/conv_with_strides_and_asymmetric_padding.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/conv_with_strides_and_asymmetric_padding.onnx"));
// (1, 1, 4, 2) // (1, 1, 4, 2)
auto expected_output = auto expected_output =
...@@ -297,8 +296,8 @@ TEST(onnx, model_batchnorm_default) ...@@ -297,8 +296,8 @@ TEST(onnx, model_batchnorm_default)
TEST(onnx, model_relu) TEST(onnx, model_relu)
{ {
// Simple ReLU test // Simple ReLU test
auto function = ngraph::onnx_import::import_onnx_function( auto function =
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/relu.onnx")); onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/relu.onnx"));
Inputs inputs{{-1, -2, 0, 1, 2, 3}}; Inputs inputs{{-1, -2, 0, 1, 2, 3}};
Outputs expected_outputs{{0, 0, 0, 1, 2, 3}}; Outputs expected_outputs{{0, 0, 0, 1, 2, 3}};
...@@ -385,11 +384,10 @@ TEST(onnx, model_mean) ...@@ -385,11 +384,10 @@ TEST(onnx, model_mean)
TEST(onnx, model_gemm_abc) TEST(onnx, model_gemm_abc)
{ {
auto function = ngraph::onnx_import::import_onnx_function( auto function = onnx_import::import_onnx_function(
ngraph::file_util::path_join(SERIALIZED_ZOO, "onnx/gemm_abc.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/gemm_abc.onnx"));
std::vector<std::vector<float>> inputs;
Inputs inputs;
inputs.emplace_back(test::NDArray<float, 2>( inputs.emplace_back(test::NDArray<float, 2>(
{{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}, {13, 14, 15, 16, 17, 18}}) {{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}, {13, 14, 15, 16, 17, 18}})
.get_vector()); .get_vector());
...@@ -405,13 +403,13 @@ TEST(onnx, model_gemm_abc) ...@@ -405,13 +403,13 @@ TEST(onnx, model_gemm_abc)
inputs.emplace_back( inputs.emplace_back(
test::NDArray<float, 2>({{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}}).get_vector()); test::NDArray<float, 2>({{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}}).get_vector());
auto expected_output = Outputs expected_outputs{
test::NDArray<float, 2>( test::NDArray<float, 2>(
{{340, 350.5, 361, 371.5}, {862, 890.5, 919, 947.5}, {1384, 1430.5, 1477, 1523.5}}) {{340, 350.5, 361, 371.5}, {862, 890.5, 919, 947.5}, {1384, 1430.5, 1477, 1523.5}})
.get_vector(); .get_vector()};
auto result_vectors = execute(function, inputs, "INTERPRETER"); Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front())); EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
} }
TEST(onnx, model_matmul) TEST(onnx, model_matmul)
...@@ -428,11 +426,11 @@ TEST(onnx, model_matmul) ...@@ -428,11 +426,11 @@ TEST(onnx, model_matmul)
test::NDArray<float, 2>({{13, 14, 15}, {16, 17, 18}, {19, 20, 21}, {22, 23, 24}}) test::NDArray<float, 2>({{13, 14, 15}, {16, 17, 18}, {19, 20, 21}, {22, 23, 24}})
.get_vector()); .get_vector());
auto expected_output = Outputs expected_outputs{
test::NDArray<float, 2>({{190, 200, 210}, {470, 496, 522}, {750, 792, 834}}).get_vector(); test::NDArray<float, 2>({{190, 200, 210}, {470, 496, 522}, {750, 792, 834}}).get_vector()};
auto result_vectors = execute(function, inputs, "INTERPRETER"); Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front())); EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
} }
TEST(onnx, model_softmax) TEST(onnx, model_softmax)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment