Commit 2b498a6c authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Add BroadcastLike and ScalarConstantLike to op_tbl (#2213)

* add missing ops to op_tbl

* add two missing ops
parent 4514faf9
...@@ -37,7 +37,7 @@ using namespace ngraph; ...@@ -37,7 +37,7 @@ using namespace ngraph;
std::shared_ptr<Node> make_zero(const std::shared_ptr<Node>& node) std::shared_ptr<Node> make_zero(const std::shared_ptr<Node>& node)
{ {
std::shared_ptr<Node> zero = std::make_shared<op::ScalarConstantLike<double>>(node, 0.0); std::shared_ptr<Node> zero = std::make_shared<op::ScalarConstantLike>(node, 0.0);
std::shared_ptr<Node> bzero = std::make_shared<op::BroadcastLike>(zero, node, AxisSet{}); std::shared_ptr<Node> bzero = std::make_shared<op::BroadcastLike>(zero, node, AxisSet{});
return bzero; return bzero;
} }
......
...@@ -50,7 +50,7 @@ void op::Acos::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -50,7 +50,7 @@ void op::Acos::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
auto x = get_inputs().at(0).get_output().get_node(); auto x = get_inputs().at(0).get_output().get_node();
auto one = make_shared<op::ScalarConstantLike<double>>(x, 1.0); auto one = make_shared<op::ScalarConstantLike>(x, 1.0);
auto ones = make_shared<op::BroadcastLike>(one, x, AxisSet()); auto ones = make_shared<op::BroadcastLike>(one, x, AxisSet());
adjoints.add_delta(x, -delta / make_shared<op::Sqrt>(ones - x * x)); adjoints.add_delta(x, -delta / make_shared<op::Sqrt>(ones - x * x));
......
...@@ -84,9 +84,9 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVe ...@@ -84,9 +84,9 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVe
op::BroadcastLike::BroadcastLike(const std::shared_ptr<Node>& arg, op::BroadcastLike::BroadcastLike(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& like_arg, const std::shared_ptr<Node>& like_arg,
const AxisSet& broadcast_axes) const AxisSet& initial_broadcast_axes)
: Broadcast("BroadcastLike", {arg, like_arg}, {}, {}) : Broadcast("BroadcastLike", {arg, like_arg}, {}, {})
, m_initial_broadcast_axes(broadcast_axes) , m_initial_broadcast_axes(initial_broadcast_axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -70,17 +70,17 @@ namespace ngraph ...@@ -70,17 +70,17 @@ namespace ngraph
/// ///
/// \param arg The argument to be broadcast. /// \param arg The argument to be broadcast.
/// \param like_arg Provides the shape for the result. /// \param like_arg Provides the shape for the result.
/// \param broadcast_axes indicates which axes will be broadcast. If empty, /// \param initial_broadcast_axes indicates which axes will be broadcast. If empty,
/// arg must be scalar and all axes are broadcast. /// arg must be scalar and all axes are broadcast.
BroadcastLike(const std::shared_ptr<Node>& arg, BroadcastLike(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& like_arg, const std::shared_ptr<Node>& like_arg,
const AxisSet& broadcast_axes); const AxisSet& initial_broadcast_axes);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void infer_shape() override; void infer_shape() override;
const AxisSet& get_initial_broadcast_axes() const { return m_initial_broadcast_axes; }
protected: protected:
AxisSet m_initial_broadcast_axes; AxisSet m_initial_broadcast_axes;
}; };
......
...@@ -169,6 +169,21 @@ shared_ptr<op::Constant> op::ScalarConstantLikeBase::as_constant() const ...@@ -169,6 +169,21 @@ shared_ptr<op::Constant> op::ScalarConstantLikeBase::as_constant() const
return std::make_shared<op::Constant>(m_element_type, m_shape, m_data); return std::make_shared<op::Constant>(m_element_type, m_shape, m_data);
} }
std::shared_ptr<Node> op::ScalarConstantLike::copy_with_new_args(const NodeVector& new_args) const
{
return std::make_shared<ScalarConstantLike>(new_args.at(0), m_value);
}
void op::ScalarConstantLike::infer_element_type()
{
m_element_type = get_input_element_type(0);
if (nullptr == m_data)
{
m_data = ngraph::aligned_alloc(m_element_type.size(), m_element_type.size());
write_values(std::vector<double>(1, m_value));
}
}
// //
// We have to open up namespace blocks here to work around a problem with gcc: // We have to open up namespace blocks here to work around a problem with gcc:
// //
......
...@@ -281,7 +281,6 @@ namespace ngraph ...@@ -281,7 +281,6 @@ namespace ngraph
}; };
/// \brief A scalar constant whose element type is the same as like. /// \brief A scalar constant whose element type is the same as like.
template <typename T>
class ScalarConstantLike : public ScalarConstantLikeBase class ScalarConstantLike : public ScalarConstantLikeBase
{ {
public: public:
...@@ -292,30 +291,20 @@ namespace ngraph ...@@ -292,30 +291,20 @@ namespace ngraph
/// ///
/// \param like A tensor that will supply the element type. /// \param like A tensor that will supply the element type.
/// \param value The value of the scalar. /// \param value The value of the scalar.
template <typename T>
ScalarConstantLike(const std::shared_ptr<Node>& like, T value) ScalarConstantLike(const std::shared_ptr<Node>& like, T value)
: ScalarConstantLikeBase("ScalarConstantLike", {like}) : ScalarConstantLikeBase("ScalarConstantLike", {like})
, m_value(value) , m_value(static_cast<double>(value))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
{
return std::make_shared<ScalarConstantLike<T>>(new_args.at(0), m_value);
}
protected: protected:
void infer_element_type() override void infer_element_type() override;
{
m_element_type = get_input_element_type(0);
if (nullptr == m_data)
{
m_data = ngraph::aligned_alloc(m_element_type.size(), m_element_type.size());
write_values(std::vector<T>(1, m_value));
}
}
T m_value; double m_value;
}; };
} }
} }
...@@ -62,6 +62,7 @@ NGRAPH_OP(BatchNormInference, ngraph::op) ...@@ -62,6 +62,7 @@ NGRAPH_OP(BatchNormInference, ngraph::op)
NGRAPH_OP(BatchNormTraining, ngraph::op) NGRAPH_OP(BatchNormTraining, ngraph::op)
NGRAPH_OP(BatchNormTrainingBackprop, ngraph::op) NGRAPH_OP(BatchNormTrainingBackprop, ngraph::op)
NGRAPH_OP(Broadcast, ngraph::op) NGRAPH_OP(Broadcast, ngraph::op)
NGRAPH_OP(BroadcastLike, ngraph::op)
NGRAPH_OP(Ceiling, ngraph::op) NGRAPH_OP(Ceiling, ngraph::op)
NGRAPH_OP(Concat, ngraph::op) NGRAPH_OP(Concat, ngraph::op)
NGRAPH_OP(Constant, ngraph::op) NGRAPH_OP(Constant, ngraph::op)
...@@ -112,6 +113,7 @@ NGRAPH_OP(Reshape, ngraph::op) ...@@ -112,6 +113,7 @@ NGRAPH_OP(Reshape, ngraph::op)
NGRAPH_OP(Result, ngraph::op) NGRAPH_OP(Result, ngraph::op)
NGRAPH_OP(Reverse, ngraph::op) NGRAPH_OP(Reverse, ngraph::op)
NGRAPH_OP(ReverseSequence, ngraph::op) NGRAPH_OP(ReverseSequence, ngraph::op)
NGRAPH_OP(ScalarConstantLike, ngraph::op)
NGRAPH_OP(Select, ngraph::op) NGRAPH_OP(Select, ngraph::op)
NGRAPH_OP(SelectAndScatter, ngraph::op) NGRAPH_OP(SelectAndScatter, ngraph::op)
NGRAPH_OP(ShapeOf, ngraph::op) NGRAPH_OP(ShapeOf, ngraph::op)
......
...@@ -460,6 +460,11 @@ void runtime::gpu::GPU_Emitter::emit_Broadcast(EMIT_ARGS) ...@@ -460,6 +460,11 @@ void runtime::gpu::GPU_Emitter::emit_Broadcast(EMIT_ARGS)
writer.block_end(); writer.block_end();
} }
void runtime::gpu::GPU_Emitter::emit_BroadcastLike(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
void runtime::gpu::GPU_Emitter::emit_Ceiling(EMIT_ARGS) void runtime::gpu::GPU_Emitter::emit_Ceiling(EMIT_ARGS)
{ {
emit_elementwise<ngraph::op::Ceiling>(external_function, writer, node, args, out); emit_elementwise<ngraph::op::Ceiling>(external_function, writer, node, args, out);
...@@ -1415,6 +1420,11 @@ void runtime::gpu::GPU_Emitter::emit_Rnn(EMIT_ARGS) ...@@ -1415,6 +1420,11 @@ void runtime::gpu::GPU_Emitter::emit_Rnn(EMIT_ARGS)
} }
#endif #endif
void runtime::gpu::GPU_Emitter::emit_ScalarConstantLike(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
void runtime::gpu::GPU_Emitter::emit_Select(EMIT_ARGS) void runtime::gpu::GPU_Emitter::emit_Select(EMIT_ARGS)
{ {
emit_elementwise<ngraph::op::Select>(external_function, writer, node, args, out); emit_elementwise<ngraph::op::Select>(external_function, writer, node, args, out);
......
...@@ -1716,6 +1716,7 @@ runtime::Handle runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> ...@@ -1716,6 +1716,7 @@ runtime::Handle runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function>
case OP_TYPEID::All: case OP_TYPEID::All:
case OP_TYPEID::AllReduce: case OP_TYPEID::AllReduce:
case OP_TYPEID::Any: case OP_TYPEID::Any:
case OP_TYPEID::BroadcastLike:
case OP_TYPEID::FunctionCall: case OP_TYPEID::FunctionCall:
case OP_TYPEID::Dequantize: case OP_TYPEID::Dequantize:
case OP_TYPEID::Quantize: case OP_TYPEID::Quantize:
...@@ -1723,6 +1724,7 @@ runtime::Handle runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> ...@@ -1723,6 +1724,7 @@ runtime::Handle runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function>
case OP_TYPEID::ReplaceSlice: case OP_TYPEID::ReplaceSlice:
case OP_TYPEID::GenerateMask: case OP_TYPEID::GenerateMask:
case OP_TYPEID::ReverseSequence: case OP_TYPEID::ReverseSequence:
case OP_TYPEID::ScalarConstantLike:
case OP_TYPEID::SelectAndScatter: case OP_TYPEID::SelectAndScatter:
case OP_TYPEID::ShapeOf: case OP_TYPEID::ShapeOf:
case OP_TYPEID::StopGradient: case OP_TYPEID::StopGradient:
......
...@@ -459,6 +459,7 @@ private: ...@@ -459,6 +459,7 @@ private:
broadcast_axes); broadcast_axes);
break; break;
} }
case OP_TYPEID::BroadcastLike: break;
case OP_TYPEID::Ceiling: case OP_TYPEID::Ceiling:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
...@@ -488,6 +489,7 @@ private: ...@@ -488,6 +489,7 @@ private:
// Constant is handled in the main loop // Constant is handled in the main loop
break; break;
} }
case OP_TYPEID::ScalarConstantLike: break;
case OP_TYPEID::Convert: case OP_TYPEID::Convert:
{ {
// const op::Convert* c = static_cast<const op::Convert*>(&node); // const op::Convert* c = static_cast<const op::Convert*>(&node);
......
...@@ -573,6 +573,12 @@ static shared_ptr<ngraph::Function> ...@@ -573,6 +573,12 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::Broadcast>(args[0], shape, axes); node = make_shared<op::Broadcast>(args[0], shape, axes);
break; break;
} }
case OP_TYPEID::BroadcastLike:
{
auto initial_axes = node_js.at("initial_axes").get<set<size_t>>();
node = make_shared<op::BroadcastLike>(args[0], args[1], initial_axes);
break;
}
case OP_TYPEID::Ceiling: case OP_TYPEID::Ceiling:
{ {
node = make_shared<op::Ceiling>(args[0]); node = make_shared<op::Ceiling>(args[0]);
...@@ -1026,6 +1032,12 @@ static shared_ptr<ngraph::Function> ...@@ -1026,6 +1032,12 @@ static shared_ptr<ngraph::Function>
make_shared<op::ReverseSequence>(args[0], args[1], batch_axis, sequence_axis); make_shared<op::ReverseSequence>(args[0], args[1], batch_axis, sequence_axis);
break; break;
} }
case OP_TYPEID::ScalarConstantLike:
{
double value = node_js.at("value").get<double>();
node = make_shared<op::ScalarConstantLike>(args[0], value);
break;
}
case OP_TYPEID::Select: case OP_TYPEID::Select:
{ {
node = make_shared<op::Select>(args[0], args[1], args[2]); node = make_shared<op::Select>(args[0], args[1], args[2]);
...@@ -1338,6 +1350,12 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1338,6 +1350,12 @@ static json write(const Node& n, bool binary_constant_data)
node["shape"] = tmp->get_broadcast_shape(); node["shape"] = tmp->get_broadcast_shape();
break; break;
} }
case OP_TYPEID::BroadcastLike:
{
auto tmp = dynamic_cast<const op::BroadcastLike*>(&n);
node["initial_axes"] = tmp->get_initial_broadcast_axes();
break;
}
case OP_TYPEID::Ceiling: { break; case OP_TYPEID::Ceiling: { break;
} }
case OP_TYPEID::Concat: case OP_TYPEID::Concat:
...@@ -1593,6 +1611,14 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1593,6 +1611,14 @@ static json write(const Node& n, bool binary_constant_data)
node["sequence_axis"] = tmp->get_sequence_axis(); node["sequence_axis"] = tmp->get_sequence_axis();
break; break;
} }
case OP_TYPEID::ScalarConstantLike:
{
auto tmp = dynamic_cast<const op::ScalarConstantLikeBase*>(&n);
auto constant = tmp->as_constant();
node["value"] = constant->get_value_strings()[0];
node["element_type"] = write_element_type(constant->get_element_type());
break;
}
case OP_TYPEID::Select: { break; case OP_TYPEID::Select: { break;
} }
case OP_TYPEID::SelectAndScatter: case OP_TYPEID::SelectAndScatter:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment