Unverified Commit cba4e54e authored by Adam Procter's avatar Adam Procter Committed by GitHub

Generalized dot (#291)

* WIP generalized dot

* Add some multi-axis 3D, 4D, and 5D tests

* Add test on some 'pretty big' tensors

* Reworked dot to have less flexible axis-pairing behavior

* Backprop for dot... and a fix for a dumb bug in CoordinateTransform

* Forgot to commit some stuff in merge

* Disable tests that currently don't work on CPU

* Fix temporarily disabled test that should pass on NGVM and INTERPRETER but wasn't due to new axis-selection convention for dot

* Remove obsolete ScalarTensorProduct kernel/instruction

* Review comment

* s/n_dot_axes/dot_axis_count/

* s/dot_axis_count/reduction_axes_count/

* Adapt CPU emitter dot fallback to new kernel
parent a960f07e
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include "ngraph/common.hpp"
using namespace ngraph;
......@@ -37,31 +39,50 @@ Shape ngraph::project_shape(const Shape& shape, const AxisSet& deleted_axes)
return project_coordinate(shape, deleted_axes);
}
// TODO: for the moment, just one axis at a time, please. Later could pass in an std::map from axis positions to axis lengths.
// TODO: check validity, i.e. that the new axis is < coord_size+1.
Coordinate
ngraph::inject_coordinate(const Coordinate& coord, size_t new_axis_pos, size_t new_axis_val)
// TODO: check validity, i.e. that the new axis indices are all < coord_size+num_new_axes.
Coordinate ngraph::inject_coordinate(const Coordinate& coord,
std::vector<std::pair<size_t, size_t>> new_axis_pos_val_pairs)
{
Coordinate result;
size_t original_pos = 0;
for (size_t result_pos = 0; result_pos < coord.size() + 1; result_pos++)
for (size_t result_pos = 0; result_pos < coord.size() + new_axis_pos_val_pairs.size();
result_pos++)
{
if (result_pos == new_axis_pos)
auto search_it = std::find_if(
new_axis_pos_val_pairs.begin(),
new_axis_pos_val_pairs.end(),
[result_pos](std::pair<size_t, size_t> p) { return p.first == result_pos; });
if (search_it == new_axis_pos_val_pairs.end())
{
result.push_back(new_axis_val);
result.push_back(coord[original_pos++]);
}
else
{
result.push_back(coord[original_pos++]);
result.push_back(search_it->second);
}
}
return result;
}
Coordinate
ngraph::inject_coordinate(const Coordinate& coord, size_t new_axis_pos, size_t new_axis_val)
{
return inject_coordinate(coord,
std::vector<std::pair<size_t, size_t>>{
std::pair<size_t, size_t>(new_axis_pos, new_axis_val)});
}
Shape ngraph::inject_shape(const Shape& shape, size_t new_axis_pos, size_t new_axis_length)
{
return inject_coordinate(shape, new_axis_pos, new_axis_length);
}
Shape inject_shape(const Shape& shape,
std::vector<std::pair<size_t, size_t>> new_axis_pos_length_pairs)
{
return inject_coordinate(shape, new_axis_pos_length_pairs);
}
......@@ -16,6 +16,7 @@
#include <memory>
#include <set>
#include <utility>
#include <vector>
// Names for types that aren't worth giving their own classes
......@@ -56,5 +57,9 @@ namespace ngraph
Shape project_shape(const Shape& shape, const AxisSet& deleted_axes);
Coordinate inject_coordinate(const Coordinate& coord, size_t new_axis_pos, size_t new_axis_val);
Coordinate inject_coordinate(const Coordinate& coord,
std::vector<std::pair<size_t, size_t>> new_axis_pos_val_pairs);
Shape inject_shape(const Shape& shape, size_t new_axis_pos, size_t new_axis_length);
Shape inject_shape(const Shape& shape,
std::vector<std::pair<size_t, size_t>> new_axis_pos_length_pairs);
}
......@@ -209,8 +209,8 @@ Coordinate CoordinateTransform::to_source_coordinate(const Coordinate& c) const
for (size_t axis = 0; axis < m_n_axes; axis++)
{
result[axis] = c[m_source_axis_order[axis]] * m_source_strides[m_source_axis_order[axis]] +
m_source_start_corner[m_source_axis_order[axis]];
result[m_source_axis_order[axis]] =
c[axis] * m_source_strides[axis] + m_source_start_corner[axis];
}
return result;
......
This diff is collapsed.
This diff is collapsed.
......@@ -24,6 +24,7 @@
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/function_call.hpp"
#include "ngraph/ops/get_tuple_element.hpp"
#include "ngraph/ops/one_hot.hpp"
......@@ -137,29 +138,7 @@ void runtime::cpu::CPU_Emitter::EmitDot(const ngraph::Node* n,
}
else
{
size_t arg0_dot_axis;
size_t arg1_dot_axis;
if (arg0_shape.size() == 1 && arg1_shape.size() == 1)
{
arg0_dot_axis = 0;
arg1_dot_axis = 0;
}
// If arg0 is a matrix and arg1 is a vector, dot on axes 1 and 0 respectively.
else if (arg0_shape.size() == 2 && arg1_shape.size() == 1)
{
arg0_dot_axis = 1;
arg1_dot_axis = 0;
}
// If arg0 is rank n and arg1 is rank m, dot on axes n-1 and m-2, respectively.
//
// Note that this happens to handle the vector-matrix and matrix-matrix cases.
else
{
arg0_dot_axis = arg0_shape.size() - 1;
arg1_dot_axis = arg1_shape.size() - 2;
}
const ngraph::op::Dot* dot = static_cast<const ngraph::op::Dot*>(n);
m_out << "kernel::dot(" << args[0].get_name() << ",\n";
m_out << " " << args[1].get_name() << ",\n";
......@@ -167,8 +146,7 @@ void runtime::cpu::CPU_Emitter::EmitDot(const ngraph::Node* n,
m_out << " {" << join(args[0].get_shape()) << "},\n";
m_out << " {" << join(args[1].get_shape()) << "},\n";
m_out << " {" << join(out[0].get_shape()) << "},\n";
m_out << " " << arg0_dot_axis << ",\n";
m_out << " " << arg1_dot_axis << ");\n";
m_out << " " << dot->get_reduction_axes_count() << ");\n";
}
}
......
......@@ -23,6 +23,7 @@
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/get_tuple_element.hpp"
#include "ngraph/ops/one_hot.hpp"
#include "ngraph/ops/reduce.hpp"
......@@ -66,7 +67,6 @@
#include "ngraph/runtime/kernel/reduce.hpp"
#include "ngraph/runtime/kernel/replace_slice.hpp"
#include "ngraph/runtime/kernel/reshape.hpp"
#include "ngraph/runtime/kernel/scalar_tensor_product.hpp"
#include "ngraph/runtime/kernel/select.hpp"
#include "ngraph/runtime/kernel/sign.hpp"
#include "ngraph/runtime/kernel/sin.hpp"
......@@ -290,54 +290,15 @@ private:
}
else if (node_op == "Dot")
{
if (args[0]->get_shape().size() == 0)
{
kernel::scalar_tensor_product(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (args[1]->get_shape().size() == 0)
{
kernel::scalar_tensor_product(reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else
{
size_t arg0_dot_axis;
size_t arg1_dot_axis;
if (args[0]->get_shape().size() == 1 && args[1]->get_shape().size() == 1)
{
arg0_dot_axis = 0;
arg1_dot_axis = 0;
}
ngraph::op::Dot* dot = dynamic_cast<ngraph::op::Dot*>(&node);
// If arg0 is a matrix and arg1 is a vector, dot on axes 1 and 0 respectively.
else if (args[0]->get_shape().size() == 2 && args[1]->get_shape().size() == 1)
{
arg0_dot_axis = 1;
arg1_dot_axis = 0;
}
// If arg0 is rank n and arg1 is rank m, dot on axes n-1 and m-2, respectively.
//
// Note that this happens to handle the vector-matrix and matrix-matrix cases.
else
{
arg0_dot_axis = args[0]->get_shape().size() - 1;
arg1_dot_axis = args[1]->get_shape().size() - 2;
}
kernel::dot(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
args[1]->get_shape(),
out[0]->get_shape(),
arg0_dot_axis,
arg1_dot_axis);
}
kernel::dot(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
args[1]->get_shape(),
out[0]->get_shape(),
dot->get_reduction_axes_count());
}
else if (node_op == "Equal")
......
......@@ -15,6 +15,7 @@
#pragma once
#include <cmath>
#include <utility>
#include "ngraph/common.hpp"
#include "ngraph/coordinate_transform.hpp"
......@@ -32,49 +33,86 @@ namespace ngraph
const Shape& arg0_shape,
const Shape& arg1_shape,
const Shape& out_shape,
size_t arg0_dot_axis,
size_t arg1_dot_axis)
size_t reduction_axes_count)
{
CoordinateTransform output_transform(out_shape);
for (Coordinate out_coord : output_transform)
{
out[output_transform.index(out_coord)] = 0;
}
// Get the sizes of the dot axes. It's easiest to pull them from arg1 because they're
// right up front.
Shape dot_axis_sizes(reduction_axes_count);
std::copy(arg1_shape.begin(),
arg1_shape.begin() + reduction_axes_count,
dot_axis_sizes.begin());
CoordinateTransform arg0_transform(arg0_shape);
CoordinateTransform arg1_transform(arg1_shape);
CoordinateTransform output_transform(out_shape);
// Create coordinate transforms for arg0 and arg1 that throw away the dotted axes.
size_t arg0_projected_rank = arg0_shape.size() - reduction_axes_count;
size_t arg1_projected_rank = arg1_shape.size() - reduction_axes_count;
Shape arg0_projected_shape(arg0_projected_rank);
std::copy(arg0_shape.begin(),
arg0_shape.begin() + arg0_projected_rank,
arg0_projected_shape.begin());
Shape arg1_projected_shape(arg1_projected_rank);
std::copy(arg1_shape.begin() + reduction_axes_count,
arg1_shape.end(),
arg1_projected_shape.begin());
CoordinateTransform arg0_projected_transform(arg0_projected_shape);
CoordinateTransform arg1_projected_transform(arg1_projected_shape);
CoordinateTransform arg0_projected_transform(
project_shape(arg0_shape, AxisSet{arg0_dot_axis}));
CoordinateTransform arg1_projected_transform(
project_shape(arg1_shape, AxisSet{arg1_dot_axis}));
// Create a coordinate transform that allows us to iterate over all possible values
// for the dotted axes.
CoordinateTransform dot_axes_transform(dot_axis_sizes);
for (Coordinate arg0_projected_coord : arg0_projected_transform)
{
for (Coordinate arg1_projected_coord : arg1_projected_transform)
{
for (size_t i = 0; i < arg0_shape[arg0_dot_axis]; i++)
// The output coordinate is just the concatenation of the projected coordinates.
Coordinate out_coord(arg0_projected_coord.size() +
arg1_projected_coord.size());
auto out_coord_it = std::copy(arg0_projected_coord.begin(),
arg0_projected_coord.end(),
out_coord.begin());
std::copy(
arg1_projected_coord.begin(), arg1_projected_coord.end(), out_coord_it);
// Zero out to start the sum.
T sum = 0;
size_t out_index = output_transform.index(out_coord);
// Walk along the dotted axes.
for (Coordinate dot_axis_positions : dot_axes_transform)
{
Coordinate arg0_coord =
inject_coordinate(arg0_projected_coord, arg0_dot_axis, i);
Coordinate arg1_coord =
inject_coordinate(arg1_projected_coord, arg1_dot_axis, i);
Coordinate out_coord(arg0_projected_coord.size() +
arg1_projected_coord.size());
std::copy(arg0_projected_coord.begin(),
arg0_projected_coord.end(),
out_coord.begin());
std::copy(arg1_projected_coord.begin(),
arg1_projected_coord.end(),
out_coord.begin() + arg0_projected_coord.size());
out[output_transform.index(out_coord)] +=
arg0[arg0_transform.index(arg0_coord)] *
arg1[arg1_transform.index(arg1_coord)];
// In order to find the points to multiply together, we need to inject our current
// positions along the dotted axes back into the projected arg0 and arg1 coordinates.
Coordinate arg0_coord(arg0_shape.size());
Coordinate arg1_coord(arg1_shape.size());
auto arg0_it = std::copy(arg0_projected_coord.begin(),
arg0_projected_coord.end(),
arg0_coord.begin());
std::copy(
dot_axis_positions.begin(), dot_axis_positions.end(), arg0_it);
auto arg1_it = std::copy(dot_axis_positions.begin(),
dot_axis_positions.end(),
arg1_coord.begin());
std::copy(
arg1_projected_coord.begin(), arg1_projected_coord.end(), arg1_it);
// Multiply and add to the sum.
sum += arg0[arg0_transform.index(arg0_coord)] *
arg1[arg1_transform.index(arg1_coord)];
}
// Write the sum back.
out[out_index] = sum;
}
}
}
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
#include "ngraph/common.hpp"
#include "ngraph/coordinate_transform.hpp"
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void scalar_tensor_product(T* arg0, // the scalar (TODO: just pass as T?)
T* arg1, // the tensor
T* out,
size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = (*arg0) * arg1[i];
}
}
}
}
}
......@@ -109,7 +109,6 @@
#include "ngraph/runtime/ngvm/instruction/replace_slice.hpp"
#include "ngraph/runtime/ngvm/instruction/reshape.hpp"
#include "ngraph/runtime/ngvm/instruction/return.hpp"
#include "ngraph/runtime/ngvm/instruction/scalar_tensor_product.hpp"
#include "ngraph/runtime/ngvm/instruction/select.hpp"
#include "ngraph/runtime/ngvm/instruction/sign.hpp"
#include "ngraph/runtime/ngvm/instruction/sin.hpp"
......@@ -352,8 +351,6 @@ std::vector<typename ET::type>
}
#define PUSH_POLYMORPHIC_INSTRUCTION(et, err_msg, instr, ...) \
DO_ON_ELEMENT_TYPE(et, err_msg, PUSH_INSTRUCTION, instr, __VA_ARGS__)
#define PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(et, err_msg, instr, ...) \
DO_ON_NUMERIC_TYPE(et, err_msg, PUSH_INSTRUCTION, instr, __VA_ARGS__)
// Turn off complaint suppression (see above)
#pragma clang diagnostic pop
......@@ -550,6 +547,8 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
REGISTER_TO_OP_MAP(op::Dot)
{
auto dot = static_cast<const op::Dot*>(n);
auto& arg_nodes = n->get_arguments();
assert(arg_nodes.size() == 2);
......@@ -566,81 +565,24 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
auto arg1_shape = arg1_tensor_type->get_shape();
auto& arg0_element_type = arg0_tensor_type->get_element_type();
auto reduction_axes_count = dot->get_reduction_axes_count();
auto result_tensor_type =
dynamic_pointer_cast<const TensorViewType>(n->get_value_type());
assert(nullptr != result_tensor_type);
auto result_shape = result_tensor_type->get_shape();
// If arg0 or arg1 is a scalar, emit a scalar-tensor product.
if (arg0_shape.size() == 0)
{
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type",
instruction::ScalarTensorProductInstruction,
in[0],
in[1],
out[0]);
}
else if (arg1_shape.size() == 0)
{
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type",
instruction::ScalarTensorProductInstruction,
in[1],
in[0],
out[0]);
}
// If arg0 and arg1 are both vectors, dot both on axis 0.
else if (arg0_shape.size() == 1 && arg1_shape.size() == 1)
{
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type",
instruction::DotInstruction,
in[0],
in[1],
out[0],
arg0_shape,
arg1_shape,
result_shape,
0,
0);
}
// If arg0 is a matrix and arg1 is a vector, dot on axes 1 and 0 respectively.
else if (arg0_shape.size() == 2 && arg1_shape.size() == 1)
{
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type",
instruction::DotInstruction,
in[0],
in[1],
out[0],
arg0_shape,
arg1_shape,
result_shape,
1,
0);
}
// If arg0 is rank n and arg1 is rank m, dot on axes n-1 and m-2, respectively.
//
// Note that this happens to handle the vector-matrix and matrix-matrix cases.
else
{
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type",
instruction::DotInstruction,
in[0],
in[1],
out[0],
arg0_shape,
arg1_shape,
result_shape,
arg0_shape.size() - 1,
arg1_shape.size() - 2);
}
PUSH_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type",
instruction::DotInstruction,
in[0],
in[1],
out[0],
arg0_shape,
arg1_shape,
result_shape,
reduction_axes_count);
};
// Parameter is a "runtime no-op" because the output tensor has already been filled.
......
......@@ -38,16 +38,14 @@ namespace ngraph
const Shape& arg0_shape,
const Shape& arg1_shape,
const Shape& out_shape,
size_t arg0_dot_axis,
size_t arg1_dot_axis)
size_t reduction_axes_count)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
, m_arg0_shape(arg0_shape)
, m_arg1_shape(arg1_shape)
, m_out_shape(out_shape)
, m_arg0_dot_axis(arg0_dot_axis)
, m_arg1_dot_axis(arg1_dot_axis)
, m_reduction_axes_count(reduction_axes_count)
{
}
......@@ -63,8 +61,7 @@ namespace ngraph
m_arg0_shape,
m_arg1_shape,
m_out_shape,
m_arg0_dot_axis,
m_arg1_dot_axis);
m_reduction_axes_count);
}
protected:
......@@ -74,8 +71,7 @@ namespace ngraph
Shape m_arg0_shape;
Shape m_arg1_shape;
Shape m_out_shape;
size_t m_arg0_dot_axis;
size_t m_arg1_dot_axis;
size_t m_reduction_axes_count;
};
}
}
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/kernel/scalar_tensor_product.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace instruction
{
template <typename ET>
class ScalarTensorProductInstruction : public Instruction
{
public:
ScalarTensorProductInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg1);
kernel::scalar_tensor_product<typename ET::type>(arg0, arg1, out, count);
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -521,6 +521,28 @@ TEST(${BACKEND_NAME}, backwards_dot_tensor2_tensor2)
autodiff_numeric_compare<float>(manager, backend, make_graph, {x0, x1}, .01f, .01f));
}
TEST(${BACKEND_NAME}, backwards_dot_tensor3_tensor3)
{
auto manager = runtime::Manager::get("NGVM");
auto backend = manager->allocate_backend();
test::Uniform<float> rng(-1.0f, 1.0f);
auto shape0 = Shape{2, 4, 3};
auto shape1 = Shape{4, 3, 3};
auto x0 = rng.initialize(backend->make_primary_tensor_view<float>(shape0));
auto x1 = rng.initialize(backend->make_primary_tensor_view<float>(shape1));
auto make_graph = [shape0, shape1]() {
auto X0 = make_shared<op::Parameter>(element::Float32::element_type(), shape0);
auto X1 = make_shared<op::Parameter>(element::Float32::element_type(), shape1);
return make_shared<Function>(make_shared<op::Dot>(X0, X1, 2),
nullptr,
std::vector<std::shared_ptr<op::Parameter>>{X0, X1});
};
EXPECT_TRUE(
autodiff_numeric_compare<float>(manager, backend, make_graph, {x0, x1}, .01f, .01f));
}
TEST(${BACKEND_NAME}, backwards_exp)
{
auto manager = runtime::Manager::get("${BACKEND_NAME}");
......
This diff is collapsed.
......@@ -32,15 +32,6 @@ TEST(type_prop, broadcast_deduce)
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 3, 4}));
}
TEST(type_prop, broadcast_deduce_correct)
{
// Check deduced type against correctly specified type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 4});
auto bc = make_shared<op::Broadcast>(param, Shape{2, 3, 4}, AxisSet{1});
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 3, 4}));
}
TEST(type_prop, broadcast_deduce_incorrect)
{
// Check deduced type against incorrectly specified type
......@@ -228,15 +219,6 @@ TEST(type_prop, convert_deduce)
ASSERT_EQ(*c_vt, TensorViewType(element::Int32::element_type(), Shape{2, 3, 4}));
}
TEST(type_prop, convert_deduce_correct)
{
// Check deduced type against incorrectly specified type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3, 4});
auto c = make_shared<op::Convert>(param, element::Int32::element_type());
auto c_vt = c->get_value_type();
ASSERT_EQ(*c_vt, TensorViewType(element::Int32::element_type(), Shape{2, 3, 4}));
}
TEST(type_prop, convert_deduce_incorrect)
{
// Check deduced type against incorrectly specified type
......@@ -322,17 +304,7 @@ TEST(type_prop, dot_deduce_different_rank)
{
// Deduce type for different-rank tensor arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 8, 4, 2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1, 2, 3});
auto bc = make_shared<op::Dot>(param1, param2);
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 8, 4, 1, 3}));
}
TEST(type_prop, dot_deduce_different_rank_correct)
{
// Deduced type matches explicitly set type
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 8, 4, 2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1, 2, 3});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 1, 3});
auto bc = make_shared<op::Dot>(param1, param2);
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 8, 4, 1, 3}));
......@@ -372,7 +344,7 @@ TEST(type_prop, dot_deduce_reduction_axes_size_mismatch)
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Dot reduction axes not compatible"));
EXPECT_EQ(error.what(), std::string("Dot axes do not have same length"));
}
catch (...)
{
......@@ -571,19 +543,6 @@ TEST(type_prop, select_deduce)
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 4}));
}
TEST(type_prop, select_deduce_correct)
{
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Bool::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_2 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 4}));
}
TEST(type_prop, select_shape_mismatch_a)
{
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
......@@ -735,24 +694,6 @@ TEST(type_prop, reduce_deduce)
TensorViewType(element::Float32::element_type(), Shape{2, 4}));
}
TEST(type_prop, reduce_deduce_correct)
{
auto param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{}));
auto f_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{}));
auto f_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{}));
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
auto f = make_shared<Function>(f_param_0 + f_param_1, rt, op::Parameters{f_param_0, f_param_1});
auto r0 = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{0});
ASSERT_EQ(*(r0->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{4}));
}
TEST(type_prop, reduce_nonscalar)
{
auto param_0 = make_shared<op::Parameter>(
......@@ -1073,14 +1014,6 @@ TEST(type_prop, reshape_deduce_t2v_120)
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{60}));
}
TEST(type_prop, reshape_deduce_correct_t2v_120)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 4, 5}));
auto r = make_shared<op::Reshape>(param, AxisVector{1, 2, 0}, Shape{60});
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{60}));
}
TEST(type_prop, reshape_deduce_not_enough_axes)
{
auto param = make_shared<op::Parameter>(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment