Commit ab4b984e authored by Adam Procter's avatar Adam Procter

Updates to comments; style fixes in dot.cpp

parent 2b2354e3
...@@ -29,12 +29,14 @@ namespace ngraph ...@@ -29,12 +29,14 @@ namespace ngraph
/// scalar-scalar product. /// scalar-scalar product.
/// (Example: arg0 has shape {1,2,3} and arg1 has shape {}; then /// (Example: arg0 has shape {1,2,3} and arg1 has shape {}; then
/// the result will have shape {1,2,3}.) /// the result will have shape {1,2,3}.)
///
/// (2) arg1 is 1-dimensional. Then, we compute a dot product reducing /// (2) arg1 is 1-dimensional. Then, we compute a dot product reducing
/// on the innermost dimensions of arg0 and arg1. /// on the innermost (rightmost) dimensions of arg0 and arg1.
/// (Example: arg0 has shape {1,2,3} and arg1 has shape {3}; then /// (Example: arg0 has shape {1,2,3} and arg1 has shape {3}; then
/// the result will have shape {1,2}.) /// the result will have shape {1,2}.)
///
/// (3) arg1 is more than 1-dimensional. Then, we compute a dot product /// (3) arg1 is more than 1-dimensional. Then, we compute a dot product
/// reducing on the innermost dimension of arg0, and the /// reducing on the innermost (rightmost) dimension of arg0, and the
/// next-to-innermost dimension of arg1. /// next-to-innermost dimension of arg1.
/// (Example: arg0 has shape {3,4} and arg1 has shape {4,3}; then /// (Example: arg0 has shape {3,4} and arg1 has shape {4,3}; then
/// the result will have shape {3,3}.) /// the result will have shape {3,3}.)
......
...@@ -38,7 +38,7 @@ void Dot::propagate_types() ...@@ -38,7 +38,7 @@ void Dot::propagate_types()
vector<size_t> arg1_shape = arg1_tensor_type->get_shape(); vector<size_t> arg1_shape = arg1_tensor_type->get_shape();
size_t arg0_reduction = arg0_shape.size() - 1; size_t arg0_reduction = arg0_shape.size() - 1;
size_t arg1_reduction; size_t arg1_reduction;
bool is_scalar_mult = arg0_shape.size() == 0 || arg1_shape.size() == 0; const bool is_scalar_mult = arg0_shape.size() == 0 || arg1_shape.size() == 0;
if (arg1_shape.size() > 1) if (arg1_shape.size() > 1)
{ {
...@@ -48,7 +48,7 @@ void Dot::propagate_types() ...@@ -48,7 +48,7 @@ void Dot::propagate_types()
{ {
arg1_reduction = arg1_shape.size() - 1; arg1_reduction = arg1_shape.size() - 1;
} }
if (!is_scalar_mult && arg0_shape.at(arg0_reduction) != arg1_shape.at(arg1_reduction)) if (!is_scalar_mult && (arg0_shape.at(arg0_reduction) != arg1_shape.at(arg1_reduction)))
{ {
throw ngraph_error("Dot reduction axes not compatible"); throw ngraph_error("Dot reduction axes not compatible");
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment