Unverified Commit 6a88a4f2 authored by Yixing Lao's avatar Yixing Lao Committed by GitHub

Merge branch 'master' into yixing/argon-codegen-v3

parents eb7a8681 e7588efa
......@@ -13,6 +13,7 @@
// ----------------------------------------------------------------------------
#include "ngraph/ops/constant.hpp"
#include "ngraph/util.hpp"
using namespace ngraph;
......@@ -21,10 +22,30 @@ namespace
template <typename ET>
void check_value_strings(const std::vector<std::string>& value_strings)
{
auto result = ET::read(value_strings);
auto result = parse_string<typename ET::type>(value_strings);
}
}
//
// Utility macro for dispatching an element type-templated function at runtime.
//
// clang-format off
// Sorry, but you really don't want to see what clang-format does to this thing. :)
#define FUNCTION_ON_ELEMENT_TYPE(et, err_msg, f, ...) \
( \
((et) == element::Bool::element_type()) ? (f<element::Bool>(__VA_ARGS__)) : \
((et) == element::Float32::element_type()) ? (f<element::Float32>(__VA_ARGS__)) : \
((et) == element::Int8::element_type()) ? (f<element::Int8>(__VA_ARGS__)) : \
((et) == element::Int32::element_type()) ? (f<element::Int32>(__VA_ARGS__)) : \
((et) == element::Int64::element_type()) ? (f<element::Int64>(__VA_ARGS__)) : \
((et) == element::UInt8::element_type()) ? (f<element::UInt8>(__VA_ARGS__)) : \
((et) == element::UInt32::element_type()) ? (f<element::UInt32>(__VA_ARGS__)) : \
((et) == element::UInt64::element_type()) ? (f<element::UInt64>(__VA_ARGS__)) : \
(throw ngraph_error(err_msg)) \
)
// clang-format on
op::Constant::Constant(const element::Type& et,
const Shape& shape,
const std::vector<std::string>& value_strings)
......
......@@ -116,6 +116,7 @@
#include "ngraph/runtime/ngvm/eigen/vector_slice.hpp"
#include "ngraph/runtime/ngvm/external_function.hpp"
#include "ngraph/runtime/utils.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph::runtime::ngvm;
......@@ -379,8 +380,8 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
auto c_value_strings = c->get_value_strings();
#define M_REGISTER_POLYMORPHIC_CONSTANT(ET) \
ef->get_instructions()->push_back( \
make_shared<eigen::ConstantInstruction<ET>>(ET::read(c_value_strings), out[0]));
ef->get_instructions()->push_back(make_shared<eigen::ConstantInstruction<ET>>( \
parse_string<typename ET::type>(c_value_strings), out[0]));
DO_ON_ELEMENT_TYPE(c_element_type,
"Constant has unhandled element type",
......
......@@ -117,37 +117,6 @@ namespace ngraph
{
return std::make_shared<runtime::ParameterizedTensorView<TraitedType<T>>>(shape);
}
/// Parses a string containing a literal of the underlying type.
static T read(const std::string& s)
{
T result;
std::stringstream ss;
ss << s;
ss >> result;
// Check that (1) parsing succeeded and (2) the entire string was used.
if (ss.fail() || ss.rdbuf()->in_avail() != 0)
{
throw ngraph_error("Could not parse literal");
}
return result;
}
/// Parses a list of strings containing literals of the underlying type.
static std::vector<T> read(const std::vector<std::string>& ss)
{
std::vector<T> result;
for (auto s : ss)
{
result.push_back(read(s));
}
return result;
}
};
NGRAPH_DEFINE_TRAITED_TYPE_NAME(char)
......@@ -178,23 +147,3 @@ namespace ngraph
using UInt64 = TraitedType<uint64_t>;
}
}
//
// Utility macro for dispatching an element type-templated function at runtime.
//
// clang-format off
// Sorry, but you really don't want to see what clang-format does to this thing. :)
#define FUNCTION_ON_ELEMENT_TYPE(et, err_msg, f, ...) \
( \
((et) == element::Bool::element_type()) ? (f<element::Bool>(__VA_ARGS__)) : \
((et) == element::Float32::element_type()) ? (f<element::Float32>(__VA_ARGS__)) : \
((et) == element::Int8::element_type()) ? (f<element::Int8>(__VA_ARGS__)) : \
((et) == element::Int32::element_type()) ? (f<element::Int32>(__VA_ARGS__)) : \
((et) == element::Int64::element_type()) ? (f<element::Int64>(__VA_ARGS__)) : \
((et) == element::UInt8::element_type()) ? (f<element::UInt8>(__VA_ARGS__)) : \
((et) == element::UInt32::element_type()) ? (f<element::UInt32>(__VA_ARGS__)) : \
((et) == element::UInt64::element_type()) ? (f<element::UInt64>(__VA_ARGS__)) : \
(throw ngraph_error(err_msg)) \
)
// clang-format on
......@@ -162,6 +162,39 @@ namespace ngraph
std::string m_name;
};
/// Parses a string containing a literal of the underlying type.
template <typename T>
T parse_string(const std::string& s)
{
T result;
std::stringstream ss;
ss << s;
ss >> result;
// Check that (1) parsing succeeded and (2) the entire string was used.
if (ss.fail() || ss.rdbuf()->in_avail() != 0)
{
throw std::runtime_error("Could not parse literal '" + s + "'");
}
return result;
}
/// Parses a list of strings containing literals of the underlying type.
template <typename T>
std::vector<T> parse_string(const std::vector<std::string>& ss)
{
std::vector<T> result;
for (auto s : ss)
{
result.push_back(parse_string<T>(s));
}
return result;
}
template <class InputIt, class BinaryOp>
typename std::iterator_traits<InputIt>::value_type
reduce(InputIt first, InputIt last, BinaryOp op)
......
......@@ -1472,9 +1472,9 @@ TEST(type_prop, tensor_constant_bad_parse)
// Should have thrown, so fail if it didn't
FAIL() << "Bad literal parse not detected";
}
catch (const ngraph_error& error)
catch (const runtime_error& error)
{
EXPECT_EQ(error.what(), std::string("Could not parse literal"));
EXPECT_TRUE(string(error.what()).find("Could not parse literal") != string::npos);
}
catch (...)
{
......@@ -1492,9 +1492,9 @@ TEST(type_prop, tensor_constant_bad_parse_float_for_int)
// Should have thrown, so fail if it didn't
FAIL() << "Bad literal parse not detected";
}
catch (const ngraph_error& error)
catch (const runtime_error& error)
{
EXPECT_EQ(error.what(), std::string("Could not parse literal"));
EXPECT_TRUE(string(error.what()).find("Could not parse literal") != string::npos);
}
catch (...)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment