Commit e92e5e5b authored by Adam Procter's avatar Adam Procter Committed by GitHub

Merge pull request #97 from NervanaSystems/aprocter/dot-handle-0d

Add support for 0D tensors to dot, clean up comments
parents 4528f86d ab4b984e
......@@ -21,7 +21,25 @@ namespace ngraph
class Dot : public Builtin
{
public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
/// Computes the dot product of two tensors.
///
/// There are three possible cases:
/// (1) arg0 or arg1 is 0-dimensional. Then, we treat the 0-dimensional
/// argument(s) as scalars and compute a scalar-tensor or
/// scalar-scalar product.
/// (Example: arg0 has shape {1,2,3} and arg1 has shape {}; then
/// the result will have shape {1,2,3}.)
///
/// (2) arg1 is 1-dimensional. Then, we compute a dot product reducing
/// on the innermost (rightmost) dimensions of arg0 and arg1.
/// (Example: arg0 has shape {1,2,3} and arg1 has shape {3}; then
/// the result will have shape {1,2}.)
///
/// (3) arg1 is more than 1-dimensional. Then, we compute a dot product
/// reducing on the innermost (rightmost) dimension of arg0, and the
/// next-to-innermost dimension of arg1.
/// (Example: arg0 has shape {3,4} and arg1 has shape {4,3}; then
/// the result will have shape {3,3}.)
Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
{
......
......@@ -34,12 +34,12 @@ void Dot::propagate_types()
throw ngraph_error("Arguments to dot must have the same element type");
}
// Use NumPy semantics for now
// Last axis of first arg reduces against second to last of second arg if more than one axis, else the only axis.
vector<size_t> arg0_shape = arg0_tensor_type->get_shape();
vector<size_t> arg1_shape = arg1_tensor_type->get_shape();
size_t arg0_reduction = arg0_shape.size() - 1;
size_t arg1_reduction;
const bool is_scalar_mult = arg0_shape.size() == 0 || arg1_shape.size() == 0;
if (arg1_shape.size() > 1)
{
arg1_reduction = arg1_shape.size() - 2;
......@@ -48,21 +48,29 @@ void Dot::propagate_types()
{
arg1_reduction = arg1_shape.size() - 1;
}
if (arg0_shape.at(arg0_reduction) != arg1_shape.at(arg1_reduction))
if (!is_scalar_mult && (arg0_shape.at(arg0_reduction) != arg1_shape.at(arg1_reduction)))
{
throw ngraph_error("Dot reduction axes not compatible");
}
vector<size_t> result_shape;
result_shape.reserve(arg0_shape.size() + arg1_shape.size() - 2);
result_shape.reserve(arg0_shape.size() + arg1_shape.size() - (is_scalar_mult ? 0 : 2));
for(auto i = 0; i < arg0_shape.size(); i++)
if(i != arg0_reduction)
{
if(is_scalar_mult || i != arg0_reduction)
{
result_shape.push_back(arg0_shape[i]);
}
}
for(auto i = 0; i < arg1_shape.size(); i++)
if(i != arg1_reduction)
{
if(is_scalar_mult || i != arg1_reduction)
{
result_shape.push_back(arg1_shape[i]);
}
}
auto result_type = make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape);
set_value_type_checked(result_type);
......
......@@ -97,6 +97,50 @@ TEST(type_prop, broadcast_bad_arguments)
//
// Tests for dot product.
//
TEST(type_prop, dot_deduce_scalar_2d)
{
// Deduce type for 1D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4,5});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{4,5}));
}
TEST(type_prop, dot_deduce_2d_scalar)
{
// Deduce type for 1D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4,5});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{4,5}));
}
TEST(type_prop, dot_deduce_scalar_scalar)
{
// Deduce type for 1D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{}));
}
TEST(type_prop, dot_deduce_scalar_1d)
{
// Deduce type for 1D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{6});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{6}));
}
TEST(type_prop, dot_deduce_1d)
{
// Deduce type for 1D arguments
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment