Unverified Commit cba4e54e authored by Adam Procter's avatar Adam Procter Committed by GitHub

Generalized dot (#291)

* WIP generalized dot

* Add some multi-axis 3D, 4D, and 5D tests

* Add test on some 'pretty big' tensors

* Reworked dot to have less flexible axis-pairing behavior

* Backprop for dot... and a fix for a dumb bug in CoordinateTransform

* Forgot to commit some stuff in merge

* Disable tests that currently don't work on CPU

* Fix temporarily disabled test that should pass on NGVM and INTERPRETER but wasn't due to new axis-selection convention for dot

* Remove obsolete ScalarTensorProduct kernel/instruction

* Review comment

* s/n_dot_axes/dot_axis_count/

* s/dot_axis_count/reduction_axes_count/

* Adapt CPU emitter dot fallback to new kernel
parent a960f07e
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include "ngraph/common.hpp"
using namespace ngraph;
......@@ -37,31 +39,50 @@ Shape ngraph::project_shape(const Shape& shape, const AxisSet& deleted_axes)
return project_coordinate(shape, deleted_axes);
}
// TODO: for the moment, just one axis at a time, please. Later could pass in an std::map from axis positions to axis lengths.
// TODO: check validity, i.e. that the new axis is < coord_size+1.
Coordinate
ngraph::inject_coordinate(const Coordinate& coord, size_t new_axis_pos, size_t new_axis_val)
// TODO: check validity, i.e. that the new axis indices are all < coord_size+num_new_axes.
Coordinate ngraph::inject_coordinate(const Coordinate& coord,
std::vector<std::pair<size_t, size_t>> new_axis_pos_val_pairs)
{
Coordinate result;
size_t original_pos = 0;
for (size_t result_pos = 0; result_pos < coord.size() + 1; result_pos++)
for (size_t result_pos = 0; result_pos < coord.size() + new_axis_pos_val_pairs.size();
result_pos++)
{
if (result_pos == new_axis_pos)
auto search_it = std::find_if(
new_axis_pos_val_pairs.begin(),
new_axis_pos_val_pairs.end(),
[result_pos](std::pair<size_t, size_t> p) { return p.first == result_pos; });
if (search_it == new_axis_pos_val_pairs.end())
{
result.push_back(new_axis_val);
result.push_back(coord[original_pos++]);
}
else
{
result.push_back(coord[original_pos++]);
result.push_back(search_it->second);
}
}
return result;
}
Coordinate
ngraph::inject_coordinate(const Coordinate& coord, size_t new_axis_pos, size_t new_axis_val)
{
return inject_coordinate(coord,
std::vector<std::pair<size_t, size_t>>{
std::pair<size_t, size_t>(new_axis_pos, new_axis_val)});
}
Shape ngraph::inject_shape(const Shape& shape, size_t new_axis_pos, size_t new_axis_length)
{
return inject_coordinate(shape, new_axis_pos, new_axis_length);
}
Shape inject_shape(const Shape& shape,
std::vector<std::pair<size_t, size_t>> new_axis_pos_length_pairs)
{
return inject_coordinate(shape, new_axis_pos_length_pairs);
}
......@@ -16,6 +16,7 @@
#include <memory>
#include <set>
#include <utility>
#include <vector>
// Names for types that aren't worth giving their own classes
......@@ -56,5 +57,9 @@ namespace ngraph
Shape project_shape(const Shape& shape, const AxisSet& deleted_axes);
Coordinate inject_coordinate(const Coordinate& coord, size_t new_axis_pos, size_t new_axis_val);
Coordinate inject_coordinate(const Coordinate& coord,
std::vector<std::pair<size_t, size_t>> new_axis_pos_val_pairs);
Shape inject_shape(const Shape& shape, size_t new_axis_pos, size_t new_axis_length);
Shape inject_shape(const Shape& shape,
std::vector<std::pair<size_t, size_t>> new_axis_pos_length_pairs);
}
......@@ -209,8 +209,8 @@ Coordinate CoordinateTransform::to_source_coordinate(const Coordinate& c) const
for (size_t axis = 0; axis < m_n_axes; axis++)
{
result[axis] = c[m_source_axis_order[axis]] * m_source_strides[m_source_axis_order[axis]] +
m_source_start_corner[m_source_axis_order[axis]];
result[m_source_axis_order[axis]] =
c[axis] * m_source_strides[axis] + m_source_start_corner[axis];
}
return result;
......
......@@ -14,6 +14,7 @@
#include <functional>
#include <memory>
#include <utility>
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/dot.hpp"
......@@ -25,8 +26,49 @@
using namespace std;
using namespace ngraph;
//
// Helper function to compute the number of dot axes according to default behavior when
// they are not specified.
//
size_t default_reduction_axes_count(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
auto arg0_value_type = arg0->get_value_type();
auto arg0_tensor_view_type = std::dynamic_pointer_cast<const TensorViewType>(arg0_value_type);
if (nullptr == arg0_tensor_view_type)
{
throw ngraph_error("Dot arg0 does not have tensor view type");
}
auto arg0_shape = arg0_tensor_view_type->get_shape();
auto arg1_value_type = arg1->get_value_type();
auto arg1_tensor_view_type = std::dynamic_pointer_cast<const TensorViewType>(arg1_value_type);
if (nullptr == arg1_tensor_view_type)
{
throw ngraph_error("Dot arg1 does not have tensor view type");
}
auto arg1_shape = arg1_tensor_view_type->get_shape();
if (arg0_shape.size() == 0 || arg1_shape.size() == 0)
{
return 0;
}
else
{
return 1;
}
}
op::Dot::Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Dot(arg0, arg1, default_reduction_axes_count(arg0, arg1))
{
}
op::Dot::Dot(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
size_t reduction_axes_count)
: RequiresTensorViewArgs("Dot", {arg0, arg1})
, m_reduction_axes_count(reduction_axes_count)
{
auto arg0_tensor_type = get_inputs().at(0).get_tensor_view_type();
auto arg1_tensor_type = get_inputs().at(1).get_tensor_view_type();
......@@ -38,205 +80,82 @@ op::Dot::Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg
vector<size_t> arg0_shape = arg0_tensor_type->get_shape();
vector<size_t> arg1_shape = arg1_tensor_type->get_shape();
size_t arg0_reduction = arg0_shape.size() - 1;
size_t arg1_reduction;
const bool is_scalar_mult = arg0_shape.size() == 0 || arg1_shape.size() == 0;
if (arg1_shape.size() > 1)
if (reduction_axes_count > arg0_shape.size())
{
arg1_reduction = arg1_shape.size() - 2;
throw ngraph_error("Dot has too many axes for arg0");
}
else
{
arg1_reduction = arg1_shape.size() - 1;
}
if (!is_scalar_mult && (arg0_shape.at(arg0_reduction) != arg1_shape.at(arg1_reduction)))
{
throw ngraph_error("Dot reduction axes not compatible");
}
vector<size_t> result_shape;
result_shape.reserve(arg0_shape.size() + arg1_shape.size() - (is_scalar_mult ? 0 : 2));
for (auto i = 0; i < arg0_shape.size(); i++)
if (reduction_axes_count > arg1_shape.size())
{
if (is_scalar_mult || i != arg0_reduction)
{
result_shape.push_back(arg0_shape[i]);
}
throw ngraph_error("Dot has too many axes for arg1");
}
for (auto i = 0; i < arg1_shape.size(); i++)
for (size_t i = 0; i < reduction_axes_count; i++)
{
if (is_scalar_mult || i != arg1_reduction)
if (arg0_shape[arg0_shape.size() - reduction_axes_count + i] != arg1_shape[i])
{
result_shape.push_back(arg1_shape[i]);
throw ngraph_error("Dot axes do not have same length");
}
}
vector<size_t> result_shape(arg0_shape.size() + arg1_shape.size() - 2 * reduction_axes_count);
std::copy(arg0_shape.begin(), arg0_shape.end() - reduction_axes_count, result_shape.begin());
std::copy(arg1_shape.begin() + reduction_axes_count,
arg1_shape.end(),
result_shape.begin() + (arg0_shape.size() - reduction_axes_count));
auto result_type =
make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape);
set_value_type_checked(result_type);
}
template <typename T>
T range(size_t n);
template <>
ngraph::AxisSet range<ngraph::AxisSet>(size_t n)
std::shared_ptr<op::Reshape> make_reshape_axes_to_front(const std::shared_ptr<Node>& n,
const Shape& front_shape,
const Shape& back_shape)
{
ngraph::AxisSet result;
for (size_t i = 0; i < n; i++)
AxisVector input_order;
Shape output_shape;
for (size_t i = 0; i < back_shape.size(); i++)
{
result.insert(i);
input_order.push_back(front_shape.size() + i);
output_shape.push_back(back_shape[i]);
}
return result;
}
template <>
ngraph::AxisVector range<ngraph::AxisVector>(size_t n)
{
ngraph::AxisVector result;
for (size_t i = 0; i < n; i++)
for (size_t i = 0; i < front_shape.size(); i++)
{
result.push_back(i);
input_order.push_back(i);
output_shape.push_back(front_shape[i]);
}
return result;
return make_shared<op::Reshape>(n, input_order, output_shape);
}
void op::Dot::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta)
{
auto x = m_arguments[0];
auto y = m_arguments[1];
auto x_shape = x->get_shape();
auto y_shape = y->get_shape();
auto delta_shape = delta->get_shape();
if (is_scalar(x_shape))
{
adjoints.add_delta(y, make_shared<Dot>(delta, x));
if (is_scalar(y_shape))
{
// Just multiplication
adjoints.add_delta(x, delta * y);
return;
}
// scale dot tensor
adjoints.add_delta(x, make_shared<Sum>(delta * y, range<AxisSet>(y_shape.size())));
return;
}
if (is_scalar(y_shape))
{
// tensor dot scalar
adjoints.add_delta(x, make_shared<Dot>(delta, y));
adjoints.add_delta(y, make_shared<Sum>(delta * x, range<AxisSet>(x_shape.size())));
return;
}
if (is_vector(y_shape))
{
if (is_vector(x_shape))
{
adjoints.add_delta(x, make_shared<Dot>(delta, y));
}
else
{
// X has shape IJ, Y has shape J, delta has shape I
// delta -> (I, 1)
// Y -> (1, J)
// delta . Y is (I, J)
Shape shape_delta_1 = delta->get_shape();
shape_delta_1.push_back(1);
auto delta_1 =
make_shared<Broadcast>(delta, shape_delta_1, AxisSet{delta->get_shape().size()});
Shape shape_1_y{1};
shape_1_y.insert(shape_1_y.end(), y_shape.begin(), y_shape.end());
auto y_1 = make_shared<Broadcast>(y, shape_1_y, AxisSet{0});
adjoints.add_delta(x, make_shared<Dot>(delta_1, y_1));
}
// X has shape IJ
// Y has shape J
// delta has shape I
// Need to move J to front of X and multiply by Y
Shape shape_xt(x_shape.size());
AxisVector x_axes(x_shape.size());
shape_xt[0] = x_shape.at(x_shape.size() - 1);
x_axes[0] = x_shape.size() - 1;
for (size_t i = 1; i < x_shape.size(); ++i)
{
shape_xt[i] = x_shape[i - 1];
x_axes[i] = i - 1;
}
auto x_reshape = make_shared<Reshape>(x, x_axes, shape_xt);
adjoints.add_delta(y, make_shared<Dot>(x_reshape, delta));
return;
}
// Tensor tensor case
// X is Ij
// Y = Kjl
// X.Y, delta is IKl
//
// delta -> I(Kl)
// Y -> (Kl)j
// delta.Y -> Ij
Shape s_I;
s_I.insert(s_I.begin(), x_shape.begin(), x_shape.end() - 1);
size_t s_j = x_shape[x_shape.size() - 1];
Shape s_K;
s_K.insert(s_K.begin(), y_shape.begin(), y_shape.end() - 2);
size_t s_l = y_shape[y_shape.size() - 1];
size_t s_Kl = shape_size(s_K) * s_l;
Shape shape_delta_I_Kl;
shape_delta_I_Kl.insert(shape_delta_I_Kl.end(), s_I.begin(), s_I.end());
shape_delta_I_Kl.push_back(s_Kl);
AxisVector idx_delta_I_Kl = range<AxisVector>(delta_shape.size());
auto delta_I_Kl = make_shared<Reshape>(delta, idx_delta_I_Kl, shape_delta_I_Kl);
Shape shape_y_Kl_j{s_Kl, s_j};
AxisVector idx_y_Kl_j = range<AxisVector>(y_shape.size() - 2);
idx_y_Kl_j.push_back(y_shape.size() - 1);
idx_y_Kl_j.push_back(y_shape.size() - 2);
auto y_Kl_j = make_shared<Reshape>(y, idx_y_Kl_j, shape_y_Kl_j);
adjoints.add_delta(x, make_shared<Dot>(delta_I_Kl, y_Kl_j));
// delta -> K(I)l
// X -> j(I)
// X.delta -> jKl -> Kjl
Shape shape_delta_K_I_l;
shape_delta_K_I_l.insert(shape_delta_K_I_l.begin(), s_K.begin(), s_K.end());
shape_delta_K_I_l.push_back(shape_size(s_I));
shape_delta_K_I_l.push_back(s_l);
AxisVector idx_delta = range<AxisVector>(delta_shape.size());
AxisVector idx_delta_K_I_l;
idx_delta_K_I_l.insert(idx_delta_K_I_l.end(),
idx_delta.begin() + s_I.size(),
idx_delta.begin() + s_I.size() + s_K.size());
idx_delta_K_I_l.insert(
idx_delta_K_I_l.end(), idx_delta.begin(), idx_delta.begin() + s_I.size());
idx_delta_K_I_l.push_back(delta_shape.size() - 1);
auto delta_K_I_l = make_shared<Reshape>(delta, idx_delta_K_I_l, shape_delta_K_I_l);
Shape shape_x_j_I;
shape_x_j_I.push_back(s_j);
shape_x_j_I.push_back(shape_size(s_I));
AxisVector idx_x = range<AxisVector>(x_shape.size());
AxisVector idx_x_j_I;
idx_x_j_I.push_back(idx_x[idx_x.size() - 1]);
idx_x_j_I.insert(idx_x_j_I.end(), idx_x.begin(), idx_x.begin() + idx_x.size() - 1);
auto x_j_I = make_shared<Reshape>(x, idx_x_j_I, shape_x_j_I);
auto jKl = make_shared<Dot>(x_j_I, delta_K_I_l);
Shape shape_Kjl;
shape_Kjl.insert(shape_Kjl.end(), s_K.begin(), s_K.end());
shape_Kjl.push_back(s_j);
shape_Kjl.push_back(s_l);
AxisVector idx_Kjl;
for (size_t i = 1; i < s_K.size() + 1; ++i)
{
idx_Kjl.push_back(i);
}
idx_Kjl.push_back(0);
idx_Kjl.push_back(y_shape.size() - 1);
auto Klj = make_shared<Reshape>(jKl, idx_Kjl, shape_Kjl);
adjoints.add_delta(y, Klj);
auto x = get_inputs().at(0).get_output().get_node();
auto y = get_inputs().at(1).get_output().get_node();
auto x_shape = x->get_shape(); // shape IJ
auto y_shape = y->get_shape(); // shape JK
auto delta_shape = delta->get_shape(); // shape IK
Shape I_shape;
Shape J_shape;
Shape K_shape;
I_shape.insert(I_shape.begin(), x_shape.begin(), x_shape.end() - m_reduction_axes_count);
J_shape.insert(J_shape.begin(), y_shape.begin(), y_shape.begin() + m_reduction_axes_count);
K_shape.insert(K_shape.begin(), y_shape.begin() + J_shape.size(), y_shape.end());
auto delta_reshaped = make_reshape_axes_to_front(delta, I_shape, K_shape); // KI
auto delta_reshaped_dot_y = make_shared<Dot>(y, delta_reshaped, K_shape.size()); // JI
auto delta_reshaped_dot_y_reshaped =
make_reshape_axes_to_front(delta_reshaped_dot_y, J_shape, I_shape); // IJ
adjoints.add_delta(x, delta_reshaped_dot_y_reshaped);
auto x_reshaped = make_reshape_axes_to_front(x, I_shape, J_shape); // JI
auto x_reshaped_dot_delta = make_shared<Dot>(x_reshaped, delta, I_shape.size()); // JK
adjoints.add_delta(y, x_reshaped_dot_delta);
}
......@@ -14,76 +14,44 @@
#pragma once
#include <utility>
#include "ngraph/ops/op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Inner product/dot product/matrix product/tensor contraction operation.
///
/// Takes two arguments `arg0` and `arg1`. There are three possible cases:
/// \brief Generalized dot product operation, including scalar-tensor product, matrix-vector product, and matrix multiplication.
///
/// 1. `arg0` or `arg1` is 0-dimensional. Then, treats that 0-dimensional argument as a scalars and computes a scalar-tensor product.
/// (Example: `arg0` has shape `{1,2,3}` and arg1 has shape `{}`; then the result will have shape `{1,2,3}`.)
/// Takes two arguments `arg0` and `arg1`, with shapes \f$(i_1,\dots,i_n,j_1,\dots,j_m)\f$ and \f$(j_1,\dots,j_m,k_1,\dots,k_p)\f$ respectively,
/// and produces an output tensor with shape \f$(i_1,\dots,i_n,k_1,\dots,k_p)\f$ by summing products along the \f$j\f$ dimensions.
///
/// 2. `arg1` is a vector (1-dimensional tensor). Then, computes a dot product reducing on the innermost (rightmost) dimensions of `arg0` and `arg1`.
/// (Example: arg0 has shape `{1,2,3}` and arg1 has shape `{3}`; then the result will have shape `{1,2}`.)
/// A few common cases are as follows:
///
/// 3. `arg1` is more than 1-dimensional. Then, computes a dot product reducing on the innermost (rightmost) dimension of arg0, and the next-to-innermost dimension of arg1.
/// (Example: arg0 has shape {3,4} and arg1 has shape {4,3}; then the result will have shape {3,3}.)
/// * If \f$m = 0\f$ and \f$n = 1\f$ or \f$p = 1\f$, the operation is a scalar-tensor product.
/// * If \f$m = 1\f$, \f$n = 2\f$, and \f$p = 1\f$, the operation is a matrix-vector product.
/// * If \f$m = 1\f$ and \f$n = p = 2\f$, the operation is a matrix multiplication.
///
/// ## Parameters
///
/// # Case 1: Scalar-tensor product
/// | | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------------------ |
/// | `reduction_axes_count` | The number of axes to reduce through dot-product (corresponds to \f$m\f$ in the formulas above). |
///
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | --------------------------------- | ------------------------------------------------------------ |
/// | `arg0` | \f$E[]\f$ | A scalar of any element type. |
/// | `arg1` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape, with the same element type as `arg0`. |
///
/// <i>(Note: the order of inputs may be reversed in this case, i.e., `arg1` can be the scalar and `arg0` the tensor.)</i>
/// | | Type | Description |
/// | ------ | ----------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | `arg0` | \f$E[d_1,\dots,d_n,d'_1,\dots,d'_m]~(n,m \geq 0)\f$ | A tensor of any shape and element type. |
/// | `arg1` | \f$E[d'_1,\dots,d'_m,d''_1,\dots,d''_p]~(p \geq 0)\f$ | A tensor of any shape with the same element type as `arg0` and rank at least \f$m\f$, whose first \f$m\f$ dimensions match the last \f$m\f$ dimensions of `arg0`, in order. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | ---------------------------------------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathtt{arg0} \cdot \mathtt{arg1}[i_1,\dots,i_n]\f$. |
///
/// # Case 2: Vector-tensor product
///
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | ----------------------------------- | ------------------------------------------------------------------------------------------------------------ |
/// | `arg0` | \f$E[d]\f$ | A vector of any element type. |
/// | `arg1` | \f$E[d_1,\dots,d_n,d]~(n \geq 0)\f$ | A tensor of any shape whose innermost dimension matches `arg0`'s size, with the same element type as `arg0`. |
/// | Type | Description |
/// | ---------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n,d''_1,\dots,d''_p]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n,k_1,\dots,k_p] = \Sigma_{0 \le j_1 < d'_1, \dots, 0 \le j_m < d'_m}(\mathtt{arg0}[i_1,\dots,i_n,j_1,\dots,j_m] \cdot \mathtt{arg1}[j_1,\dots,j_m,k_1,\dots,k_p])\f$ or, if \f$m = 0\f$, \f$T[i_1,\dots,i_n,k_1,\dots,k_p] = \mathtt{arg0}[i_1,\dots,i_n] \cdot \mathtt{arg1}[k_1,\dots,k_p]\f$. |
///
/// <i>(Note: in the particular case where \f$n = 0\f$, this is a vector dot product; when \f$n = 1\f$, this is a vector-matrix product.)</i>
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \Sigma_{0 \le k < d}(\mathtt{arg0}[k] \cdot \mathtt{arg1}[i_1,\dots,i_n,k])\f$. |
///
/// # Case 3: Tensor-tensor product
///
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | ----------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ |
/// | `arg0` | \f$E[d_1,\dots,d_n]~(n \geq 1)\f$ | A tensor of any shape with rank of at least 1, and any element type. |
/// | `arg1` | \f$E[d'_1,\dots,d'_m]~(m \geq 2\text{ and }d'_{m-1}=d_n)\f$ | A tensor with the same element type as `arg0`, and any shape with rank of at least 2 whose next-to-innermost dimension matches `arg0`'s innermost dimension. |
///
/// <i>(Note: in the particular case where \f$n = m = 2\f$, this is a matrix product.)</i>
///
/// ## Output
///
/// | Type | Description |
/// | ----------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
/// | \f$E[d_1,\dots,d_{n-1},d'_1,\dots,d'_{m-2},d'_{m}]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_{n-1},j_1,\dots,j_{m-2},j_m] = \Sigma_{0 \le k < d_n}(\texttt{arg0}[i_1,\dots,i_{n-1},k] \cdot \texttt{arg1}[j_1,\dots,j_{n-2},k,j_n])\f$ |
class Dot : public RequiresTensorViewArgs
{
public:
......@@ -91,17 +59,36 @@ namespace ngraph
///
/// \param arg0 The node producing the first argument.
/// \param arg1 The node producing the second argument.
/// \param reduction_axes_count The number of axes to dot.
Dot(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
size_t reduction_axes_count);
/// \brief Constructs a dot product operation with default dot-axis selection depending on the inputs.
///
/// If `arg0` or `arg1` is a scalar, there are no dot-axes. Else, there is one dot-axis.
///
/// (Note that in particular, this results in scalar-tensor products where one or the other argument is
/// a scalar, a matrix-vector products where `arg0` is a matrix and `arg1` is a vector, and a
/// matrix multiplication where `arg0` and `arg1` are both matrices.)
///
/// \param arg0 The node producing the first argument.
/// \param arg1 The node producing the second argument.
Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
size_t get_reduction_axes_count() const { return m_reduction_axes_count; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Dot>(new_args.at(0), new_args.at(1));
return std::make_shared<Dot>(
new_args.at(0), new_args.at(1), m_reduction_axes_count);
}
protected:
size_t m_reduction_axes_count;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
};
......
......@@ -24,6 +24,7 @@
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/function_call.hpp"
#include "ngraph/ops/get_tuple_element.hpp"
#include "ngraph/ops/one_hot.hpp"
......@@ -137,29 +138,7 @@ void runtime::cpu::CPU_Emitter::EmitDot(const ngraph::Node* n,
}
else
{
size_t arg0_dot_axis;
size_t arg1_dot_axis;
if (arg0_shape.size() == 1 && arg1_shape.size() == 1)
{
arg0_dot_axis = 0;
arg1_dot_axis = 0;
}
// If arg0 is a matrix and arg1 is a vector, dot on axes 1 and 0 respectively.
else if (arg0_shape.size() == 2 && arg1_shape.size() == 1)
{
arg0_dot_axis = 1;
arg1_dot_axis = 0;
}
// If arg0 is rank n and arg1 is rank m, dot on axes n-1 and m-2, respectively.
//
// Note that this happens to handle the vector-matrix and matrix-matrix cases.
else
{
arg0_dot_axis = arg0_shape.size() - 1;
arg1_dot_axis = arg1_shape.size() - 2;
}
const ngraph::op::Dot* dot = static_cast<const ngraph::op::Dot*>(n);
m_out << "kernel::dot(" << args[0].get_name() << ",\n";
m_out << " " << args[1].get_name() << ",\n";
......@@ -167,8 +146,7 @@ void runtime::cpu::CPU_Emitter::EmitDot(const ngraph::Node* n,
m_out << " {" << join(args[0].get_shape()) << "},\n";
m_out << " {" << join(args[1].get_shape()) << "},\n";
m_out << " {" << join(out[0].get_shape()) << "},\n";
m_out << " " << arg0_dot_axis << ",\n";
m_out << " " << arg1_dot_axis << ");\n";
m_out << " " << dot->get_reduction_axes_count() << ");\n";
}
}
......
......@@ -23,6 +23,7 @@
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/get_tuple_element.hpp"
#include "ngraph/ops/one_hot.hpp"
#include "ngraph/ops/reduce.hpp"
......@@ -66,7 +67,6 @@
#include "ngraph/runtime/kernel/reduce.hpp"
#include "ngraph/runtime/kernel/replace_slice.hpp"
#include "ngraph/runtime/kernel/reshape.hpp"
#include "ngraph/runtime/kernel/scalar_tensor_product.hpp"
#include "ngraph/runtime/kernel/select.hpp"
#include "ngraph/runtime/kernel/sign.hpp"
#include "ngraph/runtime/kernel/sin.hpp"
......@@ -290,54 +290,15 @@ private:
}
else if (node_op == "Dot")
{
if (args[0]->get_shape().size() == 0)
{
kernel::scalar_tensor_product(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (args[1]->get_shape().size() == 0)
{
kernel::scalar_tensor_product(reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else
{
size_t arg0_dot_axis;
size_t arg1_dot_axis;
if (args[0]->get_shape().size() == 1 && args[1]->get_shape().size() == 1)
{
arg0_dot_axis = 0;
arg1_dot_axis = 0;
}
ngraph::op::Dot* dot = dynamic_cast<ngraph::op::Dot*>(&node);
// If arg0 is a matrix and arg1 is a vector, dot on axes 1 and 0 respectively.
else if (args[0]->get_shape().size() == 2 && args[1]->get_shape().size() == 1)
{
arg0_dot_axis = 1;
arg1_dot_axis = 0;
}
// If arg0 is rank n and arg1 is rank m, dot on axes n-1 and m-2, respectively.
//
// Note that this happens to handle the vector-matrix and matrix-matrix cases.
else
{
arg0_dot_axis = args[0]->get_shape().size() - 1;
arg1_dot_axis = args[1]->get_shape().size() - 2;
}
kernel::dot(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
args[1]->get_shape(),
out[0]->get_shape(),
arg0_dot_axis,
arg1_dot_axis);
}
kernel::dot(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
args[1]->get_shape(),
out[0]->get_shape(),
dot->get_reduction_axes_count());
}
else if (node_op == "Equal")
......
......@@ -15,6 +15,7 @@
#pragma once
#include <cmath>
#include <utility>
#include "ngraph/common.hpp"
#include "ngraph/coordinate_transform.hpp"
......@@ -32,49 +33,86 @@ namespace ngraph
const Shape& arg0_shape,
const Shape& arg1_shape,
const Shape& out_shape,
size_t arg0_dot_axis,
size_t arg1_dot_axis)
size_t reduction_axes_count)
{
CoordinateTransform output_transform(out_shape);
for (Coordinate out_coord : output_transform)
{
out[output_transform.index(out_coord)] = 0;
}
// Get the sizes of the dot axes. It's easiest to pull them from arg1 because they're
// right up front.
Shape dot_axis_sizes(reduction_axes_count);
std::copy(arg1_shape.begin(),
arg1_shape.begin() + reduction_axes_count,
dot_axis_sizes.begin());
CoordinateTransform arg0_transform(arg0_shape);
CoordinateTransform arg1_transform(arg1_shape);
CoordinateTransform output_transform(out_shape);
// Create coordinate transforms for arg0 and arg1 that throw away the dotted axes.
size_t arg0_projected_rank = arg0_shape.size() - reduction_axes_count;
size_t arg1_projected_rank = arg1_shape.size() - reduction_axes_count;
Shape arg0_projected_shape(arg0_projected_rank);
std::copy(arg0_shape.begin(),
arg0_shape.begin() + arg0_projected_rank,
arg0_projected_shape.begin());
Shape arg1_projected_shape(arg1_projected_rank);
std::copy(arg1_shape.begin() + reduction_axes_count,
arg1_shape.end(),
arg1_projected_shape.begin());
CoordinateTransform arg0_projected_transform(arg0_projected_shape);
CoordinateTransform arg1_projected_transform(arg1_projected_shape);
CoordinateTransform arg0_projected_transform(
project_shape(arg0_shape, AxisSet{arg0_dot_axis}));
CoordinateTransform arg1_projected_transform(
project_shape(arg1_shape, AxisSet{arg1_dot_axis}));
// Create a coordinate transform that allows us to iterate over all possible values
// for the dotted axes.
CoordinateTransform dot_axes_transform(dot_axis_sizes);
for (Coordinate arg0_projected_coord : arg0_projected_transform)
{
for (Coordinate arg1_projected_coord : arg1_projected_transform)
{
for (size_t i = 0; i < arg0_shape[arg0_dot_axis]; i++)
// The output coordinate is just the concatenation of the projected coordinates.
Coordinate out_coord(arg0_projected_coord.size() +
arg1_projected_coord.size());
auto out_coord_it = std::copy(arg0_projected_coord.begin(),
arg0_projected_coord.end(),
out_coord.begin());
std::copy(
arg1_projected_coord.begin(), arg1_projected_coord.end(), out_coord_it);
// Zero out to start the sum.
T sum = 0;
size_t out_index = output_transform.index(out_coord);
// Walk along the dotted axes.
for (Coordinate dot_axis_positions : dot_axes_transform)
{
Coordinate arg0_coord =
inject_coordinate(arg0_projected_coord, arg0_dot_axis, i);
Coordinate arg1_coord =
inject_coordinate(arg1_projected_coord, arg1_dot_axis, i);
Coordinate out_coord(arg0_projected_coord.size() +
arg1_projected_coord.size());
std::copy(arg0_projected_coord.begin(),
arg0_projected_coord.end(),
out_coord.begin());
std::copy(arg1_projected_coord.begin(),
arg1_projected_coord.end(),
out_coord.begin() + arg0_projected_coord.size());
out[output_transform.index(out_coord)] +=
arg0[arg0_transform.index(arg0_coord)] *
arg1[arg1_transform.index(arg1_coord)];
// In order to find the points to multiply together, we need to inject our current
// positions along the dotted axes back into the projected arg0 and arg1 coordinates.
Coordinate arg0_coord(arg0_shape.size());
Coordinate arg1_coord(arg1_shape.size());
auto arg0_it = std::copy(arg0_projected_coord.begin(),
arg0_projected_coord.end(),
arg0_coord.begin());
std::copy(
dot_axis_positions.begin(), dot_axis_positions.end(), arg0_it);
auto arg1_it = std::copy(dot_axis_positions.begin(),
dot_axis_positions.end(),
arg1_coord.begin());
std::copy(
arg1_projected_coord.begin(), arg1_projected_coord.end(), arg1_it);
// Multiply and add to the sum.
sum += arg0[arg0_transform.index(arg0_coord)] *
arg1[arg1_transform.index(arg1_coord)];
}
// Write the sum back.
out[out_index] = sum;
}
}
}
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
#include "ngraph/common.hpp"
#include "ngraph/coordinate_transform.hpp"
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void scalar_tensor_product(T* arg0, // the scalar (TODO: just pass as T?)
T* arg1, // the tensor
T* out,
size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = (*arg0) * arg1[i];
}
}
}
}
}
......@@ -109,7 +109,6 @@
#include "ngraph/runtime/ngvm/instruction/replace_slice.hpp"
#include "ngraph/runtime/ngvm/instruction/reshape.hpp"
#include "ngraph/runtime/ngvm/instruction/return.hpp"
#include "ngraph/runtime/ngvm/instruction/scalar_tensor_product.hpp"
#include "ngraph/runtime/ngvm/instruction/select.hpp"
#include "ngraph/runtime/ngvm/instruction/sign.hpp"
#include "ngraph/runtime/ngvm/instruction/sin.hpp"
......@@ -352,8 +351,6 @@ std::vector<typename ET::type>
}
#define PUSH_POLYMORPHIC_INSTRUCTION(et, err_msg, instr, ...) \
DO_ON_ELEMENT_TYPE(et, err_msg, PUSH_INSTRUCTION, instr, __VA_ARGS__)
#define PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(et, err_msg, instr, ...) \
DO_ON_NUMERIC_TYPE(et, err_msg, PUSH_INSTRUCTION, instr, __VA_ARGS__)
// Turn off complaint suppression (see above)
#pragma clang diagnostic pop
......@@ -550,6 +547,8 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
REGISTER_TO_OP_MAP(op::Dot)
{
auto dot = static_cast<const op::Dot*>(n);
auto& arg_nodes = n->get_arguments();
assert(arg_nodes.size() == 2);
......@@ -566,81 +565,24 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
auto arg1_shape = arg1_tensor_type->get_shape();
auto& arg0_element_type = arg0_tensor_type->get_element_type();
auto reduction_axes_count = dot->get_reduction_axes_count();
auto result_tensor_type =
dynamic_pointer_cast<const TensorViewType>(n->get_value_type());
assert(nullptr != result_tensor_type);
auto result_shape = result_tensor_type->get_shape();
// If arg0 or arg1 is a scalar, emit a scalar-tensor product.
if (arg0_shape.size() == 0)
{
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type",
instruction::ScalarTensorProductInstruction,
in[0],
in[1],
out[0]);
}
else if (arg1_shape.size() == 0)
{
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type",
instruction::ScalarTensorProductInstruction,
in[1],
in[0],
out[0]);
}
// If arg0 and arg1 are both vectors, dot both on axis 0.
else if (arg0_shape.size() == 1 && arg1_shape.size() == 1)
{
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type",
instruction::DotInstruction,
in[0],
in[1],
out[0],
arg0_shape,
arg1_shape,
result_shape,
0,
0);
}
// If arg0 is a matrix and arg1 is a vector, dot on axes 1 and 0 respectively.
else if (arg0_shape.size() == 2 && arg1_shape.size() == 1)
{
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type",
instruction::DotInstruction,
in[0],
in[1],
out[0],
arg0_shape,
arg1_shape,
result_shape,
1,
0);
}
// If arg0 is rank n and arg1 is rank m, dot on axes n-1 and m-2, respectively.
//
// Note that this happens to handle the vector-matrix and matrix-matrix cases.
else
{
PUSH_NUMERIC_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type",
instruction::DotInstruction,
in[0],
in[1],
out[0],
arg0_shape,
arg1_shape,
result_shape,
arg0_shape.size() - 1,
arg1_shape.size() - 2);
}
PUSH_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Dot has unhandled element type",
instruction::DotInstruction,
in[0],
in[1],
out[0],
arg0_shape,
arg1_shape,
result_shape,
reduction_axes_count);
};
// Parameter is a "runtime no-op" because the output tensor has already been filled.
......
......@@ -38,16 +38,14 @@ namespace ngraph
const Shape& arg0_shape,
const Shape& arg1_shape,
const Shape& out_shape,
size_t arg0_dot_axis,
size_t arg1_dot_axis)
size_t reduction_axes_count)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
, m_arg0_shape(arg0_shape)
, m_arg1_shape(arg1_shape)
, m_out_shape(out_shape)
, m_arg0_dot_axis(arg0_dot_axis)
, m_arg1_dot_axis(arg1_dot_axis)
, m_reduction_axes_count(reduction_axes_count)
{
}
......@@ -63,8 +61,7 @@ namespace ngraph
m_arg0_shape,
m_arg1_shape,
m_out_shape,
m_arg0_dot_axis,
m_arg1_dot_axis);
m_reduction_axes_count);
}
protected:
......@@ -74,8 +71,7 @@ namespace ngraph
Shape m_arg0_shape;
Shape m_arg1_shape;
Shape m_out_shape;
size_t m_arg0_dot_axis;
size_t m_arg1_dot_axis;
size_t m_reduction_axes_count;
};
}
}
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/kernel/scalar_tensor_product.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace instruction
{
template <typename ET>
class ScalarTensorProductInstruction : public Instruction
{
public:
ScalarTensorProductInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg1);
kernel::scalar_tensor_product<typename ET::type>(arg0, arg1, out, count);
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -521,6 +521,28 @@ TEST(${BACKEND_NAME}, backwards_dot_tensor2_tensor2)
autodiff_numeric_compare<float>(manager, backend, make_graph, {x0, x1}, .01f, .01f));
}
TEST(${BACKEND_NAME}, backwards_dot_tensor3_tensor3)
{
auto manager = runtime::Manager::get("NGVM");
auto backend = manager->allocate_backend();
test::Uniform<float> rng(-1.0f, 1.0f);
auto shape0 = Shape{2, 4, 3};
auto shape1 = Shape{4, 3, 3};
auto x0 = rng.initialize(backend->make_primary_tensor_view<float>(shape0));
auto x1 = rng.initialize(backend->make_primary_tensor_view<float>(shape1));
auto make_graph = [shape0, shape1]() {
auto X0 = make_shared<op::Parameter>(element::Float32::element_type(), shape0);
auto X1 = make_shared<op::Parameter>(element::Float32::element_type(), shape1);
return make_shared<Function>(make_shared<op::Dot>(X0, X1, 2),
nullptr,
std::vector<std::shared_ptr<op::Parameter>>{X0, X1});
};
EXPECT_TRUE(
autodiff_numeric_compare<float>(manager, backend, make_graph, {x0, x1}, .01f, .01f));
}
TEST(${BACKEND_NAME}, backwards_exp)
{
auto manager = runtime::Manager::get("${BACKEND_NAME}");
......
......@@ -808,19 +808,19 @@ TEST(${BACKEND_NAME}, dot2d)
// >>> a.shape=(2,2,2)
// >>> b.shape=(2,2,2)
//
// >>> tensordot(a,b,axes=([2],[1]))
// array([[[[ 7., 10.],
// [ 19., 22.]],
// >>> tensordot(a,b,axes=([2],[0]))
// array([[[[ 11., 14.],
// [ 17., 20.]],
//
// [[ 15., 22.],
// [ 43., 50.]]],
// [[ 23., 30.],
// [ 37., 44.]]],
//
//
// [[[ 23., 34.],
// [ 67., 78.]],
// [[[ 35., 46.],
// [ 57., 68.]],
//
// [[ 31., 46.],
// [ 91., 106.]]]])
// [[ 47., 62.],
// [ 77., 92.]]]])
//
TEST(${BACKEND_NAME}, dot3d_3d)
{
......@@ -844,7 +844,7 @@ TEST(${BACKEND_NAME}, dot3d_3d)
auto result = backend->make_primary_tensor_view(element::Float32::element_type(), shape_r);
cf->call({a, b}, {result});
EXPECT_EQ((vector<float>{7, 10, 19, 22, 15, 22, 43, 50, 23, 34, 67, 78, 31, 46, 91, 106}),
EXPECT_EQ((vector<float>{11, 14, 17, 20, 23, 30, 37, 44, 35, 46, 57, 68, 47, 62, 77, 92}),
result->get_vector<float>());
}
......@@ -2522,6 +2522,104 @@ TEST(${BACKEND_NAME}, reshape_m2m_dim_change_transpose)
EXPECT_EQ((vector<float>{1, 3, 5, 2, 4, 6}), result->get_vector<float>());
}
//
// Numpy:
//
// >>> x = linspace(1,2*2*3*3*2*4,2*2*3*3*2*4)
// >>> x.shape=(2,2,3,3,2,4)
// >>> y = ascontiguousarray(transpose(x,(2,4,0,5,3,1)))
// >>> y.shape=2*2*3*3*2*4
// >>> y
// array([ 1., 73., 9., 81., 17., 89., 2., 74., 10.,
// 82., 18., 90., 3., 75., 11., 83., 19., 91.,
// 4., 76., 12., 84., 20., 92., 145., 217., 153.,
// 225., 161., 233., 146., 218., 154., 226., 162., 234.,
// 147., 219., 155., 227., 163., 235., 148., 220., 156.,
// 228., 164., 236., 5., 77., 13., 85., 21., 93.,
// 6., 78., 14., 86., 22., 94., 7., 79., 15.,
// 87., 23., 95., 8., 80., 16., 88., 24., 96.,
// 149., 221., 157., 229., 165., 237., 150., 222., 158.,
// 230., 166., 238., 151., 223., 159., 231., 167., 239.,
// 152., 224., 160., 232., 168., 240., 25., 97., 33.,
// 105., 41., 113., 26., 98., 34., 106., 42., 114.,
// 27., 99., 35., 107., 43., 115., 28., 100., 36.,
// 108., 44., 116., 169., 241., 177., 249., 185., 257.,
// 170., 242., 178., 250., 186., 258., 171., 243., 179.,
// 251., 187., 259., 172., 244., 180., 252., 188., 260.,
// 29., 101., 37., 109., 45., 117., 30., 102., 38.,
// 110., 46., 118., 31., 103., 39., 111., 47., 119.,
// 32., 104., 40., 112., 48., 120., 173., 245., 181.,
// 253., 189., 261., 174., 246., 182., 254., 190., 262.,
// 175., 247., 183., 255., 191., 263., 176., 248., 184.,
// 256., 192., 264., 49., 121., 57., 129., 65., 137.,
// 50., 122., 58., 130., 66., 138., 51., 123., 59.,
// 131., 67., 139., 52., 124., 60., 132., 68., 140.,
// 193., 265., 201., 273., 209., 281., 194., 266., 202.,
// 274., 210., 282., 195., 267., 203., 275., 211., 283.,
// 196., 268., 204., 276., 212., 284., 53., 125., 61.,
// 133., 69., 141., 54., 126., 62., 134., 70., 142.,
// 55., 127., 63., 135., 71., 143., 56., 128., 64.,
// 136., 72., 144., 197., 269., 205., 277., 213., 285.,
// 198., 270., 206., 278., 214., 286., 199., 271., 207.,
// 279., 215., 287., 200., 272., 208., 280., 216., 288.])
//
// Disabled because it doesn't work on CPU yet.
//
TEST(DISABLED_${BACKEND_NAME}, reshape_6d)
{
vector<float> a_data(2 * 2 * 3 * 3 * 2 * 4);
for (int i = 0; i < 2 * 2 * 3 * 3 * 2 * 4; i++)
{
a_data[i] = float(i + 1);
}
auto shape_a = Shape{2, 2, 3, 3, 2, 4};
auto A = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_a));
auto shape_r = Shape{3, 2, 2, 4, 3, 2};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto r = make_shared<op::Reshape>(A, AxisVector{2, 4, 0, 5, 3, 1}, shape_r);
auto f = make_shared<Function>(r, rt, op::Parameters{A});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = backend->make_primary_tensor_view(element::Float32::element_type(), shape_a);
copy_data(a, a_data);
auto result = backend->make_primary_tensor_view(element::Float32::element_type(), shape_r);
cf->call({a}, {result});
ASSERT_EQ(
(vector<float>{
1., 73., 9., 81., 17., 89., 2., 74., 10., 82., 18., 90., 3., 75.,
11., 83., 19., 91., 4., 76., 12., 84., 20., 92., 145., 217., 153., 225.,
161., 233., 146., 218., 154., 226., 162., 234., 147., 219., 155., 227., 163., 235.,
148., 220., 156., 228., 164., 236., 5., 77., 13., 85., 21., 93., 6., 78.,
14., 86., 22., 94., 7., 79., 15., 87., 23., 95., 8., 80., 16., 88.,
24., 96., 149., 221., 157., 229., 165., 237., 150., 222., 158., 230., 166., 238.,
151., 223., 159., 231., 167., 239., 152., 224., 160., 232., 168., 240., 25., 97.,
33., 105., 41., 113., 26., 98., 34., 106., 42., 114., 27., 99., 35., 107.,
43., 115., 28., 100., 36., 108., 44., 116., 169., 241., 177., 249., 185., 257.,
170., 242., 178., 250., 186., 258., 171., 243., 179., 251., 187., 259., 172., 244.,
180., 252., 188., 260., 29., 101., 37., 109., 45., 117., 30., 102., 38., 110.,
46., 118., 31., 103., 39., 111., 47., 119., 32., 104., 40., 112., 48., 120.,
173., 245., 181., 253., 189., 261., 174., 246., 182., 254., 190., 262., 175., 247.,
183., 255., 191., 263., 176., 248., 184., 256., 192., 264., 49., 121., 57., 129.,
65., 137., 50., 122., 58., 130., 66., 138., 51., 123., 59., 131., 67., 139.,
52., 124., 60., 132., 68., 140., 193., 265., 201., 273., 209., 281., 194., 266.,
202., 274., 210., 282., 195., 267., 203., 275., 211., 283., 196., 268., 204., 276.,
212., 284., 53., 125., 61., 133., 69., 141., 54., 126., 62., 134., 70., 142.,
55., 127., 63., 135., 71., 143., 56., 128., 64., 136., 72., 144., 197., 269.,
205., 277., 213., 285., 198., 270., 206., 278., 214., 286., 199., 271., 207., 279.,
215., 287., 200., 272., 208., 280., 216., 288.}),
result->get_vector<float>());
}
TEST(${BACKEND_NAME}, sin)
{
auto shape = Shape{6};
......@@ -4035,3 +4133,337 @@ TEST(${BACKEND_NAME}, replace_slice_3d_strided_different_strides)
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}),
result->get_vector<float>());
}
//
// Numpy test:
//
// > from numpy import *
// > x = linspace(1,2*3*4,2*3*4)
// > y = linspace(1,3*4*5,3*4*5)
// > x.shape=(2,3,4)
// > y.shape=(3,4,5)
// > z = tensordot(x,y,([1,2],[0,1]))
// > z.shape = 2*5
// > z
// array([ 2938., 3016., 3094., 3172., 3250., 7042., 7264., 7486.,
// 7708., 7930.])
//
// Disabled because it doesn't work on CPU yet.
//
TEST(DISABLED_${BACKEND_NAME}, dot_3d_multi_axis)
{
vector<float> a_data(2 * 3 * 4);
for (int i = 0; i < 2 * 3 * 4; i++)
{
a_data[i] = float(i + 1);
}
vector<float> b_data(3 * 4 * 5);
for (int i = 0; i < 3 * 4 * 5; i++)
{
b_data[i] = float(i + 1);
}
auto shape_a = Shape{2, 3, 4};
auto A = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_a));
auto shape_b = Shape{3, 4, 5};
auto B = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_b));
auto shape_r = Shape{2, 5};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto r = make_shared<op::Dot>(A, B, 2);
auto f = make_shared<Function>(r, rt, op::Parameters{A, B});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = backend->make_primary_tensor_view(element::Float32::element_type(), shape_a);
copy_data(a, a_data);
auto b = backend->make_primary_tensor_view(element::Float32::element_type(), shape_b);
copy_data(b, b_data);
auto result = backend->make_primary_tensor_view(element::Float32::element_type(), shape_r);
cf->call({a, b}, {result});
ASSERT_EQ((vector<float>{2938., 3016., 3094., 3172., 3250., 7042., 7264., 7486., 7708., 7930.}),
result->get_vector<float>());
}
//
// Numpy test:
//
// > from numpy import *
// > x = array([6,61,2,3,5,21,75,23,23,0,23,2,35,67,1,2,9,16,2,3,6,1,8,0])
// > y = array([9,1,4,6,3,5,1,36,7,3,5,0,1,20,35,2,1,0,1,25,3,6,7,8])
// > x.shape=(2,4,3)
// > y.shape=(3,4,2)
// > z = tensordot(x,y,([2],[0]))
// > z.shape = 2*4*4*2
// > z
// array([ 483, 189, 331, 86, 85, 1262, 2155, 354, 83, 18, 58,
// 543, 77, 241, 325, 286, 859, 144, 438, 1025, 317, 973,
// 1041, 2930, 163, 69, 117, 50, 29, 472, 819, 62, 785,
// 236, 476, 235, 175, 1521, 2387, 1402, 97, 29, 69, 412,
// 63, 286, 429, 218, 45, 11, 29, 162, 27, 106, 149,
// 126, 65, 25, 44, 6, 11, 165, 281, 52])
//
// Disabled because it doesn't work on CPU yet.
//
TEST(DISABLED_${BACKEND_NAME}, dot_3d_one_axis_arbitrary)
{
vector<float> a_data{6, 61, 2, 3, 5, 21, 75, 23, 23, 0, 23, 2,
35, 67, 1, 2, 9, 16, 2, 3, 6, 1, 8, 0};
vector<float> b_data{9, 1, 4, 6, 3, 5, 1, 36, 7, 3, 5, 0,
1, 20, 35, 2, 1, 0, 1, 25, 3, 6, 7, 8};
auto shape_a = Shape{2, 4, 3};
auto A = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_a));
auto shape_b = Shape{3, 4, 2};
auto B = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_b));
auto shape_r = Shape{2, 4, 4, 2};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto r = make_shared<op::Dot>(A, B);
auto f = make_shared<Function>(r, rt, op::Parameters{A, B});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = backend->make_primary_tensor_view(element::Float32::element_type(), shape_a);
copy_data(a, a_data);
auto b = backend->make_primary_tensor_view(element::Float32::element_type(), shape_b);
copy_data(b, b_data);
auto result = backend->make_primary_tensor_view(element::Float32::element_type(), shape_r);
cf->call({a, b}, {result});
ASSERT_EQ((vector<float>{483, 189, 331, 86, 85, 1262, 2155, 354, 83, 18, 58, 543, 77,
241, 325, 286, 859, 144, 438, 1025, 317, 973, 1041, 2930, 163, 69,
117, 50, 29, 472, 819, 62, 785, 236, 476, 235, 175, 1521, 2387,
1402, 97, 29, 69, 412, 63, 286, 429, 218, 45, 11, 29, 162,
27, 106, 149, 126, 65, 25, 44, 6, 11, 165, 281, 52}),
result->get_vector<float>());
}
//
// Numpy test:
//
// from numpy import *
// x = linspace(1,2*3*3*4,2*3*3*4)
// y = linspace(1,3*4*2*3*2,3*4*2*2*3)
// x.shape=(2,3,3,4)
// y.shape=(3,4,2,2,3)
// z = tensordot(x,y,([2,3],[0,1]))
// z.shape = 2*3*2*2*3
// z
//
// array([ 6942., 7020., 7098., 7176., 7254., 7332., 7410.,
// 7488., 7566., 7644., 7722., 7800., 16590., 16812.,
// 17034., 17256., 17478., 17700., 17922., 18144., 18366.,
// 18588., 18810., 19032., 26238., 26604., 26970., 27336.,
// 27702., 28068., 28434., 28800., 29166., 29532., 29898.,
// 30264., 35886., 36396., 36906., 37416., 37926., 38436.,
// 38946., 39456., 39966., 40476., 40986., 41496., 45534.,
// 46188., 46842., 47496., 48150., 48804., 49458., 50112.,
// 50766., 51420., 52074., 52728., 55182., 55980., 56778.,
// 57576., 58374., 59172., 59970., 60768., 61566., 62364.,
// 63162., 63960.])
//
// Disabled because it doesn't work on CPU yet.
//
TEST(DISABLED_${BACKEND_NAME}, dot_4d_5d_multi_axis)
{
vector<float> a_data(2 * 3 * 3 * 4);
for (int i = 0; i < 2 * 3 * 3 * 4; i++)
{
a_data[i] = float(i + 1);
}
vector<float> b_data(3 * 4 * 2 * 2 * 3);
for (int i = 0; i < 3 * 4 * 2 * 2 * 3; i++)
{
b_data[i] = float(i + 1);
}
auto shape_a = Shape{2, 3, 3, 4};
auto A = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_a));
auto shape_b = Shape{3, 4, 2, 3, 2};
auto B = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_b));
auto shape_r = Shape{2, 3, 2, 3, 2};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto r = make_shared<op::Dot>(A, B, 2);
auto f = make_shared<Function>(r, rt, op::Parameters{A, B});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = backend->make_primary_tensor_view(element::Float32::element_type(), shape_a);
copy_data(a, a_data);
auto b = backend->make_primary_tensor_view(element::Float32::element_type(), shape_b);
copy_data(b, b_data);
auto result = backend->make_primary_tensor_view(element::Float32::element_type(), shape_r);
cf->call({a, b}, {result});
ASSERT_EQ(
(vector<float>{6942., 7020., 7098., 7176., 7254., 7332., 7410., 7488., 7566.,
7644., 7722., 7800., 16590., 16812., 17034., 17256., 17478., 17700.,
17922., 18144., 18366., 18588., 18810., 19032., 26238., 26604., 26970.,
27336., 27702., 28068., 28434., 28800., 29166., 29532., 29898., 30264.,
35886., 36396., 36906., 37416., 37926., 38436., 38946., 39456., 39966.,
40476., 40986., 41496., 45534., 46188., 46842., 47496., 48150., 48804.,
49458., 50112., 50766., 51420., 52074., 52728., 55182., 55980., 56778.,
57576., 58374., 59172., 59970., 60768., 61566., 62364., 63162., 63960.}),
result->get_vector<float>());
}
//
// Numpy test:
//
// from numpy import *
// x = linspace(1,2*3*3*4,2*3*3*4)
// y = linspace(1,2*3*3*4*2,2*3*3*4*2)
// x.shape=(2,3,3,4)
// y.shape=(2,3,3,4,2)
// z = tensordot(x,y,([0,1,2,3],[0,1,2,3]))
// z
//
// array([ 251412., 254040.])
//
// Disabled because it doesn't work on CPU yet.
//
TEST(DISABLED_${BACKEND_NAME}, dot_4d_5d_multi_axis_more)
{
vector<float> a_data(2 * 3 * 3 * 4);
for (int i = 0; i < 2 * 3 * 3 * 4; i++)
{
a_data[i] = float(i + 1);
}
vector<float> b_data(2 * 3 * 3 * 4 * 2);
for (int i = 0; i < 2 * 3 * 3 * 4 * 2; i++)
{
b_data[i] = float(i + 1);
}
auto shape_a = Shape{2, 3, 3, 4};
auto A = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_a));
auto shape_b = Shape{2, 3, 3, 4, 2};
auto B = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_b));
auto shape_r = Shape{2};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto r = make_shared<op::Dot>(A, B, 4);
auto f = make_shared<Function>(r, rt, op::Parameters{A, B});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = backend->make_primary_tensor_view(element::Float32::element_type(), shape_a);
copy_data(a, a_data);
auto b = backend->make_primary_tensor_view(element::Float32::element_type(), shape_b);
copy_data(b, b_data);
auto result = backend->make_primary_tensor_view(element::Float32::element_type(), shape_r);
cf->call({a, b}, {result});
ASSERT_EQ((vector<float>{251412., 254040.}), result->get_vector<float>());
}
//
// Numpy test:
//
// from numpy import *
// x = linspace(1,20*30*30*40,20*30*30*40)
// y = linspace(1,20*30*30*40*20,20*30*30*40*20)
// x.shape=(20,30,30,40)
// y.shape=(20,30,30,40,20)
// z = tensordot(x,y,([0,1,2,3],[0,1,2,3]))
// set_printoptions(precision=20)
// z
//
// array([ 2.48832025919525478400e+18, 2.48832051839533977600e+18,
// 2.48832077759658444800e+18, 2.48832103679413504000e+18,
// 2.48832129599669350400e+18, 2.48832155519793971200e+18,
// 2.48832181439802265600e+18, 2.48832207359808000000e+18,
// 2.48832233279813580800e+18, 2.48832259199822028800e+18,
// 2.48832285119946496000e+18, 2.48832311040043008000e+18,
// 2.48832336959957401600e+18, 2.48832362880081817600e+18,
// 2.48832388800090368000e+18, 2.48832414720096000000e+18,
// 2.48832440640101478400e+18, 2.48832466560109772800e+18,
// 2.48832492480234188800e+18, 2.48832518400031897600e+18])
//
// Disabled because this test is very slow.
//
TEST(DISABLED_${BACKEND_NAME}, dot_4d_5d_multi_axis_big_fp64_VERY_SLOW)
{
vector<double> a_data(20 * 30 * 30 * 40);
for (int i = 0; i < 20 * 30 * 30 * 40; i++)
{
a_data[i] = double(i + 1);
}
vector<double> b_data(20 * 30 * 30 * 40 * 20);
for (int i = 0; i < 20 * 30 * 30 * 40 * 20; i++)
{
b_data[i] = double(i + 1);
}
auto shape_a = Shape{20, 30, 30, 40};
auto A = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float64::element_type(), shape_a));
auto shape_b = Shape{20, 30, 30, 40, 20};
auto B = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float64::element_type(), shape_b));
auto shape_r = Shape{20};
auto rt = make_shared<TensorViewType>(element::Float64::element_type(), shape_r);
auto r = make_shared<op::Dot>(A, B, 4);
auto f = make_shared<Function>(r, rt, op::Parameters{A, B});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = backend->make_primary_tensor_view(element::Float64::element_type(), shape_a);
copy_data(a, a_data);
auto b = backend->make_primary_tensor_view(element::Float64::element_type(), shape_b);
copy_data(b, b_data);
auto result = backend->make_primary_tensor_view(element::Float64::element_type(), shape_r);
cf->call({a, b}, {result});
ASSERT_EQ(
(vector<double>{
2.48832025919525478400e+18, 2.48832051839533977600e+18, 2.48832077759658444800e+18,
2.48832103679413504000e+18, 2.48832129599669350400e+18, 2.48832155519793971200e+18,
2.48832181439802265600e+18, 2.48832207359808000000e+18, 2.48832233279813580800e+18,
2.48832259199822028800e+18, 2.48832285119946496000e+18, 2.48832311040043008000e+18,
2.48832336959957401600e+18, 2.48832362880081817600e+18, 2.48832388800090368000e+18,
2.48832414720096000000e+18, 2.48832440640101478400e+18, 2.48832466560109772800e+18,
2.48832492480234188800e+18, 2.48832518400031897600e+18}),
result->get_vector<double>());
}
......@@ -32,15 +32,6 @@ TEST(type_prop, broadcast_deduce)
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 3, 4}));
}
TEST(type_prop, broadcast_deduce_correct)
{
// Check deduced type against correctly specified type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 4});
auto bc = make_shared<op::Broadcast>(param, Shape{2, 3, 4}, AxisSet{1});
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 3, 4}));
}
TEST(type_prop, broadcast_deduce_incorrect)
{
// Check deduced type against incorrectly specified type
......@@ -228,15 +219,6 @@ TEST(type_prop, convert_deduce)
ASSERT_EQ(*c_vt, TensorViewType(element::Int32::element_type(), Shape{2, 3, 4}));
}
TEST(type_prop, convert_deduce_correct)
{
// Check deduced type against incorrectly specified type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3, 4});
auto c = make_shared<op::Convert>(param, element::Int32::element_type());
auto c_vt = c->get_value_type();
ASSERT_EQ(*c_vt, TensorViewType(element::Int32::element_type(), Shape{2, 3, 4}));
}
TEST(type_prop, convert_deduce_incorrect)
{
// Check deduced type against incorrectly specified type
......@@ -322,17 +304,7 @@ TEST(type_prop, dot_deduce_different_rank)
{
// Deduce type for different-rank tensor arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 8, 4, 2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1, 2, 3});
auto bc = make_shared<op::Dot>(param1, param2);
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 8, 4, 1, 3}));
}
TEST(type_prop, dot_deduce_different_rank_correct)
{
// Deduced type matches explicitly set type
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 8, 4, 2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1, 2, 3});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 1, 3});
auto bc = make_shared<op::Dot>(param1, param2);
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 8, 4, 1, 3}));
......@@ -372,7 +344,7 @@ TEST(type_prop, dot_deduce_reduction_axes_size_mismatch)
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Dot reduction axes not compatible"));
EXPECT_EQ(error.what(), std::string("Dot axes do not have same length"));
}
catch (...)
{
......@@ -571,19 +543,6 @@ TEST(type_prop, select_deduce)
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 4}));
}
TEST(type_prop, select_deduce_correct)
{
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Bool::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_2 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 4}));
}
TEST(type_prop, select_shape_mismatch_a)
{
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
......@@ -735,24 +694,6 @@ TEST(type_prop, reduce_deduce)
TensorViewType(element::Float32::element_type(), Shape{2, 4}));
}
TEST(type_prop, reduce_deduce_correct)
{
auto param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{}));
auto f_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{}));
auto f_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{}));
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
auto f = make_shared<Function>(f_param_0 + f_param_1, rt, op::Parameters{f_param_0, f_param_1});
auto r0 = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{0});
ASSERT_EQ(*(r0->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{4}));
}
TEST(type_prop, reduce_nonscalar)
{
auto param_0 = make_shared<op::Parameter>(
......@@ -1073,14 +1014,6 @@ TEST(type_prop, reshape_deduce_t2v_120)
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{60}));
}
TEST(type_prop, reshape_deduce_correct_t2v_120)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 4, 5}));
auto r = make_shared<op::Reshape>(param, AxisVector{1, 2, 0}, Shape{60});
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{60}));
}
TEST(type_prop, reshape_deduce_not_enough_axes)
{
auto param = make_shared<op::Parameter>(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment