Commit 070b958e authored by Scott Cyphers's avatar Scott Cyphers

style

parent 4aa19e31
......@@ -24,9 +24,9 @@
#include "ngraph/node.hpp"
#include "ngraph/op.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/shape.hpp"
......
......@@ -43,8 +43,7 @@ namespace ngraph
namespace op
{
Node::ptr broadcast(const Node::ptr& tensor,
const Shape& shape,
const std::vector<size_t>& broadcast_axes);
const Shape& shape,
const std::vector<size_t>& broadcast_axes);
}
}
......@@ -16,7 +16,6 @@
namespace ngraph
{
class DotOp : public BuiltinOp
{
public:
......
......@@ -23,36 +23,35 @@ using namespace ngraph;
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
const Shape& shape,
const vector<size_t>& broadcast_axes)
Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
const Shape& shape,
const vector<size_t>& broadcast_axes)
{
return make_shared<BroadcastOp>(tensor, shape, broadcast_axes);
return make_shared<BroadcastOp>(tensor, shape, broadcast_axes);
}
void BroadcastOp::propagate_types()
{
auto arg_type = m_arguments.at(0)->type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to broadcast is missing type.");
}
auto arg_tensor_view_type = dynamic_pointer_cast<TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to broadcast is not a tensor view");
}
vector<size_t> target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
{
target_shape.erase(target_shape.begin() + *i);
auto arg_type = m_arguments.at(0)->type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to broadcast is missing type.");
}
auto arg_tensor_view_type = dynamic_pointer_cast<TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to broadcast is not a tensor view");
}
vector<size_t> target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
{
target_shape.erase(target_shape.begin() + *i);
}
if (Shape{target_shape} != arg_tensor_view_type->shape())
{
throw ngraph_error("Broadcast arg, shape, and axes are incompatible");
}
// TODO If m_type is already set (by framework), this should verify that the type
// we expect is consistent with the type the framework expects.
m_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape);
}
if (Shape{target_shape} != arg_tensor_view_type->shape())
{
throw ngraph_error("Broadcast arg, shape, and axes are incompatible");
}
// TODO If m_type is already set (by framework), this should verify that the type
// we expect is consistent with the type the framework expects.
m_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment