Commit 5498a997 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Remove ScalarConstantLikeBase (#4132)

* Remove ScalarConstantLikeBase

* Remove code no longer needed
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 62038b21
......@@ -393,7 +393,7 @@ bool op::Constant::are_all_data_elements_bitwise_identical() const
constexpr NodeTypeInfo op::ScalarConstantLike::type_info;
shared_ptr<op::Constant> op::ScalarConstantLikeBase::as_constant() const
shared_ptr<op::Constant> op::ScalarConstantLike::as_constant() const
{
return std::make_shared<op::Constant>(m_element_type, m_shape, m_data->get_ptr());
}
......
......@@ -375,21 +375,8 @@ namespace ngraph
Constant operator=(const Constant&) = delete;
};
class NGRAPH_API ScalarConstantLikeBase : public Constant
{
public:
std::shared_ptr<op::Constant> as_constant() const;
ScalarConstantLikeBase() = default;
protected:
ScalarConstantLikeBase(const OutputVector& args)
: Constant(args)
{
}
};
/// \brief A scalar constant whose element type is the same as like.
class NGRAPH_API ScalarConstantLike : public ScalarConstantLikeBase
class NGRAPH_API ScalarConstantLike : public Constant
{
public:
static constexpr NodeTypeInfo type_info{"ScalarConstantLike", 0};
......@@ -403,7 +390,7 @@ namespace ngraph
/// \param value The value of the scalar.
template <typename T>
ScalarConstantLike(const Output<Node>& like, T value)
: ScalarConstantLikeBase({like})
: Constant({like})
, m_value(static_cast<double>(value))
{
constructor_validate_and_infer_types();
......@@ -412,6 +399,7 @@ namespace ngraph
ScalarConstantLike() = default;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<op::Constant> as_constant() const;
protected:
void infer_element_type() override;
......
......@@ -71,15 +71,6 @@ bool pass::LikeReplacement::run_on_function(shared_ptr<Function> function_ptr)
{
clobbered = handler->second(n) || clobbered;
}
// Here we're checking on a common base class of a family of template classes,
// which is more than type info can handle.
auto sclb = as_type_ptr<op::ScalarConstantLikeBase>(n);
if (sclb != nullptr)
{
replace_node(sclb, sclb->as_constant());
clobbered = true;
}
}
return clobbered;
......
......@@ -79,18 +79,6 @@ static bool eliminate_slice(const std::shared_ptr<Node>& node)
return false;
}
static bool replace_broadcast_like(const std::shared_ptr<Node>& node)
{
// Replace a broadcast like with the broadcast to eliminate the pseudo-dependency on the "like"
// argument
auto broadcast_like = std::static_pointer_cast<op::BroadcastLike>(node);
replace_node(node,
std::make_shared<op::Broadcast>(broadcast_like->get_argument(0),
broadcast_like->get_broadcast_shape(),
broadcast_like->get_broadcast_axes()));
return true;
}
static bool eliminate_broadcast(const std::shared_ptr<Node>& node)
{
auto broadcast = std::static_pointer_cast<op::Broadcast>(node);
......@@ -114,7 +102,6 @@ static const std::unordered_map<std::type_index, std::function<bool(const std::s
{TI(op::Convert), &eliminate_convert},
{TI(op::Slice), &eliminate_slice},
{TI(op::StopGradient), &eliminate_stop_gradient},
{TI(op::BroadcastLike), &replace_broadcast_like},
{TI(op::Broadcast), &eliminate_broadcast}};
bool pass::NopElimination::run_on_function(std::shared_ptr<Function> function)
......@@ -130,15 +117,6 @@ bool pass::NopElimination::run_on_function(std::shared_ptr<Function> function)
{
clobbered = handler->second(n) || clobbered;
}
// Here we're checking on a common base class of a family of template classes,
// which is more than type info can handle.
auto sclb = as_type_ptr<op::ScalarConstantLikeBase>(n);
if (sclb != nullptr)
{
replace_node(sclb, sclb->as_constant());
clobbered = true;
}
}
return clobbered;
......
......@@ -4357,7 +4357,7 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::ScalarConstantLike:
{
auto tmp = static_cast<const op::ScalarConstantLikeBase*>(&n);
auto tmp = static_cast<const op::ScalarConstantLike*>(&n);
auto constant = tmp->as_constant();
node["value"] = constant->get_value_strings()[0];
node["element_type"] = write_element_type(constant->get_element_type());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment