Unverified Commit cba4e54e authored by Adam Procter's avatar Adam Procter Committed by GitHub

Generalized dot (#291)

* WIP generalized dot

* Add some multi-axis 3D, 4D, and 5D tests

* Add test on some 'pretty big' tensors

* Reworked dot to have less flexible axis-pairing behavior

* Backprop for dot... and a fix for a dumb bug in CoordinateTransform

* Forgot to commit some stuff in merge

* Disable tests that currently don't work on CPU

* Fix temporarily disabled test that should pass on NGVM and INTERPRETER but wasn't due to new axis-selection convention for dot

* Remove obsolete ScalarTensorProduct kernel/instruction

* Review comment

* s/n_dot_axes/dot_axis_count/

* s/dot_axis_count/reduction_axes_count/

* Adapt CPU emitter dot fallback to new kernel
parent a960f07e
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <algorithm>
#include "ngraph/common.hpp" #include "ngraph/common.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -37,31 +39,50 @@ Shape ngraph::project_shape(const Shape& shape, const AxisSet& deleted_axes) ...@@ -37,31 +39,50 @@ Shape ngraph::project_shape(const Shape& shape, const AxisSet& deleted_axes)
return project_coordinate(shape, deleted_axes); return project_coordinate(shape, deleted_axes);
} }
// TODO: for the moment, just one axis at a time, please. Later could pass in an std::map from axis positions to axis lengths. // TODO: check validity, i.e. that the new axis indices are all < coord_size+num_new_axes.
// TODO: check validity, i.e. that the new axis is < coord_size+1. Coordinate ngraph::inject_coordinate(const Coordinate& coord,
Coordinate std::vector<std::pair<size_t, size_t>> new_axis_pos_val_pairs)
ngraph::inject_coordinate(const Coordinate& coord, size_t new_axis_pos, size_t new_axis_val)
{ {
Coordinate result; Coordinate result;
size_t original_pos = 0; size_t original_pos = 0;
for (size_t result_pos = 0; result_pos < coord.size() + 1; result_pos++) for (size_t result_pos = 0; result_pos < coord.size() + new_axis_pos_val_pairs.size();
result_pos++)
{ {
if (result_pos == new_axis_pos) auto search_it = std::find_if(
new_axis_pos_val_pairs.begin(),
new_axis_pos_val_pairs.end(),
[result_pos](std::pair<size_t, size_t> p) { return p.first == result_pos; });
if (search_it == new_axis_pos_val_pairs.end())
{ {
result.push_back(new_axis_val); result.push_back(coord[original_pos++]);
} }
else else
{ {
result.push_back(coord[original_pos++]); result.push_back(search_it->second);
} }
} }
return result; return result;
} }
Coordinate
ngraph::inject_coordinate(const Coordinate& coord, size_t new_axis_pos, size_t new_axis_val)
{
return inject_coordinate(coord,
std::vector<std::pair<size_t, size_t>>{
std::pair<size_t, size_t>(new_axis_pos, new_axis_val)});
}
Shape ngraph::inject_shape(const Shape& shape, size_t new_axis_pos, size_t new_axis_length) Shape ngraph::inject_shape(const Shape& shape, size_t new_axis_pos, size_t new_axis_length)
{ {
return inject_coordinate(shape, new_axis_pos, new_axis_length); return inject_coordinate(shape, new_axis_pos, new_axis_length);
} }
Shape inject_shape(const Shape& shape,
std::vector<std::pair<size_t, size_t>> new_axis_pos_length_pairs)
{
return inject_coordinate(shape, new_axis_pos_length_pairs);
}
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include <set> #include <set>
#include <utility>
#include <vector> #include <vector>
// Names for types that aren't worth giving their own classes // Names for types that aren't worth giving their own classes
...@@ -56,5 +57,9 @@ namespace ngraph ...@@ -56,5 +57,9 @@ namespace ngraph
Shape project_shape(const Shape& shape, const AxisSet& deleted_axes); Shape project_shape(const Shape& shape, const AxisSet& deleted_axes);
Coordinate inject_coordinate(const Coordinate& coord, size_t new_axis_pos, size_t new_axis_val); Coordinate inject_coordinate(const Coordinate& coord, size_t new_axis_pos, size_t new_axis_val);
Coordinate inject_coordinate(const Coordinate& coord,
std::vector<std::pair<size_t, size_t>> new_axis_pos_val_pairs);
Shape inject_shape(const Shape& shape, size_t new_axis_pos, size_t new_axis_length); Shape inject_shape(const Shape& shape, size_t new_axis_pos, size_t new_axis_length);
Shape inject_shape(const Shape& shape,
std::vector<std::pair<size_t, size_t>> new_axis_pos_length_pairs);
} }
...@@ -209,8 +209,8 @@ Coordinate CoordinateTransform::to_source_coordinate(const Coordinate& c) const ...@@ -209,8 +209,8 @@ Coordinate CoordinateTransform::to_source_coordinate(const Coordinate& c) const
for (size_t axis = 0; axis < m_n_axes; axis++) for (size_t axis = 0; axis < m_n_axes; axis++)
{ {
result[axis] = c[m_source_axis_order[axis]] * m_source_strides[m_source_axis_order[axis]] + result[m_source_axis_order[axis]] =
m_source_start_corner[m_source_axis_order[axis]]; c[axis] * m_source_strides[axis] + m_source_start_corner[axis];
} }
return result; return result;
......
This diff is collapsed.
This diff is collapsed.
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/concatenate.hpp" #include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp" #include "ngraph/ops/constant.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/function_call.hpp" #include "ngraph/ops/function_call.hpp"
#include "ngraph/ops/get_tuple_element.hpp" #include "ngraph/ops/get_tuple_element.hpp"
#include "ngraph/ops/one_hot.hpp" #include "ngraph/ops/one_hot.hpp"
...@@ -137,29 +138,7 @@ void runtime::cpu::CPU_Emitter::EmitDot(const ngraph::Node* n, ...@@ -137,29 +138,7 @@ void runtime::cpu::CPU_Emitter::EmitDot(const ngraph::Node* n,
} }
else else
{ {
size_t arg0_dot_axis; const ngraph::op::Dot* dot = static_cast<const ngraph::op::Dot*>(n);
size_t arg1_dot_axis;
if (arg0_shape.size() == 1 && arg1_shape.size() == 1)
{
arg0_dot_axis = 0;
arg1_dot_axis = 0;
}
// If arg0 is a matrix and arg1 is a vector, dot on axes 1 and 0 respectively.
else if (arg0_shape.size() == 2 && arg1_shape.size() == 1)
{
arg0_dot_axis = 1;
arg1_dot_axis = 0;
}
// If arg0 is rank n and arg1 is rank m, dot on axes n-1 and m-2, respectively.
//
// Note that this happens to handle the vector-matrix and matrix-matrix cases.
else
{
arg0_dot_axis = arg0_shape.size() - 1;
arg1_dot_axis = arg1_shape.size() - 2;
}
m_out << "kernel::dot(" << args[0].get_name() << ",\n"; m_out << "kernel::dot(" << args[0].get_name() << ",\n";
m_out << " " << args[1].get_name() << ",\n"; m_out << " " << args[1].get_name() << ",\n";
...@@ -167,8 +146,7 @@ void runtime::cpu::CPU_Emitter::EmitDot(const ngraph::Node* n, ...@@ -167,8 +146,7 @@ void runtime::cpu::CPU_Emitter::EmitDot(const ngraph::Node* n,
m_out << " {" << join(args[0].get_shape()) << "},\n"; m_out << " {" << join(args[0].get_shape()) << "},\n";
m_out << " {" << join(args[1].get_shape()) << "},\n"; m_out << " {" << join(args[1].get_shape()) << "},\n";
m_out << " {" << join(out[0].get_shape()) << "},\n"; m_out << " {" << join(out[0].get_shape()) << "},\n";
m_out << " " << arg0_dot_axis << ",\n"; m_out << " " << dot->get_reduction_axes_count() << ");\n";
m_out << " " << arg1_dot_axis << ");\n";
} }
} }
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/concatenate.hpp" #include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp" #include "ngraph/ops/constant.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/get_tuple_element.hpp" #include "ngraph/ops/get_tuple_element.hpp"
#include "ngraph/ops/one_hot.hpp" #include "ngraph/ops/one_hot.hpp"
#include "ngraph/ops/reduce.hpp" #include "ngraph/ops/reduce.hpp"
...@@ -66,7 +67,6 @@ ...@@ -66,7 +67,6 @@
#include "ngraph/runtime/kernel/reduce.hpp" #include "ngraph/runtime/kernel/reduce.hpp"
#include "ngraph/runtime/kernel/replace_slice.hpp" #include "ngraph/runtime/kernel/replace_slice.hpp"
#include "ngraph/runtime/kernel/reshape.hpp" #include "ngraph/runtime/kernel/reshape.hpp"
#include "ngraph/runtime/kernel/scalar_tensor_product.hpp"
#include "ngraph/runtime/kernel/select.hpp" #include "ngraph/runtime/kernel/select.hpp"
#include "ngraph/runtime/kernel/sign.hpp" #include "ngraph/runtime/kernel/sign.hpp"
#include "ngraph/runtime/kernel/sin.hpp" #include "ngraph/runtime/kernel/sin.hpp"
...@@ -290,54 +290,15 @@ private: ...@@ -290,54 +290,15 @@ private:
} }
else if (node_op == "Dot") else if (node_op == "Dot")
{ {
if (args[0]->get_shape().size() == 0) ngraph::op::Dot* dot = dynamic_cast<ngraph::op::Dot*>(&node);
{
kernel::scalar_tensor_product(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (args[1]->get_shape().size() == 0)
{
kernel::scalar_tensor_product(reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else
{
size_t arg0_dot_axis;
size_t arg1_dot_axis;
if (args[0]->get_shape().size() == 1 && args[1]->get_shape().size() == 1)
{
arg0_dot_axis = 0;
arg1_dot_axis = 0;
}
// If arg0 is a matrix and arg1 is a vector, dot on axes 1 and 0 respectively.
else if (args[0]->get_shape().size() == 2 && args[1]->get_shape().size() == 1)
{
arg0_dot_axis = 1;
arg1_dot_axis = 0;
}
// If arg0 is rank n and arg1 is rank m, dot on axes n-1 and m-2, respectively.
//
// Note that this happens to handle the vector-matrix and matrix-matrix cases.
else
{
arg0_dot_axis = args[0]->get_shape().size() - 1;
arg1_dot_axis = args[1]->get_shape().size() - 2;
}
kernel::dot(reinterpret_cast<T*>(args[0]->get_data_ptr()), kernel::dot(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()), reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()), reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(), args[0]->get_shape(),
args[1]->get_shape(), args[1]->get_shape(),
out[0]->get_shape(), out[0]->get_shape(),
arg0_dot_axis, dot->get_reduction_axes_count());
arg1_dot_axis);
}
} }
else if (node_op == "Equal") else if (node_op == "Equal")
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <cmath> #include <cmath>
#include <utility>
#include "ngraph/common.hpp" #include "ngraph/common.hpp"
#include "ngraph/coordinate_transform.hpp" #include "ngraph/coordinate_transform.hpp"
...@@ -32,49 +33,86 @@ namespace ngraph ...@@ -32,49 +33,86 @@ namespace ngraph
const Shape& arg0_shape, const Shape& arg0_shape,
const Shape& arg1_shape, const Shape& arg1_shape,
const Shape& out_shape, const Shape& out_shape,
size_t arg0_dot_axis, size_t reduction_axes_count)
size_t arg1_dot_axis)
{ {
CoordinateTransform output_transform(out_shape); // Get the sizes of the dot axes. It's easiest to pull them from arg1 because they're
// right up front.
for (Coordinate out_coord : output_transform) Shape dot_axis_sizes(reduction_axes_count);
{ std::copy(arg1_shape.begin(),
out[output_transform.index(out_coord)] = 0; arg1_shape.begin() + reduction_axes_count,
} dot_axis_sizes.begin());
CoordinateTransform arg0_transform(arg0_shape); CoordinateTransform arg0_transform(arg0_shape);
CoordinateTransform arg1_transform(arg1_shape); CoordinateTransform arg1_transform(arg1_shape);
CoordinateTransform output_transform(out_shape);
CoordinateTransform arg0_projected_transform( // Create coordinate transforms for arg0 and arg1 that throw away the dotted axes.
project_shape(arg0_shape, AxisSet{arg0_dot_axis})); size_t arg0_projected_rank = arg0_shape.size() - reduction_axes_count;
CoordinateTransform arg1_projected_transform( size_t arg1_projected_rank = arg1_shape.size() - reduction_axes_count;
project_shape(arg1_shape, AxisSet{arg1_dot_axis}));
Shape arg0_projected_shape(arg0_projected_rank);
std::copy(arg0_shape.begin(),
arg0_shape.begin() + arg0_projected_rank,
arg0_projected_shape.begin());
Shape arg1_projected_shape(arg1_projected_rank);
std::copy(arg1_shape.begin() + reduction_axes_count,
arg1_shape.end(),
arg1_projected_shape.begin());
CoordinateTransform arg0_projected_transform(arg0_projected_shape);
CoordinateTransform arg1_projected_transform(arg1_projected_shape);
// Create a coordinate transform that allows us to iterate over all possible values
// for the dotted axes.
CoordinateTransform dot_axes_transform(dot_axis_sizes);
for (Coordinate arg0_projected_coord : arg0_projected_transform) for (Coordinate arg0_projected_coord : arg0_projected_transform)
{ {
for (Coordinate arg1_projected_coord : arg1_projected_transform) for (Coordinate arg1_projected_coord : arg1_projected_transform)
{ {
for (size_t i = 0; i < arg0_shape[arg0_dot_axis]; i++) // The output coordinate is just the concatenation of the projected coordinates.
{
Coordinate arg0_coord =
inject_coordinate(arg0_projected_coord, arg0_dot_axis, i);
Coordinate arg1_coord =
inject_coordinate(arg1_projected_coord, arg1_dot_axis, i);
Coordinate out_coord(arg0_projected_coord.size() + Coordinate out_coord(arg0_projected_coord.size() +
arg1_projected_coord.size()); arg1_projected_coord.size());
std::copy(arg0_projected_coord.begin(), auto out_coord_it = std::copy(arg0_projected_coord.begin(),
arg0_projected_coord.end(), arg0_projected_coord.end(),
out_coord.begin()); out_coord.begin());
std::copy(arg1_projected_coord.begin(), std::copy(
arg1_projected_coord.end(), arg1_projected_coord.begin(), arg1_projected_coord.end(), out_coord_it);
out_coord.begin() + arg0_projected_coord.size());
// Zero out to start the sum.
T sum = 0;
size_t out_index = output_transform.index(out_coord);
// Walk along the dotted axes.
for (Coordinate dot_axis_positions : dot_axes_transform)
{
// In order to find the points to multiply together, we need to inject our current
// positions along the dotted axes back into the projected arg0 and arg1 coordinates.
Coordinate arg0_coord(arg0_shape.size());
Coordinate arg1_coord(arg1_shape.size());
out[output_transform.index(out_coord)] += auto arg0_it = std::copy(arg0_projected_coord.begin(),
arg0[arg0_transform.index(arg0_coord)] * arg0_projected_coord.end(),
arg0_coord.begin());
std::copy(
dot_axis_positions.begin(), dot_axis_positions.end(), arg0_it);
auto arg1_it = std::copy(dot_axis_positions.begin(),
dot_axis_positions.end(),
arg1_coord.begin());
std::copy(
arg1_projected_coord.begin(), arg1_projected_coord.end(), arg1_it);
// Multiply and add to the sum.
sum += arg0[arg0_transform.index(arg0_coord)] *
arg1[arg1_transform.index(arg1_coord)]; arg1[arg1_transform.index(arg1_coord)];
} }
// Write the sum back.
out[out_index] = sum;
} }
} }
} }
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
#include "ngraph/common.hpp"
#include "ngraph/coordinate_transform.hpp"
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void scalar_tensor_product(T* arg0, // the scalar (TODO: just pass as T?)
T* arg1, // the tensor
T* out,
size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = (*arg0) * arg1[i];
}
}
}
}
}
...@@ -109,7 +109,6 @@ ...@@ -109,7 +109,6 @@
#include "ngraph/runtime/ngvm/instruction/replace_slice.hpp" #include "ngraph/runtime/ngvm/instruction/replace_slice.hpp"
#include "ngraph/runtime/ngvm/instruction/reshape.hpp" #include "ngraph/runtime/ngvm/instruction/reshape.hpp"
#include "ngraph/runtime/ngvm/instruction/return.hpp" #include "ngraph/runtime/ngvm/instruction/return.hpp"
#include "ngraph/runtime/ngvm/instruction/scalar_tensor_product.hpp"
#include "ngraph/runtime/ngvm/instruction/select.hpp" #include "ngraph/runtime/ngvm/instruction/select.hpp"
#include "ngraph/runtime/ngvm/instruction/sign.hpp" #include "ngraph/runtime/ngvm/instruction/sign.hpp"
#include "ngraph/runtime/ngvm/instruction/sin.hpp" #include "ngraph/runtime/ngvm/instruction/sin.hpp"
...@@ -352,8 +351,6 @@ std::vector<typename ET::type> ...@@ -352,8 +351,6 @@ std::vector<typename ET::type>
} }
#define PUSH_POLYMORPHIC_INSTRUCTION(et, err_msg, instr, ...) \ #define PUSH_POLYMORPHIC_INSTRUCTION(et, err_msg, instr, ...) \
DO_ON_ELEMENT_TYPE(et, err_msg, PUSH_INSTRUCTION, instr, __VA_ARGS__) DO_ON_ELEMENT_TYPE(et, err_msg, PUSH_INSTRUCTION, instr, __VA_ARGS__)
#define PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(et, err_msg, instr, ...) \
DO_ON_NUMERIC_TYPE(et, err_msg, PUSH_INSTRUCTION, instr, __VA_ARGS__)
// Turn off complaint suppression (see above) // Turn off complaint suppression (see above)
#pragma clang diagnostic pop #pragma clang diagnostic pop
...@@ -550,6 +547,8 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -550,6 +547,8 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
REGISTER_TO_OP_MAP(op::Dot) REGISTER_TO_OP_MAP(op::Dot)
{ {
auto dot = static_cast<const op::Dot*>(n);
auto& arg_nodes = n->get_arguments(); auto& arg_nodes = n->get_arguments();
assert(arg_nodes.size() == 2); assert(arg_nodes.size() == 2);
...@@ -566,70 +565,15 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -566,70 +565,15 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
auto arg1_shape = arg1_tensor_type->get_shape(); auto arg1_shape = arg1_tensor_type->get_shape();
auto& arg0_element_type = arg0_tensor_type->get_element_type(); auto& arg0_element_type = arg0_tensor_type->get_element_type();
auto reduction_axes_count = dot->get_reduction_axes_count();
auto result_tensor_type = auto result_tensor_type =
dynamic_pointer_cast<const TensorViewType>(n->get_value_type()); dynamic_pointer_cast<const TensorViewType>(n->get_value_type());
assert(nullptr != result_tensor_type); assert(nullptr != result_tensor_type);
auto result_shape = result_tensor_type->get_shape(); auto result_shape = result_tensor_type->get_shape();
// If arg0 or arg1 is a scalar, emit a scalar-tensor product. PUSH_POLYMORPHIC_INSTRUCTION(arg0_element_type,
if (arg0_shape.size() == 0)
{
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type",
instruction::ScalarTensorProductInstruction,
in[0],
in[1],
out[0]);
}
else if (arg1_shape.size() == 0)
{
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type",
instruction::ScalarTensorProductInstruction,
in[1],
in[0],
out[0]);
}
// If arg0 and arg1 are both vectors, dot both on axis 0.
else if (arg0_shape.size() == 1 && arg1_shape.size() == 1)
{
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type",
instruction::DotInstruction,
in[0],
in[1],
out[0],
arg0_shape,
arg1_shape,
result_shape,
0,
0);
}
// If arg0 is a matrix and arg1 is a vector, dot on axes 1 and 0 respectively.
else if (arg0_shape.size() == 2 && arg1_shape.size() == 1)
{
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type",
instruction::DotInstruction,
in[0],
in[1],
out[0],
arg0_shape,
arg1_shape,
result_shape,
1,
0);
}
// If arg0 is rank n and arg1 is rank m, dot on axes n-1 and m-2, respectively.
//
// Note that this happens to handle the vector-matrix and matrix-matrix cases.
else
{
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type", "Dot has unhandled element type",
instruction::DotInstruction, instruction::DotInstruction,
in[0], in[0],
...@@ -638,9 +582,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -638,9 +582,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
arg0_shape, arg0_shape,
arg1_shape, arg1_shape,
result_shape, result_shape,
arg0_shape.size() - 1, reduction_axes_count);
arg1_shape.size() - 2);
}
}; };
// Parameter is a "runtime no-op" because the output tensor has already been filled. // Parameter is a "runtime no-op" because the output tensor has already been filled.
......
...@@ -38,16 +38,14 @@ namespace ngraph ...@@ -38,16 +38,14 @@ namespace ngraph
const Shape& arg0_shape, const Shape& arg0_shape,
const Shape& arg1_shape, const Shape& arg1_shape,
const Shape& out_shape, const Shape& out_shape,
size_t arg0_dot_axis, size_t reduction_axes_count)
size_t arg1_dot_axis)
: m_arg0(arg0) : m_arg0(arg0)
, m_arg1(arg1) , m_arg1(arg1)
, m_out(out) , m_out(out)
, m_arg0_shape(arg0_shape) , m_arg0_shape(arg0_shape)
, m_arg1_shape(arg1_shape) , m_arg1_shape(arg1_shape)
, m_out_shape(out_shape) , m_out_shape(out_shape)
, m_arg0_dot_axis(arg0_dot_axis) , m_reduction_axes_count(reduction_axes_count)
, m_arg1_dot_axis(arg1_dot_axis)
{ {
} }
...@@ -63,8 +61,7 @@ namespace ngraph ...@@ -63,8 +61,7 @@ namespace ngraph
m_arg0_shape, m_arg0_shape,
m_arg1_shape, m_arg1_shape,
m_out_shape, m_out_shape,
m_arg0_dot_axis, m_reduction_axes_count);
m_arg1_dot_axis);
} }
protected: protected:
...@@ -74,8 +71,7 @@ namespace ngraph ...@@ -74,8 +71,7 @@ namespace ngraph
Shape m_arg0_shape; Shape m_arg0_shape;
Shape m_arg1_shape; Shape m_arg1_shape;
Shape m_out_shape; Shape m_out_shape;
size_t m_arg0_dot_axis; size_t m_reduction_axes_count;
size_t m_arg1_dot_axis;
}; };
} }
} }
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/kernel/scalar_tensor_product.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace instruction
{
template <typename ET>
class ScalarTensorProductInstruction : public Instruction
{
public:
ScalarTensorProductInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg1);
kernel::scalar_tensor_product<typename ET::type>(arg0, arg1, out, count);
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
...@@ -521,6 +521,28 @@ TEST(${BACKEND_NAME}, backwards_dot_tensor2_tensor2) ...@@ -521,6 +521,28 @@ TEST(${BACKEND_NAME}, backwards_dot_tensor2_tensor2)
autodiff_numeric_compare<float>(manager, backend, make_graph, {x0, x1}, .01f, .01f)); autodiff_numeric_compare<float>(manager, backend, make_graph, {x0, x1}, .01f, .01f));
} }
TEST(${BACKEND_NAME}, backwards_dot_tensor3_tensor3)
{
auto manager = runtime::Manager::get("NGVM");
auto backend = manager->allocate_backend();
test::Uniform<float> rng(-1.0f, 1.0f);
auto shape0 = Shape{2, 4, 3};
auto shape1 = Shape{4, 3, 3};
auto x0 = rng.initialize(backend->make_primary_tensor_view<float>(shape0));
auto x1 = rng.initialize(backend->make_primary_tensor_view<float>(shape1));
auto make_graph = [shape0, shape1]() {
auto X0 = make_shared<op::Parameter>(element::Float32::element_type(), shape0);
auto X1 = make_shared<op::Parameter>(element::Float32::element_type(), shape1);
return make_shared<Function>(make_shared<op::Dot>(X0, X1, 2),
nullptr,
std::vector<std::shared_ptr<op::Parameter>>{X0, X1});
};
EXPECT_TRUE(
autodiff_numeric_compare<float>(manager, backend, make_graph, {x0, x1}, .01f, .01f));
}
TEST(${BACKEND_NAME}, backwards_exp) TEST(${BACKEND_NAME}, backwards_exp)
{ {
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
......
This diff is collapsed.
...@@ -32,15 +32,6 @@ TEST(type_prop, broadcast_deduce) ...@@ -32,15 +32,6 @@ TEST(type_prop, broadcast_deduce)
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 3, 4})); ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 3, 4}));
} }
TEST(type_prop, broadcast_deduce_correct)
{
// Check deduced type against correctly specified type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 4});
auto bc = make_shared<op::Broadcast>(param, Shape{2, 3, 4}, AxisSet{1});
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 3, 4}));
}
TEST(type_prop, broadcast_deduce_incorrect) TEST(type_prop, broadcast_deduce_incorrect)
{ {
// Check deduced type against incorrectly specified type // Check deduced type against incorrectly specified type
...@@ -228,15 +219,6 @@ TEST(type_prop, convert_deduce) ...@@ -228,15 +219,6 @@ TEST(type_prop, convert_deduce)
ASSERT_EQ(*c_vt, TensorViewType(element::Int32::element_type(), Shape{2, 3, 4})); ASSERT_EQ(*c_vt, TensorViewType(element::Int32::element_type(), Shape{2, 3, 4}));
} }
TEST(type_prop, convert_deduce_correct)
{
// Check deduced type against incorrectly specified type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3, 4});
auto c = make_shared<op::Convert>(param, element::Int32::element_type());
auto c_vt = c->get_value_type();
ASSERT_EQ(*c_vt, TensorViewType(element::Int32::element_type(), Shape{2, 3, 4}));
}
TEST(type_prop, convert_deduce_incorrect) TEST(type_prop, convert_deduce_incorrect)
{ {
// Check deduced type against incorrectly specified type // Check deduced type against incorrectly specified type
...@@ -322,17 +304,7 @@ TEST(type_prop, dot_deduce_different_rank) ...@@ -322,17 +304,7 @@ TEST(type_prop, dot_deduce_different_rank)
{ {
// Deduce type for different-rank tensor arguments // Deduce type for different-rank tensor arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 8, 4, 2}); auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 8, 4, 2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1, 2, 3}); auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 1, 3});
auto bc = make_shared<op::Dot>(param1, param2);
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 8, 4, 1, 3}));
}
TEST(type_prop, dot_deduce_different_rank_correct)
{
// Deduced type matches explicitly set type
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 8, 4, 2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1, 2, 3});
auto bc = make_shared<op::Dot>(param1, param2); auto bc = make_shared<op::Dot>(param1, param2);
auto bc_vt = bc->get_value_type(); auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 8, 4, 1, 3})); ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 8, 4, 1, 3}));
...@@ -372,7 +344,7 @@ TEST(type_prop, dot_deduce_reduction_axes_size_mismatch) ...@@ -372,7 +344,7 @@ TEST(type_prop, dot_deduce_reduction_axes_size_mismatch)
} }
catch (const ngraph_error& error) catch (const ngraph_error& error)
{ {
EXPECT_EQ(error.what(), std::string("Dot reduction axes not compatible")); EXPECT_EQ(error.what(), std::string("Dot axes do not have same length"));
} }
catch (...) catch (...)
{ {
...@@ -571,19 +543,6 @@ TEST(type_prop, select_deduce) ...@@ -571,19 +543,6 @@ TEST(type_prop, select_deduce)
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 4})); ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 4}));
} }
TEST(type_prop, select_deduce_correct)
{
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Bool::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_2 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 4}));
}
TEST(type_prop, select_shape_mismatch_a) TEST(type_prop, select_shape_mismatch_a)
{ {
auto tv0_2_4_param_0 = make_shared<op::Parameter>( auto tv0_2_4_param_0 = make_shared<op::Parameter>(
...@@ -735,24 +694,6 @@ TEST(type_prop, reduce_deduce) ...@@ -735,24 +694,6 @@ TEST(type_prop, reduce_deduce)
TensorViewType(element::Float32::element_type(), Shape{2, 4})); TensorViewType(element::Float32::element_type(), Shape{2, 4}));
} }
TEST(type_prop, reduce_deduce_correct)
{
auto param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{}));
auto f_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{}));
auto f_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{}));
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
auto f = make_shared<Function>(f_param_0 + f_param_1, rt, op::Parameters{f_param_0, f_param_1});
auto r0 = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{0});
ASSERT_EQ(*(r0->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{4}));
}
TEST(type_prop, reduce_nonscalar) TEST(type_prop, reduce_nonscalar)
{ {
auto param_0 = make_shared<op::Parameter>( auto param_0 = make_shared<op::Parameter>(
...@@ -1073,14 +1014,6 @@ TEST(type_prop, reshape_deduce_t2v_120) ...@@ -1073,14 +1014,6 @@ TEST(type_prop, reshape_deduce_t2v_120)
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{60})); ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{60}));
} }
TEST(type_prop, reshape_deduce_correct_t2v_120)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 4, 5}));
auto r = make_shared<op::Reshape>(param, AxisVector{1, 2, 0}, Shape{60});
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{60}));
}
TEST(type_prop, reshape_deduce_not_enough_axes) TEST(type_prop, reshape_deduce_not_enough_axes)
{ {
auto param = make_shared<op::Parameter>( auto param = make_shared<op::Parameter>(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment