Commit ffe85534 authored by Adam Procter's avatar Adam Procter

Merge branch 'master' into aprocter/doxygen

parents ed7e7b78 e92e5e5b
...@@ -21,7 +21,25 @@ namespace ngraph ...@@ -21,7 +21,25 @@ namespace ngraph
class Dot : public Builtin class Dot : public Builtin
{ {
public: public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction. /// Computes the dot product of two tensors.
///
/// There are three possible cases:
/// (1) arg0 or arg1 is 0-dimensional. Then, we treat the 0-dimensional
/// argument(s) as scalars and compute a scalar-tensor or
/// scalar-scalar product.
/// (Example: arg0 has shape {1,2,3} and arg1 has shape {}; then
/// the result will have shape {1,2,3}.)
///
/// (2) arg1 is 1-dimensional. Then, we compute a dot product reducing
/// on the innermost (rightmost) dimensions of arg0 and arg1.
/// (Example: arg0 has shape {1,2,3} and arg1 has shape {3}; then
/// the result will have shape {1,2}.)
///
/// (3) arg1 is more than 1-dimensional. Then, we compute a dot product
/// reducing on the innermost (rightmost) dimension of arg0, and the
/// next-to-innermost dimension of arg1.
/// (Example: arg0 has shape {3,4} and arg1 has shape {4,3}; then
/// the result will have shape {3,3}.)
Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1}) : Builtin({arg0, arg1})
{ {
......
...@@ -34,12 +34,12 @@ void Dot::propagate_types() ...@@ -34,12 +34,12 @@ void Dot::propagate_types()
throw ngraph_error("Arguments to dot must have the same element type"); throw ngraph_error("Arguments to dot must have the same element type");
} }
// Use NumPy semantics for now
// Last axis of first arg reduces against second to last of second arg if more than one axis, else the only axis.
vector<size_t> arg0_shape = arg0_tensor_type->get_shape(); vector<size_t> arg0_shape = arg0_tensor_type->get_shape();
vector<size_t> arg1_shape = arg1_tensor_type->get_shape(); vector<size_t> arg1_shape = arg1_tensor_type->get_shape();
size_t arg0_reduction = arg0_shape.size() - 1; size_t arg0_reduction = arg0_shape.size() - 1;
size_t arg1_reduction; size_t arg1_reduction;
const bool is_scalar_mult = arg0_shape.size() == 0 || arg1_shape.size() == 0;
if (arg1_shape.size() > 1) if (arg1_shape.size() > 1)
{ {
arg1_reduction = arg1_shape.size() - 2; arg1_reduction = arg1_shape.size() - 2;
...@@ -48,21 +48,29 @@ void Dot::propagate_types() ...@@ -48,21 +48,29 @@ void Dot::propagate_types()
{ {
arg1_reduction = arg1_shape.size() - 1; arg1_reduction = arg1_shape.size() - 1;
} }
if (arg0_shape.at(arg0_reduction) != arg1_shape.at(arg1_reduction)) if (!is_scalar_mult && (arg0_shape.at(arg0_reduction) != arg1_shape.at(arg1_reduction)))
{ {
throw ngraph_error("Dot reduction axes not compatible"); throw ngraph_error("Dot reduction axes not compatible");
} }
vector<size_t> result_shape; vector<size_t> result_shape;
result_shape.reserve(arg0_shape.size() + arg1_shape.size() - 2); result_shape.reserve(arg0_shape.size() + arg1_shape.size() - (is_scalar_mult ? 0 : 2));
for(auto i = 0; i < arg0_shape.size(); i++) for(auto i = 0; i < arg0_shape.size(); i++)
if(i != arg0_reduction) {
if(is_scalar_mult || i != arg0_reduction)
{
result_shape.push_back(arg0_shape[i]); result_shape.push_back(arg0_shape[i]);
}
}
for(auto i = 0; i < arg1_shape.size(); i++) for(auto i = 0; i < arg1_shape.size(); i++)
if(i != arg1_reduction) {
if(is_scalar_mult || i != arg1_reduction)
{
result_shape.push_back(arg1_shape[i]); result_shape.push_back(arg1_shape[i]);
}
}
auto result_type = make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape); auto result_type = make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape);
set_value_type_checked(result_type); set_value_type_checked(result_type);
......
...@@ -97,6 +97,50 @@ TEST(type_prop, broadcast_bad_arguments) ...@@ -97,6 +97,50 @@ TEST(type_prop, broadcast_bad_arguments)
// //
// Tests for dot product. // Tests for dot product.
// //
TEST(type_prop, dot_deduce_scalar_2d)
{
// Deduce type for 1D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4,5});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{4,5}));
}
TEST(type_prop, dot_deduce_2d_scalar)
{
// Deduce type for 1D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4,5});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{4,5}));
}
TEST(type_prop, dot_deduce_scalar_scalar)
{
// Deduce type for 1D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{}));
}
TEST(type_prop, dot_deduce_scalar_1d)
{
// Deduce type for 1D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{6});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{6}));
}
TEST(type_prop, dot_deduce_1d) TEST(type_prop, dot_deduce_1d)
{ {
// Deduce type for 1D arguments // Deduce type for 1D arguments
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment