Commit ba8d13da authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Implement reshape in VM for tensors with rank<=2 (#184)

* Type checking for reshape op

* Formatting

* Implement reduce op for matrices and smaller in VM

* Formatting

* Missing include for Ubuntu

* Remove IndexBuiltin class
parent bbbb793e
...@@ -36,6 +36,7 @@ set (SRC ...@@ -36,6 +36,7 @@ set (SRC
ops/op.cpp ops/op.cpp
ops/parameter.cpp ops/parameter.cpp
ops/reduce.cpp ops/reduce.cpp
ops/reshape.cpp
ops/select.cpp ops/select.cpp
ops/tuple.cpp ops/tuple.cpp
ops/unary_elementwise_arithmetic.cpp ops/unary_elementwise_arithmetic.cpp
......
...@@ -77,6 +77,7 @@ ...@@ -77,6 +77,7 @@
#include "ngraph/ops/power.hpp" #include "ngraph/ops/power.hpp"
#include "ngraph/ops/reduce.hpp" #include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/remainder.hpp" #include "ngraph/ops/remainder.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/select.hpp" #include "ngraph/ops/select.hpp"
#include "ngraph/ops/subtract.hpp" #include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/tuple.hpp" #include "ngraph/ops/tuple.hpp"
......
...@@ -20,7 +20,7 @@ namespace ngraph ...@@ -20,7 +20,7 @@ namespace ngraph
{ {
namespace op namespace op
{ {
class Broadcast : public IndexBuiltin class Broadcast : public Builtin
{ {
public: public:
/// ///
...@@ -32,7 +32,7 @@ namespace ngraph ...@@ -32,7 +32,7 @@ namespace ngraph
Broadcast(const std::shared_ptr<Node>& arg, Broadcast(const std::shared_ptr<Node>& arg,
const Shape& shape, const Shape& shape,
const AxisSet& broadcast_axes) const AxisSet& broadcast_axes)
: IndexBuiltin(arg) : Builtin({arg})
, m_shape(shape) , m_shape(shape)
, m_broadcast_axes(broadcast_axes) , m_broadcast_axes(broadcast_axes)
{ {
......
...@@ -38,31 +38,6 @@ namespace ngraph ...@@ -38,31 +38,6 @@ namespace ngraph
} }
}; };
/// Index ops create a new way to index the same tensor elements
class IndexBuiltin : public Builtin
{
protected:
IndexBuiltin(const std::shared_ptr<Node>& arg)
: Builtin(Nodes{arg})
{
}
};
class Reshape : public IndexBuiltin
{
public:
Reshape(const std::shared_ptr<Node>& arg0, const Shape& shape)
: IndexBuiltin(arg0)
, m_shape(shape)
{
}
virtual std::string description() const override { return "Reshape"; }
//virtual void propagate_types() override;
protected:
Shape m_shape;
};
/// Operations where the same element function is applied to each element /// Operations where the same element function is applied to each element
/// Op(X)[I] = op(X[I]) /// Op(X)[I] = op(X[I])
class UnaryElementwiseBuiltin : public Builtin class UnaryElementwiseBuiltin : public Builtin
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ops/reshape.hpp"
#include "ngraph/function.hpp"
#include <algorithm>
using namespace std;
using namespace ngraph::op;
void Reshape::propagate_types()
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to reshape is missing type.");
}
auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type);
if (nullptr == arg_type)
{
throw ngraph_error("Argument to reshape is not a tensor view");
}
auto arg_shape = arg_tensor_view_type->get_shape();
auto arg_rank = arg_shape.size();
if (m_input_order.size() != arg_rank)
{
throw ngraph_error("Input axis order for reshape is not a permutation of argument's axes");
}
for (size_t i = 0; i < arg_rank; i++)
{
auto it = std::find(std::begin(m_input_order), std::end(m_input_order), i);
if (std::end(m_input_order) == it)
{
throw ngraph_error(
"Input axis order for reshape is not a permutation of argument's axes");
}
}
size_t arg_shape_product = 1;
for (auto i : arg_shape)
{
arg_shape_product *= i;
}
size_t output_shape_product = 1;
for (auto i : m_output_shape)
{
output_shape_product *= i;
}
if (arg_shape_product != output_shape_product)
{
throw ngraph_error(
"Product of output shape dimensions does not match product of argument shape "
"dimensions for reshape");
}
set_value_type_checked(
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_output_shape));
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph
{
namespace op
{
class Reshape : public Builtin
{
public:
///
/// @param arg The tensor view to be reshaped.
/// @param input_order The order in which to iterate over input axes. (TODO: that needs more explanation)
/// This must be a permutation of the sequence (0,...,n-1) where n is the rank of the input tensor.
/// @param output_shape The output shape. If the input shape is (a0,...,ak-1) then the output shape must
/// be of the form (b0,...,bj-1) where product(ai) == product(bi).
///
Reshape(const std::shared_ptr<Node>& arg,
const AxisVector& input_order,
const Shape& output_shape)
: Builtin({arg})
, m_input_order(input_order)
, m_output_shape(output_shape)
{
}
virtual std::string description() const override { return "Reshape"; }
virtual void propagate_types() override;
const AxisVector& get_input_order() const { return m_input_order; }
const Shape& get_output_shape() const { return m_output_shape; }
protected:
const AxisVector m_input_order;
const Shape m_output_shape;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace eigen
{
template <typename ET>
class MatrixTransposeInstruction : public Instruction
{
public:
MatrixTransposeInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenMatrix<ET>(call_frame, m_out) =
EigenMatrix<ET>(call_frame, m_arg).transpose();
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
}
...@@ -44,6 +44,7 @@ ...@@ -44,6 +44,7 @@
#include "ngraph/ops/negative.hpp" #include "ngraph/ops/negative.hpp"
#include "ngraph/ops/not_equal.hpp" #include "ngraph/ops/not_equal.hpp"
#include "ngraph/ops/reduce.hpp" #include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/select.hpp" #include "ngraph/ops/select.hpp"
#include "ngraph/ops/subtract.hpp" #include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/tuple.hpp" #include "ngraph/ops/tuple.hpp"
...@@ -71,6 +72,7 @@ ...@@ -71,6 +72,7 @@
#include "ngraph/runtime/ngvm/eigen/less_than.hpp" #include "ngraph/runtime/ngvm/eigen/less_than.hpp"
#include "ngraph/runtime/ngvm/eigen/log.hpp" #include "ngraph/runtime/ngvm/eigen/log.hpp"
#include "ngraph/runtime/ngvm/eigen/matrix_mult.hpp" #include "ngraph/runtime/ngvm/eigen/matrix_mult.hpp"
#include "ngraph/runtime/ngvm/eigen/matrix_transpose.hpp"
#include "ngraph/runtime/ngvm/eigen/matrix_vector_product.hpp" #include "ngraph/runtime/ngvm/eigen/matrix_vector_product.hpp"
#include "ngraph/runtime/ngvm/eigen/maximum.hpp" #include "ngraph/runtime/ngvm/eigen/maximum.hpp"
#include "ngraph/runtime/ngvm/eigen/multiply.hpp" #include "ngraph/runtime/ngvm/eigen/multiply.hpp"
...@@ -777,6 +779,59 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -777,6 +779,59 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
} }
}; };
REGISTER_TO_OP_MAP(op::Reshape)
{
auto reshape = static_cast<const op::Reshape*>(n);
auto arg_type = reshape->get_arguments().at(0)->get_value_type();
auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type);
assert(nullptr != arg_tensor_view_type);
auto arg_shape = arg_tensor_view_type->get_shape();
auto arg_rank = arg_shape.size();
auto result_type = reshape->get_value_type();
auto result_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(result_type);
assert(nullptr != result_tensor_view_type);
auto result_shape = result_tensor_view_type->get_shape();
auto& result_element_type = result_tensor_view_type->get_element_type();
auto input_order = reshape->get_input_order();
bool same_layout = std::is_sorted(input_order.begin(), input_order.end());
size_t result_shape_product = 1;
for (auto i : result_shape)
{
result_shape_product *= i;
}
// If there is no layout change or we are just going from 1^n to 1^m or a zero-size tensor, we can just copy.
if (same_layout || result_shape_product < 2)
{
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Reshape has unhandled element type",
runtime::ngvm::eigen::CopyInstruction,
in.at(0).get_index(),
out.at(0).get_index());
}
// If there *is* a layout change in the 2D case, we transpose the input.
else if (arg_rank == 2)
{
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Reshape has unhandled element type",
runtime::ngvm::eigen::MatrixTransposeInstruction,
in[0],
out[0]);
}
// Other cases (reordering of axes for tensors with rank>2) are not handled yet.
else
{
throw ngraph_error(
"Axis permutation in reshape is not implemented yet for tensors with rank>2 in "
"VM");
}
};
initialized = true; initialized = true;
} }
return op_map; return op_map;
......
...@@ -1620,3 +1620,243 @@ TEST(execute, reduce_matrix_to_scalar_zero_by_zero) ...@@ -1620,3 +1620,243 @@ TEST(execute, reduce_matrix_to_scalar_zero_by_zero)
ASSERT_EQ((vector<float>{}), a->get_vector()); ASSERT_EQ((vector<float>{}), a->get_vector());
ASSERT_EQ((vector<float>{99}), b->get_vector()); ASSERT_EQ((vector<float>{99}), b->get_vector());
} }
TEST(type_prop, reshape_t2v_012)
{
auto shape_a = Shape{2, 2, 3};
auto A = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_a));
auto shape_r = Shape{12};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto r = make_shared<op::Reshape>(A, AxisVector{0, 1, 2}, shape_r);
auto f = make_shared<Function>(r, rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape_a);
*a = vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape_r);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), result->get_vector());
}
TEST(type_prop, reshape_t2s_012)
{
auto shape_a = Shape{1, 1, 1};
auto A = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_a));
auto shape_r = Shape{};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto r = make_shared<op::Reshape>(A, AxisVector{0, 1, 2}, shape_r);
auto f = make_shared<Function>(r, rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape_a);
*a = vector<float>{6};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape_r);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{6}), result->get_vector());
}
TEST(type_prop, reshape_t2s_120)
{
auto shape_a = Shape{1, 1, 1};
auto A = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_a));
auto shape_r = Shape{};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto r = make_shared<op::Reshape>(A, AxisVector{1, 2, 0}, shape_r);
auto f = make_shared<Function>(r, rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape_a);
*a = vector<float>{6};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape_r);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{6}), result->get_vector());
}
TEST(type_prop, reshape_s2t)
{
auto shape_a = Shape{};
auto A = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_a));
auto shape_r = Shape{1, 1, 1, 1, 1, 1};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto r = make_shared<op::Reshape>(A, AxisVector{}, shape_r);
auto f = make_shared<Function>(r, rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape_a);
*a = vector<float>{42};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape_r);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{42}), result->get_vector());
}
TEST(type_prop, reshape_v2m_col)
{
auto shape_a = Shape{3};
auto A = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_a));
auto shape_r = Shape{3, 1};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto r = make_shared<op::Reshape>(A, AxisVector{0}, shape_r);
auto f = make_shared<Function>(r, rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape_a);
*a = vector<float>{1, 2, 3};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape_r);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{1, 2, 3}), result->get_vector());
}
TEST(type_prop, reshape_v2m_row)
{
auto shape_a = Shape{3};
auto A = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_a));
auto shape_r = Shape{1, 3};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto r = make_shared<op::Reshape>(A, AxisVector{0}, shape_r);
auto f = make_shared<Function>(r, rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape_a);
*a = vector<float>{1, 2, 3};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape_r);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{1, 2, 3}), result->get_vector());
}
TEST(type_prop, reshape_v2t_middle)
{
auto shape_a = Shape{3};
auto A = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_a));
auto shape_r = Shape{1, 3, 1};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto r = make_shared<op::Reshape>(A, AxisVector{0}, shape_r);
auto f = make_shared<Function>(r, rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape_a);
*a = vector<float>{1, 2, 3};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape_r);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{1, 2, 3}), result->get_vector());
}
TEST(type_prop, reshape_m2m_same)
{
auto shape_a = Shape{3, 3};
auto A = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_a));
auto shape_r = Shape{3, 3};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto r = make_shared<op::Reshape>(A, AxisVector{0, 1}, shape_r);
auto f = make_shared<Function>(r, rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape_a);
*a = vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape_r);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9}), result->get_vector());
}
TEST(type_prop, reshape_m2m_transpose)
{
auto shape_a = Shape{3, 3};
auto A = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_a));
auto shape_r = Shape{3, 3};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto r = make_shared<op::Reshape>(A, AxisVector{1, 0}, shape_r);
auto f = make_shared<Function>(r, rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape_a);
*a = vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape_r);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{1, 4, 7, 2, 5, 8, 3, 6, 9}), result->get_vector());
}
TEST(type_prop, reshape_m2m_dim_change_transpose)
{
auto shape_a = Shape{3, 2};
auto A = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_a));
auto shape_r = Shape{2, 3};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto r = make_shared<op::Reshape>(A, AxisVector{1, 0}, shape_r);
auto f = make_shared<Function>(r, rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape_a);
*a = vector<float>{1, 2, 3, 4, 5, 6};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape_r);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{1, 3, 5, 2, 4, 6}), result->get_vector());
}
...@@ -1041,3 +1041,197 @@ TEST(type_prop, function_call_deduce) ...@@ -1041,3 +1041,197 @@ TEST(type_prop, function_call_deduce)
auto r_p_r_vt = r_p_r->get_value_type(); auto r_p_r_vt = r_p_r->get_value_type();
ASSERT_EQ(*r_p_r_vt, TensorViewType(element::Float32::element_type(), shape)); ASSERT_EQ(*r_p_r_vt, TensorViewType(element::Float32::element_type(), shape));
} }
TEST(type_prop, reshape_deduce_s2v)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{}));
auto r = make_shared<op::Reshape>(param, AxisVector{}, Shape{1});
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{1}));
}
TEST(type_prop, reshape_deduce_s2m)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{}));
auto r = make_shared<op::Reshape>(param, AxisVector{}, Shape{1, 1});
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()),
TensorViewType(element::Float32::element_type(), Shape{1, 1}));
}
TEST(type_prop, reshape_deduce_s2t)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{}));
auto r = make_shared<op::Reshape>(param, AxisVector{}, Shape{1, 1, 1});
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()),
TensorViewType(element::Float32::element_type(), Shape{1, 1, 1}));
}
TEST(type_prop, reshape_deduce_v2s)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{1}));
auto r = make_shared<op::Reshape>(param, AxisVector{0}, Shape{});
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{}));
}
TEST(type_prop, reshape_deduce_m2s)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{1, 1}));
auto r = make_shared<op::Reshape>(param, AxisVector{0, 1}, Shape{});
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{}));
}
TEST(type_prop, reshape_deduce_t2s)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{1, 1, 1}));
auto r = make_shared<op::Reshape>(param, AxisVector{0, 1, 2}, Shape{});
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{}));
}
TEST(type_prop, reshape_deduce_m2v_01)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 4}));
auto r = make_shared<op::Reshape>(param, AxisVector{0, 1}, Shape{12});
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{12}));
}
TEST(type_prop, reshape_deduce_m2v_10)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 4}));
auto r = make_shared<op::Reshape>(param, AxisVector{1, 0}, Shape{12});
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{12}));
}
TEST(type_prop, reshape_deduce_t2v_012)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 4, 5}));
auto r = make_shared<op::Reshape>(param, AxisVector{0, 1, 2}, Shape{60});
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{60}));
}
TEST(type_prop, reshape_deduce_t2v_120)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 4, 5}));
auto r = make_shared<op::Reshape>(param, AxisVector{1, 2, 0}, Shape{60});
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{60}));
}
TEST(type_prop, reshape_deduce_correct_t2v_120)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 4, 5}));
auto r = make_shared<op::Reshape>(param, AxisVector{1, 2, 0}, Shape{60});
r->set_value_type(make_shared<TensorViewType>(element::Float32::element_type(), Shape{60}));
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{60}));
}
TEST(type_prop, reshape_deduce_not_enough_axes)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 4, 5}));
auto r = make_shared<op::Reshape>(param, AxisVector{1, 0}, Shape{60});
try
{
r->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Not enough axes not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(
error.what(),
std::string("Input axis order for reshape is not a permutation of argument's axes"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, reshape_deduce_too_many_axes)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 4, 5}));
auto r = make_shared<op::Reshape>(param, AxisVector{1, 2, 0, 3}, Shape{60});
try
{
r->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Too many axes not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(
error.what(),
std::string("Input axis order for reshape is not a permutation of argument's axes"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, reshape_deduce_duplicate_axes)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 4, 5}));
auto r = make_shared<op::Reshape>(param, AxisVector{1, 1, 0}, Shape{60});
try
{
r->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Too many axes not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(
error.what(),
std::string("Input axis order for reshape is not a permutation of argument's axes"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, reshape_deduce_wrong_output_shape)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 4, 5}));
auto r = make_shared<op::Reshape>(param, AxisVector{1, 2, 0}, Shape{3, 3, 3});
try
{
r->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Too many axes not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(),
std::string("Product of output shape dimensions does not match "
"product of argument shape dimensions for reshape"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment