Commit bdd16da3 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[SPEC] Add support string as AutoBroadcastSpec (#3909)

* Support string casting to AutoBroadcastSpec

* Make string values consistent
parent 883d2efd
......@@ -13,9 +13,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <map>
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/check.hpp"
#include "ngraph/enum_names.hpp"
#include "ngraph/op/util/attr_types.hpp"
using namespace ngraph;
......@@ -126,4 +128,17 @@ namespace ngraph
{
return s << as_string(type);
}
op::AutoBroadcastType op::AutoBroadcastSpec::type_from_string(const std::string& type) const
{
static const std::map<std::string, AutoBroadcastType> allowed_values = {
{"NONE", AutoBroadcastType::NONE},
{"NUMPY", AutoBroadcastType::NUMPY},
{"PDPD", AutoBroadcastType::PDPD},
{"EXPLICIT", AutoBroadcastType::EXPLICIT}};
NGRAPH_CHECK(allowed_values.count(type) > 0, "Invalid 'type' value passed in.");
return allowed_values.at(type);
}
}
......@@ -257,6 +257,10 @@ namespace ngraph
, m_axis(0)
{
}
AutoBroadcastSpec(const char* type)
: AutoBroadcastSpec(type_from_string(type))
{
}
AutoBroadcastSpec(AutoBroadcastType type, int64_t axis)
: m_type(type)
, m_axis(axis)
......@@ -275,6 +279,9 @@ namespace ngraph
static const AutoBroadcastSpec NUMPY;
NGRAPH_API
static const AutoBroadcastSpec NONE;
private:
AutoBroadcastType type_from_string(const std::string& type) const;
};
}
}
......@@ -21,6 +21,8 @@
#include <random>
#include <string>
#include "util/type_prop.hpp"
// clang-format off
#ifdef ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS
#define DEFAULT_FLOAT_TOLERANCE_BITS ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS
......@@ -191,3 +193,35 @@ NGRAPH_TEST(${BACKEND_NAME}, auto_bcast_binary_elementwise_pdpd_dynamic)
ex->call_with_validate({t_r}, {t_a, t_b});
ASSERT_EQ(t_r->get_shape(), (Shape{2, 3, 4, 5}));
}
NGRAPH_TEST(${BACKEND_NAME}, auto_bcast_string_cast)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{1});
auto b = make_shared<op::Parameter>(element::f32, Shape{1});
auto add = make_shared<op::Add>(a, b, "NUMPY");
ASSERT_EQ(add->get_autob(), op::AutoBroadcastType::NUMPY);
add = make_shared<op::Add>(a, b, "NONE");
ASSERT_EQ(add->get_autob(), op::AutoBroadcastType::NONE);
add = make_shared<op::Add>(a, b, "PDPD");
ASSERT_EQ(add->get_autob(), op::AutoBroadcastType::PDPD);
add = make_shared<op::Add>(a, b, "EXPLICIT");
ASSERT_EQ(add->get_autob(), op::AutoBroadcastType::EXPLICIT);
try
{
add = make_shared<op::Add>(a, b, "UNKNOWN");
FAIL() << "Unknown AutoBroadcastType not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Invalid 'type' value passed in."));
}
catch (...)
{
FAIL() << "AutoBroadcastType checking failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment