Commit 070b958e authored by Scott Cyphers's avatar Scott Cyphers

style

parent 4aa19e31
...@@ -24,9 +24,9 @@ ...@@ -24,9 +24,9 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op.hpp" #include "ngraph/op.hpp"
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/concatenate.hpp" #include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp" #include "ngraph/ops/constant.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/parameter.hpp" #include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/tuple.hpp" #include "ngraph/ops/tuple.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
......
...@@ -45,6 +45,5 @@ namespace ngraph ...@@ -45,6 +45,5 @@ namespace ngraph
Node::ptr broadcast(const Node::ptr& tensor, Node::ptr broadcast(const Node::ptr& tensor,
const Shape& shape, const Shape& shape,
const std::vector<size_t>& broadcast_axes); const std::vector<size_t>& broadcast_axes);
} }
} }
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
namespace ngraph namespace ngraph
{ {
class DotOp : public BuiltinOp class DotOp : public BuiltinOp
{ {
public: public:
......
...@@ -23,36 +23,35 @@ using namespace ngraph; ...@@ -23,36 +23,35 @@ using namespace ngraph;
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast. ** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg. ** the remaining axes in shape must be the same as the shape of arg.
**/ **/
Node::ptr ngraph::op::broadcast(const Node::ptr& tensor, Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
const Shape& shape, const Shape& shape,
const vector<size_t>& broadcast_axes) const vector<size_t>& broadcast_axes)
{ {
return make_shared<BroadcastOp>(tensor, shape, broadcast_axes); return make_shared<BroadcastOp>(tensor, shape, broadcast_axes);
} }
void BroadcastOp::propagate_types() void BroadcastOp::propagate_types()
{ {
auto arg_type = m_arguments.at(0)->type(); auto arg_type = m_arguments.at(0)->type();
if (nullptr == arg_type) if (nullptr == arg_type)
{ {
throw ngraph_error("Argument to broadcast is missing type."); throw ngraph_error("Argument to broadcast is missing type.");
} }
auto arg_tensor_view_type = dynamic_pointer_cast<TensorViewType>(arg_type); auto arg_tensor_view_type = dynamic_pointer_cast<TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type) if (nullptr == arg_tensor_view_type)
{ {
throw ngraph_error("Argument to broadcast is not a tensor view"); throw ngraph_error("Argument to broadcast is not a tensor view");
} }
vector<size_t> target_shape = m_shape; vector<size_t> target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i) for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
{ {
target_shape.erase(target_shape.begin() + *i); target_shape.erase(target_shape.begin() + *i);
}
if (Shape{target_shape} != arg_tensor_view_type->shape())
{
throw ngraph_error("Broadcast arg, shape, and axes are incompatible");
}
// TODO If m_type is already set (by framework), this should verify that the type
// we expect is consistent with the type the framework expects.
m_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape);
} }
if (Shape{target_shape} != arg_tensor_view_type->shape())
{
throw ngraph_error("Broadcast arg, shape, and axes are incompatible");
}
// TODO If m_type is already set (by framework), this should verify that the type
// we expect is consistent with the type the framework expects.
m_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment