Commit 2b498a6c authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Add BroadcastLike and ScalarConstantLike to op_tbl (#2213)

* add missing ops to op_tbl

* add two missing ops
parent 4514faf9
......@@ -37,7 +37,7 @@ using namespace ngraph;
std::shared_ptr<Node> make_zero(const std::shared_ptr<Node>& node)
{
std::shared_ptr<Node> zero = std::make_shared<op::ScalarConstantLike<double>>(node, 0.0);
std::shared_ptr<Node> zero = std::make_shared<op::ScalarConstantLike>(node, 0.0);
std::shared_ptr<Node> bzero = std::make_shared<op::BroadcastLike>(zero, node, AxisSet{});
return bzero;
}
......
......@@ -50,7 +50,7 @@ void op::Acos::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
auto x = get_inputs().at(0).get_output().get_node();
auto one = make_shared<op::ScalarConstantLike<double>>(x, 1.0);
auto one = make_shared<op::ScalarConstantLike>(x, 1.0);
auto ones = make_shared<op::BroadcastLike>(one, x, AxisSet());
adjoints.add_delta(x, -delta / make_shared<op::Sqrt>(ones - x * x));
......
......@@ -84,9 +84,9 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVe
op::BroadcastLike::BroadcastLike(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& like_arg,
const AxisSet& broadcast_axes)
const AxisSet& initial_broadcast_axes)
: Broadcast("BroadcastLike", {arg, like_arg}, {}, {})
, m_initial_broadcast_axes(broadcast_axes)
, m_initial_broadcast_axes(initial_broadcast_axes)
{
constructor_validate_and_infer_types();
}
......
......@@ -70,17 +70,17 @@ namespace ngraph
///
/// \param arg The argument to be broadcast.
/// \param like_arg Provides the shape for the result.
/// \param broadcast_axes indicates which axes will be broadcast. If empty,
/// arg must be scalar and all axes are broadcast.
/// \param initial_broadcast_axes indicates which axes will be broadcast. If empty,
/// arg must be scalar and all axes are broadcast.
BroadcastLike(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& like_arg,
const AxisSet& broadcast_axes);
const AxisSet& initial_broadcast_axes);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
void infer_shape() override;
const AxisSet& get_initial_broadcast_axes() const { return m_initial_broadcast_axes; }
protected:
AxisSet m_initial_broadcast_axes;
};
......
......@@ -169,6 +169,21 @@ shared_ptr<op::Constant> op::ScalarConstantLikeBase::as_constant() const
return std::make_shared<op::Constant>(m_element_type, m_shape, m_data);
}
std::shared_ptr<Node> op::ScalarConstantLike::copy_with_new_args(const NodeVector& new_args) const
{
return std::make_shared<ScalarConstantLike>(new_args.at(0), m_value);
}
void op::ScalarConstantLike::infer_element_type()
{
m_element_type = get_input_element_type(0);
if (nullptr == m_data)
{
m_data = ngraph::aligned_alloc(m_element_type.size(), m_element_type.size());
write_values(std::vector<double>(1, m_value));
}
}
//
// We have to open up namespace blocks here to work around a problem with gcc:
//
......
......@@ -281,7 +281,6 @@ namespace ngraph
};
/// \brief A scalar constant whose element type is the same as like.
template <typename T>
class ScalarConstantLike : public ScalarConstantLikeBase
{
public:
......@@ -292,30 +291,20 @@ namespace ngraph
///
/// \param like A tensor that will supply the element type.
/// \param value The value of the scalar.
template <typename T>
ScalarConstantLike(const std::shared_ptr<Node>& like, T value)
: ScalarConstantLikeBase("ScalarConstantLike", {like})
, m_value(value)
, m_value(static_cast<double>(value))
{
constructor_validate_and_infer_types();
}
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override
{
return std::make_shared<ScalarConstantLike<T>>(new_args.at(0), m_value);
}
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
protected:
void infer_element_type() override
{
m_element_type = get_input_element_type(0);
if (nullptr == m_data)
{
m_data = ngraph::aligned_alloc(m_element_type.size(), m_element_type.size());
write_values(std::vector<T>(1, m_value));
}
}
void infer_element_type() override;
T m_value;
double m_value;
};
}
}
......@@ -62,6 +62,7 @@ NGRAPH_OP(BatchNormInference, ngraph::op)
NGRAPH_OP(BatchNormTraining, ngraph::op)
NGRAPH_OP(BatchNormTrainingBackprop, ngraph::op)
NGRAPH_OP(Broadcast, ngraph::op)
NGRAPH_OP(BroadcastLike, ngraph::op)
NGRAPH_OP(Ceiling, ngraph::op)
NGRAPH_OP(Concat, ngraph::op)
NGRAPH_OP(Constant, ngraph::op)
......@@ -112,6 +113,7 @@ NGRAPH_OP(Reshape, ngraph::op)
NGRAPH_OP(Result, ngraph::op)
NGRAPH_OP(Reverse, ngraph::op)
NGRAPH_OP(ReverseSequence, ngraph::op)
NGRAPH_OP(ScalarConstantLike, ngraph::op)
NGRAPH_OP(Select, ngraph::op)
NGRAPH_OP(SelectAndScatter, ngraph::op)
NGRAPH_OP(ShapeOf, ngraph::op)
......
......@@ -460,6 +460,11 @@ void runtime::gpu::GPU_Emitter::emit_Broadcast(EMIT_ARGS)
writer.block_end();
}
void runtime::gpu::GPU_Emitter::emit_BroadcastLike(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
void runtime::gpu::GPU_Emitter::emit_Ceiling(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Ceiling>(external_function, writer, node, args, out);
......@@ -1415,6 +1420,11 @@ void runtime::gpu::GPU_Emitter::emit_Rnn(EMIT_ARGS)
}
#endif
void runtime::gpu::GPU_Emitter::emit_ScalarConstantLike(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
void runtime::gpu::GPU_Emitter::emit_Select(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Select>(external_function, writer, node, args, out);
......
......@@ -1716,6 +1716,7 @@ runtime::Handle runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function>
case OP_TYPEID::All:
case OP_TYPEID::AllReduce:
case OP_TYPEID::Any:
case OP_TYPEID::BroadcastLike:
case OP_TYPEID::FunctionCall:
case OP_TYPEID::Dequantize:
case OP_TYPEID::Quantize:
......@@ -1723,6 +1724,7 @@ runtime::Handle runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function>
case OP_TYPEID::ReplaceSlice:
case OP_TYPEID::GenerateMask:
case OP_TYPEID::ReverseSequence:
case OP_TYPEID::ScalarConstantLike:
case OP_TYPEID::SelectAndScatter:
case OP_TYPEID::ShapeOf:
case OP_TYPEID::StopGradient:
......
......@@ -459,6 +459,7 @@ private:
broadcast_axes);
break;
}
case OP_TYPEID::BroadcastLike: break;
case OP_TYPEID::Ceiling:
{
size_t element_count = shape_size(node.get_output_shape(0));
......@@ -488,6 +489,7 @@ private:
// Constant is handled in the main loop
break;
}
case OP_TYPEID::ScalarConstantLike: break;
case OP_TYPEID::Convert:
{
// const op::Convert* c = static_cast<const op::Convert*>(&node);
......
......@@ -573,6 +573,12 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::Broadcast>(args[0], shape, axes);
break;
}
case OP_TYPEID::BroadcastLike:
{
auto initial_axes = node_js.at("initial_axes").get<set<size_t>>();
node = make_shared<op::BroadcastLike>(args[0], args[1], initial_axes);
break;
}
case OP_TYPEID::Ceiling:
{
node = make_shared<op::Ceiling>(args[0]);
......@@ -1026,6 +1032,12 @@ static shared_ptr<ngraph::Function>
make_shared<op::ReverseSequence>(args[0], args[1], batch_axis, sequence_axis);
break;
}
case OP_TYPEID::ScalarConstantLike:
{
double value = node_js.at("value").get<double>();
node = make_shared<op::ScalarConstantLike>(args[0], value);
break;
}
case OP_TYPEID::Select:
{
node = make_shared<op::Select>(args[0], args[1], args[2]);
......@@ -1338,6 +1350,12 @@ static json write(const Node& n, bool binary_constant_data)
node["shape"] = tmp->get_broadcast_shape();
break;
}
case OP_TYPEID::BroadcastLike:
{
auto tmp = dynamic_cast<const op::BroadcastLike*>(&n);
node["initial_axes"] = tmp->get_initial_broadcast_axes();
break;
}
case OP_TYPEID::Ceiling: { break;
}
case OP_TYPEID::Concat:
......@@ -1593,6 +1611,14 @@ static json write(const Node& n, bool binary_constant_data)
node["sequence_axis"] = tmp->get_sequence_axis();
break;
}
case OP_TYPEID::ScalarConstantLike:
{
auto tmp = dynamic_cast<const op::ScalarConstantLikeBase*>(&n);
auto constant = tmp->as_constant();
node["value"] = constant->get_value_strings()[0];
node["element_type"] = write_element_type(constant->get_element_type());
break;
}
case OP_TYPEID::Select: { break;
}
case OP_TYPEID::SelectAndScatter:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment