Commit f64b0e0c authored by tsocha's avatar tsocha Committed by Robert Kimball

[ONNX] Fix in overriding ops (#2477)

* [ONNX] Overriding custom ops

* Add UT

* Style Check

* Review & style fix
parent 2b54d810
...@@ -140,9 +140,14 @@ namespace ngraph ...@@ -140,9 +140,14 @@ namespace ngraph
const std::string& domain, const std::string& domain,
Operator fn) Operator fn)
{ {
auto result = m_map[domain][name].emplace(version, std::move(fn)); auto it = m_map[domain][name].find(version);
if (result.second) if (it == std::end(m_map[domain][name]))
{ {
m_map[domain][name].emplace(version, std::move(fn));
}
else
{
it->second = std::move(fn);
NGRAPH_WARN << "Overwriting existing operator: " NGRAPH_WARN << "Overwriting existing operator: "
<< domain + "." + name + ":" + std::to_string(version); << domain + "." + name + ":" + std::to_string(version);
} }
......
ONNXnGraphImporter:c

A
BC"FalseAdd compute_graphZ
A


Z
B


b
C


B
\ No newline at end of file
...@@ -2031,6 +2031,33 @@ TEST(onnx_${BACKEND_NAME}, model_where) ...@@ -2031,6 +2031,33 @@ TEST(onnx_${BACKEND_NAME}, model_where)
EXPECT_EQ(expected_outputs.front(), outputs.front()); EXPECT_EQ(expected_outputs.front(), outputs.front());
} }
TEST(onnx_${BACKEND_NAME}, model_override_op)
{
onnx_import::register_operator(
"FalseAdd", 1, "", [](const onnx_import::Node& node) -> NodeVector {
NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
});
onnx_import::register_operator(
"FalseAdd", 1, "", [](const onnx_import::Node& node) -> NodeVector {
NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))};
});
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/override_op.onnx"));
Inputs inputs;
inputs.emplace_back(std::vector<float>{0.f, 1.f, 2.f, 3.f});
inputs.emplace_back(std::vector<float>{3.f, 2.f, 1.f, 0.f});
Outputs expected_output{std::vector<float>{-3.f, -1.f, 1.f, 3.f}};
Outputs outputs{execute(function, inputs, "${BACKEND_NAME}")};
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
}
TEST(onnx_${BACKEND_NAME}, import_non_existing_file) TEST(onnx_${BACKEND_NAME}, import_non_existing_file)
{ {
try try
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment