Commit eedea8d4 authored by Scott Cyphers's avatar Scott Cyphers Committed by Robert Kimball

Move Concat to the correct file (#539)

parent 13ac7882
......@@ -48,6 +48,7 @@ output/
*.mpg
*.cpio
*.wav
*.backup
doc/source/generated
.cache/
nervana_aeon.egg-info/
......
This diff is collapsed.
This diff is collapsed.
......@@ -38,7 +38,7 @@ set (SRC
ops/avg_pool.cpp
ops/batch_norm.cpp
ops/broadcast.cpp
ops/concatenate.cpp
ops/concat.cpp
ops/constant.cpp
ops/convert.cpp
ops/convolution.cpp
......
......@@ -73,7 +73,7 @@
#include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/ceiling.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/concat.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convert.hpp"
#include "ngraph/ops/convolution.hpp"
......
......@@ -17,7 +17,7 @@
#include <cassert>
#include <memory>
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/concat.hpp"
#include "ngraph/ops/slice.hpp"
using namespace std;
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <cassert>
#include <memory>
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/slice.hpp"
using namespace std;
using namespace ngraph;
op::Concat::Concat(const Nodes& args, size_t concatenation_axis)
: RequiresTensorViewArgs("Concat", args)
, m_concatenation_axis(concatenation_axis)
{
if (m_inputs.size() < 1)
{
throw ngraph_error("At least one argument required");
}
auto& input_0 = get_inputs().at(0);
auto input_0_shape = input_0.get_shape();
if (m_concatenation_axis >= input_0_shape.size())
{
throw ngraph_error("Concatenation axis is out of bounds");
}
size_t concatenation_axis_length = input_0_shape.at(m_concatenation_axis);
auto& input_0_element_type = input_0.get_element_type();
for (auto i = 1; i < get_inputs().size(); i++)
{
auto& input_i = get_inputs().at(i);
auto input_i_shape = input_i.get_shape();
if (input_i_shape.size() != input_0_shape.size())
{
throw ngraph_error("Arguments to concat do not have same rank");
}
if (input_i.get_element_type() != input_0_element_type)
{
throw ngraph_error("Argument element types do not match");
}
for (auto j = 0; j < input_i_shape.size(); j++)
{
if (j != m_concatenation_axis && input_0_shape.at(j) != input_i_shape.at(j))
{
throw ngraph_error(
"Arguments to concat do not have same dimension on a non-concatenation axis");
}
else if (j == m_concatenation_axis)
{
concatenation_axis_length += input_i_shape.at(j);
}
}
}
vector<size_t> concatenated_shape = input_0_shape;
concatenated_shape.at(m_concatenation_axis) = concatenation_axis_length;
set_value_type_checked(make_shared<TensorViewType>(input_0_element_type, concatenated_shape));
}
void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta)
{
auto concat_result_shape = get_outputs().at(0).get_shape();
Coordinate arg_delta_slice_lower = Coordinate(concat_result_shape.size(), 0);
Coordinate arg_delta_slice_upper = concat_result_shape;
Coordinate arg_delta_slice_strides = Coordinate(concat_result_shape.size(), 1);
size_t pos = 0;
for (auto arg : get_input_ops())
{
auto arg_shape = arg->get_shape();
auto slice_width = arg_shape[m_concatenation_axis];
size_t next_pos = pos + slice_width;
arg_delta_slice_lower[m_concatenation_axis] = pos;
arg_delta_slice_upper[m_concatenation_axis] = next_pos;
adjoints.add_delta(
arg,
make_shared<op::Slice>(
delta, arg_delta_slice_lower, arg_delta_slice_upper, arg_delta_slice_strides));
pos = next_pos;
}
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <memory>
#include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph
{
namespace op
{
/// \brief Concatenation operation.
class Concat : public util::RequiresTensorViewArgs
{
public:
/// \brief Constructs a concatenation operation.
///
/// \param args The nodes producing the input tensors.
/// \param concatenation_axis The axis along which to concatenate the input tensors.
Concat(const Nodes& args, size_t concatenation_axis);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
return std::make_shared<Concat>(new_args, m_concatenation_axis);
}
/// \return The concatenation axis.
size_t get_concatenation_axis() const { return m_concatenation_axis; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
const size_t m_concatenation_axis;
};
}
}
......@@ -34,7 +34,7 @@
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/ceiling.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/concat.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convert.hpp"
#include "ngraph/ops/convolution.hpp"
......
......@@ -43,7 +43,7 @@
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/ceiling.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/concat.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convert.hpp"
#include "ngraph/ops/convolution.hpp"
......
......@@ -29,7 +29,7 @@
#include "ngraph/node.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/concat.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/dot.hpp"
......
......@@ -45,7 +45,7 @@
#include "ngraph/ops/atan.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/ceiling.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/concat.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convert.hpp"
#include "ngraph/ops/convolution.hpp"
......
......@@ -25,7 +25,7 @@
#include "ngraph/node.hpp"
#include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/concat.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/dot.hpp"
......
......@@ -34,7 +34,7 @@
#include "ngraph/ops/asin.hpp"
#include "ngraph/ops/atan.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/concat.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convert.hpp"
#include "ngraph/ops/cos.hpp"
......
......@@ -25,7 +25,7 @@
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/ceiling.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/concat.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convert.hpp"
#include "ngraph/ops/convolution.hpp"
......
......@@ -24,7 +24,7 @@
#include "ngraph/codegen/execution_engine.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/concat.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/cpu/cpu_call_frame.hpp"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment