Commit ab4b984e authored by Adam Procter's avatar Adam Procter

Updates to comments; style fixes in dot.cpp

parent 2b2354e3
......@@ -29,12 +29,14 @@ namespace ngraph
/// scalar-scalar product.
/// (Example: arg0 has shape {1,2,3} and arg1 has shape {}; then
/// the result will have shape {1,2,3}.)
///
/// (2) arg1 is 1-dimensional. Then, we compute a dot product reducing
/// on the innermost dimensions of arg0 and arg1.
/// on the innermost (rightmost) dimensions of arg0 and arg1.
/// (Example: arg0 has shape {1,2,3} and arg1 has shape {3}; then
/// the result will have shape {1,2}.)
///
/// (3) arg1 is more than 1-dimensional. Then, we compute a dot product
/// reducing on the innermost dimension of arg0, and the
/// reducing on the innermost (rightmost) dimension of arg0, and the
/// next-to-innermost dimension of arg1.
/// (Example: arg0 has shape {3,4} and arg1 has shape {4,3}; then
/// the result will have shape {3,3}.)
......
......@@ -38,7 +38,7 @@ void Dot::propagate_types()
vector<size_t> arg1_shape = arg1_tensor_type->get_shape();
size_t arg0_reduction = arg0_shape.size() - 1;
size_t arg1_reduction;
bool is_scalar_mult = arg0_shape.size() == 0 || arg1_shape.size() == 0;
const bool is_scalar_mult = arg0_shape.size() == 0 || arg1_shape.size() == 0;
if (arg1_shape.size() > 1)
{
......@@ -48,7 +48,7 @@ void Dot::propagate_types()
{
arg1_reduction = arg1_shape.size() - 1;
}
if (!is_scalar_mult && arg0_shape.at(arg0_reduction) != arg1_shape.at(arg1_reduction))
if (!is_scalar_mult && (arg0_shape.at(arg0_reduction) != arg1_shape.at(arg1_reduction)))
{
throw ngraph_error("Dot reduction axes not compatible");
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment