Commit e6dad531 authored by tsocha's avatar tsocha Committed by Scott Cyphers

[ONNX] Concat operator (#1524)

* [ONNX] Concat operator

* Style fix
parent 9779dc81
......@@ -36,6 +36,8 @@ add_library(onnx_import STATIC
op/average_pool.hpp
op/batch_norm.cpp
op/batch_norm.hpp
op/concat.cpp
op/concat.hpp
op/constant.cpp
op/constant.hpp
op/conv.cpp
......
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "concat.hpp"
#include "ngraph/op/concat.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector concat(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
auto axis = node.get_attribute_value<int64_t>("axis");
return {std::make_shared<ngraph::op::Concat>(inputs, axis)};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
NodeVector concat(const Node& node);
} // namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -21,6 +21,7 @@
#include "op/add.hpp"
#include "op/average_pool.hpp"
#include "op/batch_norm.hpp"
#include "op/concat.hpp"
#include "op/constant.hpp"
#include "op/conv.hpp"
#include "op/div.hpp"
......@@ -86,6 +87,7 @@ namespace ngraph
std::bind(op::average_pool, std::placeholders::_1));
m_map.emplace("BatchNormalization",
std::bind(op::batch_norm, std::placeholders::_1));
m_map.emplace("Concat", std::bind(op::concat, std::placeholders::_1));
m_map.emplace("Constant", std::bind(op::constant, std::placeholders::_1));
m_map.emplace("Conv", std::bind(op::conv, std::placeholders::_1));
m_map.emplace("Div", std::bind(op::div, std::placeholders::_1));
......
......@@ -488,6 +488,22 @@ TEST(onnx, model_softmax)
EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front()));
}
TEST(onnx, model_concat)
{
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/concat.onnx"));
Inputs inputs;
inputs.emplace_back(test::NDArray<float, 1>({1, 2}).get_vector());
inputs.emplace_back(test::NDArray<float, 1>({3, 4}).get_vector());
Outputs expected_outputs{test::NDArray<float, 1>({1, 2, 3, 4}).get_vector()};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx, model_flatten)
{
auto function = onnx_import::import_onnx_function(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment