Unverified Commit 6a88a4f2 authored by Yixing Lao's avatar Yixing Lao Committed by GitHub

Merge branch 'master' into yixing/argon-codegen-v3

parents eb7a8681 e7588efa
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/ops/constant.hpp" #include "ngraph/ops/constant.hpp"
#include "ngraph/util.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -21,10 +22,30 @@ namespace ...@@ -21,10 +22,30 @@ namespace
template <typename ET> template <typename ET>
void check_value_strings(const std::vector<std::string>& value_strings) void check_value_strings(const std::vector<std::string>& value_strings)
{ {
auto result = ET::read(value_strings); auto result = parse_string<typename ET::type>(value_strings);
} }
} }
//
// Utility macro for dispatching an element type-templated function at runtime.
//
// clang-format off
// Sorry, but you really don't want to see what clang-format does to this thing. :)
#define FUNCTION_ON_ELEMENT_TYPE(et, err_msg, f, ...) \
( \
((et) == element::Bool::element_type()) ? (f<element::Bool>(__VA_ARGS__)) : \
((et) == element::Float32::element_type()) ? (f<element::Float32>(__VA_ARGS__)) : \
((et) == element::Int8::element_type()) ? (f<element::Int8>(__VA_ARGS__)) : \
((et) == element::Int32::element_type()) ? (f<element::Int32>(__VA_ARGS__)) : \
((et) == element::Int64::element_type()) ? (f<element::Int64>(__VA_ARGS__)) : \
((et) == element::UInt8::element_type()) ? (f<element::UInt8>(__VA_ARGS__)) : \
((et) == element::UInt32::element_type()) ? (f<element::UInt32>(__VA_ARGS__)) : \
((et) == element::UInt64::element_type()) ? (f<element::UInt64>(__VA_ARGS__)) : \
(throw ngraph_error(err_msg)) \
)
// clang-format on
op::Constant::Constant(const element::Type& et, op::Constant::Constant(const element::Type& et,
const Shape& shape, const Shape& shape,
const std::vector<std::string>& value_strings) const std::vector<std::string>& value_strings)
......
...@@ -116,6 +116,7 @@ ...@@ -116,6 +116,7 @@
#include "ngraph/runtime/ngvm/eigen/vector_slice.hpp" #include "ngraph/runtime/ngvm/eigen/vector_slice.hpp"
#include "ngraph/runtime/ngvm/external_function.hpp" #include "ngraph/runtime/ngvm/external_function.hpp"
#include "ngraph/runtime/utils.hpp" #include "ngraph/runtime/utils.hpp"
#include "ngraph/util.hpp"
using namespace std; using namespace std;
using namespace ngraph::runtime::ngvm; using namespace ngraph::runtime::ngvm;
...@@ -379,8 +380,8 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -379,8 +380,8 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
auto c_value_strings = c->get_value_strings(); auto c_value_strings = c->get_value_strings();
#define M_REGISTER_POLYMORPHIC_CONSTANT(ET) \ #define M_REGISTER_POLYMORPHIC_CONSTANT(ET) \
ef->get_instructions()->push_back( \ ef->get_instructions()->push_back(make_shared<eigen::ConstantInstruction<ET>>( \
make_shared<eigen::ConstantInstruction<ET>>(ET::read(c_value_strings), out[0])); parse_string<typename ET::type>(c_value_strings), out[0]));
DO_ON_ELEMENT_TYPE(c_element_type, DO_ON_ELEMENT_TYPE(c_element_type,
"Constant has unhandled element type", "Constant has unhandled element type",
......
...@@ -117,37 +117,6 @@ namespace ngraph ...@@ -117,37 +117,6 @@ namespace ngraph
{ {
return std::make_shared<runtime::ParameterizedTensorView<TraitedType<T>>>(shape); return std::make_shared<runtime::ParameterizedTensorView<TraitedType<T>>>(shape);
} }
/// Parses a string containing a literal of the underlying type.
static T read(const std::string& s)
{
T result;
std::stringstream ss;
ss << s;
ss >> result;
// Check that (1) parsing succeeded and (2) the entire string was used.
if (ss.fail() || ss.rdbuf()->in_avail() != 0)
{
throw ngraph_error("Could not parse literal");
}
return result;
}
/// Parses a list of strings containing literals of the underlying type.
static std::vector<T> read(const std::vector<std::string>& ss)
{
std::vector<T> result;
for (auto s : ss)
{
result.push_back(read(s));
}
return result;
}
}; };
NGRAPH_DEFINE_TRAITED_TYPE_NAME(char) NGRAPH_DEFINE_TRAITED_TYPE_NAME(char)
...@@ -178,23 +147,3 @@ namespace ngraph ...@@ -178,23 +147,3 @@ namespace ngraph
using UInt64 = TraitedType<uint64_t>; using UInt64 = TraitedType<uint64_t>;
} }
} }
//
// Utility macro for dispatching an element type-templated function at runtime.
//
// clang-format off
// Sorry, but you really don't want to see what clang-format does to this thing. :)
#define FUNCTION_ON_ELEMENT_TYPE(et, err_msg, f, ...) \
( \
((et) == element::Bool::element_type()) ? (f<element::Bool>(__VA_ARGS__)) : \
((et) == element::Float32::element_type()) ? (f<element::Float32>(__VA_ARGS__)) : \
((et) == element::Int8::element_type()) ? (f<element::Int8>(__VA_ARGS__)) : \
((et) == element::Int32::element_type()) ? (f<element::Int32>(__VA_ARGS__)) : \
((et) == element::Int64::element_type()) ? (f<element::Int64>(__VA_ARGS__)) : \
((et) == element::UInt8::element_type()) ? (f<element::UInt8>(__VA_ARGS__)) : \
((et) == element::UInt32::element_type()) ? (f<element::UInt32>(__VA_ARGS__)) : \
((et) == element::UInt64::element_type()) ? (f<element::UInt64>(__VA_ARGS__)) : \
(throw ngraph_error(err_msg)) \
)
// clang-format on
...@@ -162,6 +162,39 @@ namespace ngraph ...@@ -162,6 +162,39 @@ namespace ngraph
std::string m_name; std::string m_name;
}; };
/// Parses a string containing a literal of the underlying type.
template <typename T>
T parse_string(const std::string& s)
{
T result;
std::stringstream ss;
ss << s;
ss >> result;
// Check that (1) parsing succeeded and (2) the entire string was used.
if (ss.fail() || ss.rdbuf()->in_avail() != 0)
{
throw std::runtime_error("Could not parse literal '" + s + "'");
}
return result;
}
/// Parses a list of strings containing literals of the underlying type.
template <typename T>
std::vector<T> parse_string(const std::vector<std::string>& ss)
{
std::vector<T> result;
for (auto s : ss)
{
result.push_back(parse_string<T>(s));
}
return result;
}
template <class InputIt, class BinaryOp> template <class InputIt, class BinaryOp>
typename std::iterator_traits<InputIt>::value_type typename std::iterator_traits<InputIt>::value_type
reduce(InputIt first, InputIt last, BinaryOp op) reduce(InputIt first, InputIt last, BinaryOp op)
......
...@@ -1472,9 +1472,9 @@ TEST(type_prop, tensor_constant_bad_parse) ...@@ -1472,9 +1472,9 @@ TEST(type_prop, tensor_constant_bad_parse)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Bad literal parse not detected"; FAIL() << "Bad literal parse not detected";
} }
catch (const ngraph_error& error) catch (const runtime_error& error)
{ {
EXPECT_EQ(error.what(), std::string("Could not parse literal")); EXPECT_TRUE(string(error.what()).find("Could not parse literal") != string::npos);
} }
catch (...) catch (...)
{ {
...@@ -1492,9 +1492,9 @@ TEST(type_prop, tensor_constant_bad_parse_float_for_int) ...@@ -1492,9 +1492,9 @@ TEST(type_prop, tensor_constant_bad_parse_float_for_int)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Bad literal parse not detected"; FAIL() << "Bad literal parse not detected";
} }
catch (const ngraph_error& error) catch (const runtime_error& error)
{ {
EXPECT_EQ(error.what(), std::string("Could not parse literal")); EXPECT_TRUE(string(error.what()).find("Could not parse literal") != string::npos);
} }
catch (...) catch (...)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment