Commit cf33669b authored by Rob Earhart's avatar Rob Earhart Committed by Robert Kimball

Convert PlaidML Tile op to generic ngraph passthrough op (#2361)

* Add a direct-to-Tile op

* Disable dequantize_dynamic_offset

* Add missing Py op defn

* Generic passthrough op; serialization

* Appease Linux builds

* Add gpu handlers

* Disable floor_int32 for now
parent fb4db5f6
...@@ -30,7 +30,8 @@ if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") ...@@ -30,7 +30,8 @@ if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
endif() endif()
endif() endif()
set(GTEST_OUTPUT_DIR ${EXTERNAL_PROJECTS_ROOT}/gtest/build/googlemock/gtest) set(GMOCK_OUTPUT_DIR ${EXTERNAL_PROJECTS_ROOT}/gtest/build/googlemock)
set(GTEST_OUTPUT_DIR ${GMOCK_OUTPUT_DIR}/gtest)
if (APPLE OR LINUX) if (APPLE OR LINUX)
set(COMPILE_FLAGS -fPIC) set(COMPILE_FLAGS -fPIC)
...@@ -70,9 +71,13 @@ ExternalProject_Add( ...@@ -70,9 +71,13 @@ ExternalProject_Add(
ExternalProject_Get_Property(ext_gtest SOURCE_DIR BINARY_DIR) ExternalProject_Get_Property(ext_gtest SOURCE_DIR BINARY_DIR)
add_library(libgtest INTERFACE) add_library(libgtest INTERFACE)
add_dependencies(libgtest ext_gtest) add_dependencies(libgtest ext_gtest ext_gmock)
target_include_directories(libgtest SYSTEM INTERFACE ${SOURCE_DIR}/googletest/include) target_include_directories(libgtest SYSTEM INTERFACE
${SOURCE_DIR}/googletest/include
${SOURCE_DIR}/googlemock/include)
target_link_libraries(libgtest INTERFACE target_link_libraries(libgtest INTERFACE
debug ${GTEST_OUTPUT_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}gtestd${CMAKE_STATIC_LIBRARY_SUFFIX} debug ${GTEST_OUTPUT_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}gtestd${CMAKE_STATIC_LIBRARY_SUFFIX}
optimized ${GTEST_OUTPUT_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}gtest${CMAKE_STATIC_LIBRARY_SUFFIX}) debug ${GMOCK_OUTPUT_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}gmockd${CMAKE_STATIC_LIBRARY_SUFFIX}
optimized ${GTEST_OUTPUT_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}gtest${CMAKE_STATIC_LIBRARY_SUFFIX}
optimized ${GMOCK_OUTPUT_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}gmock${CMAKE_STATIC_LIBRARY_SUFFIX})
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/passthrough.hpp"
#include "pyngraph/ops/passthrough.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Passthrough(py::module m)
{
py::class_<ngraph::op::Passthrough, std::shared_ptr<ngraph::op::Passthrough>, ngraph::Node>
pass{m, "Passthrough"};
pass.doc() = "ngraph.impl.op.Passthrough wraps ngraph::op::Passthrough";
pass.def(py::init<const std::string&,
const std::string&,
const std::string&,
const ngraph::NodeVector&,
std::vector<std::tuple<ngraph::element::Type, ngraph::PartialShape>>>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Passthrough(py::module m);
...@@ -72,6 +72,7 @@ void regmodule_pyngraph_op(py::module m_op) ...@@ -72,6 +72,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Or(m_op); regclass_pyngraph_op_Or(m_op);
regclass_pyngraph_op_Pad(m_op); regclass_pyngraph_op_Pad(m_op);
regclass_pyngraph_op_Parameter(m_op); regclass_pyngraph_op_Parameter(m_op);
regclass_pyngraph_op_Passthrough(m_op);
regclass_pyngraph_op_Power(m_op); regclass_pyngraph_op_Power(m_op);
regclass_pyngraph_op_Product(m_op); regclass_pyngraph_op_Product(m_op);
regclass_pyngraph_op_Relu(m_op); regclass_pyngraph_op_Relu(m_op);
......
...@@ -61,6 +61,7 @@ ...@@ -61,6 +61,7 @@
#include "pyngraph/ops/or.hpp" #include "pyngraph/ops/or.hpp"
#include "pyngraph/ops/pad.hpp" #include "pyngraph/ops/pad.hpp"
#include "pyngraph/ops/parameter.hpp" #include "pyngraph/ops/parameter.hpp"
#include "pyngraph/ops/passthrough.hpp"
#include "pyngraph/ops/power.hpp" #include "pyngraph/ops/power.hpp"
#include "pyngraph/ops/product.hpp" #include "pyngraph/ops/product.hpp"
#include "pyngraph/ops/relu.hpp" #include "pyngraph/ops/relu.hpp"
......
...@@ -204,6 +204,7 @@ sources = [ ...@@ -204,6 +204,7 @@ sources = [
'pyngraph/ops/or.cpp', 'pyngraph/ops/or.cpp',
'pyngraph/ops/pad.cpp', 'pyngraph/ops/pad.cpp',
'pyngraph/ops/parameter.cpp', 'pyngraph/ops/parameter.cpp',
'pyngraph/ops/passthrough.cpp',
'pyngraph/ops/power.cpp', 'pyngraph/ops/power.cpp',
'pyngraph/ops/regmodule_pyngraph_op.cpp', 'pyngraph/ops/regmodule_pyngraph_op.cpp',
'pyngraph/ops/relu.cpp', 'pyngraph/ops/relu.cpp',
......
...@@ -94,6 +94,7 @@ set (SRC ...@@ -94,6 +94,7 @@ set (SRC
op/or.cpp op/or.cpp
op/pad.cpp op/pad.cpp
op/parameter.cpp op/parameter.cpp
op/passthrough.cpp
op/power.cpp op/power.cpp
op/product.cpp op/product.cpp
op/quantize.cpp op/quantize.cpp
......
...@@ -105,6 +105,7 @@ NGRAPH_OP(OneHot, ngraph::op) ...@@ -105,6 +105,7 @@ NGRAPH_OP(OneHot, ngraph::op)
NGRAPH_OP(Or, ngraph::op) NGRAPH_OP(Or, ngraph::op)
NGRAPH_OP(Pad, ngraph::op) NGRAPH_OP(Pad, ngraph::op)
NGRAPH_OP(Parameter, ngraph::op) NGRAPH_OP(Parameter, ngraph::op)
NGRAPH_OP(Passthrough, ngraph::op)
NGRAPH_OP(Power, ngraph::op) NGRAPH_OP(Power, ngraph::op)
NGRAPH_OP(Product, ngraph::op) NGRAPH_OP(Product, ngraph::op)
NGRAPH_OP(Quantize, ngraph::op) NGRAPH_OP(Quantize, ngraph::op)
......
//***************************************************************************** //*****************************************************************************
// Copyright 2017-2018 Intel Corporation // Copyright 2017-2019 Intel Corporation
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -16,53 +16,38 @@ ...@@ -16,53 +16,38 @@
#include <utility> #include <utility>
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/op/passthrough.hpp"
#include "ngraph/runtime/plaidml/plaidml_ops_tile.hpp"
namespace ngraph ngraph::op::Passthrough::Passthrough(const std::string& logical_type,
{ const std::string& language,
namespace runtime const std::string& function,
{
namespace plaidml
{
NGRAPH_PLAIDML_OP_CLASS(ImplTile, OpImpl<op::Tile>);
}
}
}
ngraph::runtime::plaidml::op::Tile::Tile(
const std::string& node_type,
vertexai::plaidml::function function,
const NodeVector& args, const NodeVector& args,
std::vector<std::tuple<element::Type, PartialShape>> outputs) std::vector<std::tuple<element::Type, PartialShape>> outputs)
: Node{node_type, args, outputs.size()} : Op{"Passthrough", args}
, m_function{std::move(function)} , m_logical_type{logical_type}
, m_language{language}
, m_function{function}
, m_output_shapes{std::move(outputs)} , m_output_shapes{std::move(outputs)}
{ {
set_output_size(m_output_shapes.size());
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void ngraph::runtime::plaidml::op::Tile::validate_and_infer_types() void ngraph::op::Passthrough::validate_and_infer_types()
{ {
// TODO: It would be useful to have PlaidML deduce the output // N.B. It would be useful to have the backend deduce the output
// shapes, instead of having them passed in via the // shapes, instead of having them passed in via the
// constructor. The primary barrier to doing so is that // constructor and trusting that they're correct.
// PlaidML placeholders always have a fixed number of //
// dimensions but arbitrary dimension sizes, and the only way // The primary barrier to doing so is that at the point where
// to pin them down to a concrete dimension size is to bind a // Passthrough ops are being constructed, we don't
// tensor to them, which requires actually allocating the // necessarily have the backend available.
// tensor. In principal, we could fix this pretty easily; //
// we'll need to know more about where the PlaidML API is // At some point, we may want to add higher-level
// going before doing so, though. // backend-specific APIs for constructing Passthrough
if (get_input_size() != m_function.num_inputs()) // operations; that would ensure that the backend can
{ // understand the language being used, and would allow the
throw ngraph_error{"Incorrect input count for Tile operation node"}; // backend to infer the output shapes as needed.
}
if (m_output_shapes.size() != m_function.num_outputs())
{
throw ngraph_error{"Incorrect output count for Tile operation node"};
}
std::size_t idx = 0; std::size_t idx = 0;
for (auto& output_shape : m_output_shapes) for (auto& output_shape : m_output_shapes)
...@@ -72,28 +57,13 @@ void ngraph::runtime::plaidml::op::Tile::validate_and_infer_types() ...@@ -72,28 +57,13 @@ void ngraph::runtime::plaidml::op::Tile::validate_and_infer_types()
} }
std::shared_ptr<ngraph::Node> std::shared_ptr<ngraph::Node>
ngraph::runtime::plaidml::op::Tile::copy_with_new_args(const NodeVector& new_args) const ngraph::op::Passthrough::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != get_input_size()) if (new_args.size() != get_input_size())
{ {
throw ngraph_error{"Tile node input counts cannot be changed for a given Tile function"}; throw ngraph_error{
} "Passthrough node input counts cannot be changed for a given Passthrough function"};
return std::make_shared<Tile>(description(), m_function, new_args, m_output_shapes);
}
void ngraph::runtime::plaidml::ImplTile::Apply()
{
vertexai::plaidml::function::positional_t inputs;
for (std::size_t idx = 0; idx < op().get_input_size(); ++idx)
{
inputs.emplace_back(op_input(idx));
}
auto app = op().func().apply(inputs);
for (std::size_t idx = 0; idx < op().get_output_size(); ++idx)
{
set_output(idx, app.get_output(idx));
} }
return std::make_shared<Passthrough>(
description(), m_language, m_function, new_args, m_output_shapes);
} }
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <string>
#include <tuple>
#include <vector>
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
/// An op directly representing backend-specific code.
///
/// N.B. Not all backends support all operation languages; a
/// given backend might only support a given passthrough
/// operation language in certain modes.
class Passthrough;
}
}
class ngraph::op::Passthrough final : public Op
{
public:
Passthrough(const std::string& logical_type, // aka "What this operation is doing"
const std::string& language, // The language the implementation is written in
const std::string& function, // The operation implementation
const NodeVector& args,
std::vector<std::tuple<element::Type, PartialShape>> outputs);
void validate_and_infer_types() final;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const final;
const std::string& logical_type() const { return m_logical_type; }
const std::string& language() const { return m_language; }
const std::string& function() const { return m_function; }
const std::vector<std::tuple<element::Type, PartialShape>>& output_shapes() const
{
return m_output_shapes;
}
private:
std::string m_logical_type;
std::string m_language;
std::string m_function;
std::vector<std::tuple<element::Type, PartialShape>> m_output_shapes;
};
...@@ -86,6 +86,7 @@ ...@@ -86,6 +86,7 @@
#include "ngraph/op/or.hpp" #include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp" #include "ngraph/op/pad.hpp"
#include "ngraph/op/parameter.hpp" #include "ngraph/op/parameter.hpp"
#include "ngraph/op/passthrough.hpp"
#include "ngraph/op/power.hpp" #include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp" #include "ngraph/op/quantize.hpp"
...@@ -856,6 +857,11 @@ std::string runtime::gpu::GPU_Emitter::emit_Parameter(EMIT_ARGS) ...@@ -856,6 +857,11 @@ std::string runtime::gpu::GPU_Emitter::emit_Parameter(EMIT_ARGS)
return ""; return "";
} }
std::string runtime::gpu::GPU_Emitter::emit_Passthrough(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_Power(EMIT_ARGS) std::string runtime::gpu::GPU_Emitter::emit_Power(EMIT_ARGS)
{ {
return emit_elementwise<ngraph::op::Power>(compiled_function, function_name, node, args, out); return emit_elementwise<ngraph::op::Power>(compiled_function, function_name, node, args, out);
......
...@@ -1912,6 +1912,7 @@ shared_ptr<runtime::Executable> ...@@ -1912,6 +1912,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::StopGradient: case OP_TYPEID::StopGradient:
case OP_TYPEID::TopK: case OP_TYPEID::TopK:
case OP_TYPEID::EmbeddingLookup: case OP_TYPEID::EmbeddingLookup:
case OP_TYPEID::Passthrough:
{ {
throw unsupported_op("Unsupported op '" + op->description() + throw unsupported_op("Unsupported op '" + op->description() +
"' in IntelGPU back end."); "' in IntelGPU back end.");
......
...@@ -44,6 +44,7 @@ ...@@ -44,6 +44,7 @@
#include "ngraph/op/min.hpp" #include "ngraph/op/min.hpp"
#include "ngraph/op/one_hot.hpp" #include "ngraph/op/one_hot.hpp"
#include "ngraph/op/pad.hpp" #include "ngraph/op/pad.hpp"
#include "ngraph/op/passthrough.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp" #include "ngraph/op/quantize.hpp"
#include "ngraph/op/replace_slice.hpp" #include "ngraph/op/replace_slice.hpp"
...@@ -922,6 +923,11 @@ private: ...@@ -922,6 +923,11 @@ private:
break; break;
} }
case OP_TYPEID::Parameter: break; case OP_TYPEID::Parameter: break;
case OP_TYPEID::Passthrough:
{
const op::Passthrough* passthrough = static_cast<const op::Passthrough*>(&node);
throw unsupported_op{"Unsupported operation language: " + passthrough->language()};
}
case OP_TYPEID::Pad: case OP_TYPEID::Pad:
{ {
const op::Pad* pad = static_cast<const op::Pad*>(&node); const op::Pad* pad = static_cast<const op::Pad*>(&node);
......
...@@ -37,6 +37,7 @@ set(SRC ...@@ -37,6 +37,7 @@ set(SRC
plaidml_ops_local_response_norm.cpp plaidml_ops_local_response_norm.cpp
plaidml_ops_logical.cpp plaidml_ops_logical.cpp
plaidml_ops_one_hot.cpp plaidml_ops_one_hot.cpp
plaidml_ops_passthrough.cpp
plaidml_ops_pool.cpp plaidml_ops_pool.cpp
plaidml_ops_reduce.cpp plaidml_ops_reduce.cpp
plaidml_ops_replace_slice.cpp plaidml_ops_replace_slice.cpp
...@@ -44,7 +45,6 @@ set(SRC ...@@ -44,7 +45,6 @@ set(SRC
plaidml_ops_reverse.cpp plaidml_ops_reverse.cpp
plaidml_ops_slice.cpp plaidml_ops_slice.cpp
plaidml_ops_softmax.cpp plaidml_ops_softmax.cpp
plaidml_ops_tile.cpp
plaidml_ops_transcendental.cpp plaidml_ops_transcendental.cpp
plaidml_ops_winograd.cpp plaidml_ops_winograd.cpp
plaidml_pass_concat_elision.cpp plaidml_pass_concat_elision.cpp
......
...@@ -37,8 +37,8 @@ void ngraph::runtime::plaidml::ImplLRN::Apply() ...@@ -37,8 +37,8 @@ void ngraph::runtime::plaidml::ImplLRN::Apply()
auto rank = dim_limit - 2; auto rank = dim_limit - 2;
auto distance = op().get_nsize() / 2; auto distance = op().get_nsize() / 2;
std::ostringstream div_expr; std::ostringstream div_expr;
div_expr << "I / pow(" << op().get_bias() << ".0 + ((" << op().get_alpha() << ".0 / " div_expr << "I / pow(" << op().get_bias() << " + ((" << op().get_alpha() << " / "
<< op().get_nsize() << ".0) * S), " << op().get_beta() << ".0)"; << op().get_nsize() << ") * S), " << op().get_beta() << ")";
set_output( set_output(
start_tile_function() start_tile_function()
.add(builder::Input{op_input(), "I"}.add_dims({"N", "C"}).add_dims("D", 0, rank)) .add(builder::Input{op_input(), "I"}.add_dims({"N", "C"}).add_dims("D", 0, rank))
......
//***************************************************************************** //*****************************************************************************
// Copyright 2017-2018 Intel Corporation // Copyright 2017-2019 Intel Corporation
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -14,14 +14,13 @@ ...@@ -14,14 +14,13 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#pragma once #include <utility>
#include <tuple> #include "ngraph/except.hpp"
#include <vector> #include "ngraph/op/passthrough.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include <plaidml/plaidml++.h> namespace vp = vertexai::plaidml;
#include "ngraph/op/op.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -29,29 +28,29 @@ namespace ngraph ...@@ -29,29 +28,29 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
namespace op NGRAPH_PLAIDML_OP_CLASS(ImplPassthrough, OpImpl<op::Passthrough>);
{
/// An op directly representing PlaidML Tile code.
class Tile;
}
} }
} }
} }
class ngraph::runtime::plaidml::op::Tile final : public Node void ngraph::runtime::plaidml::ImplPassthrough::Apply()
{ {
public: if (op().language() != "Tile")
Tile(const std::string& node_type, {
vertexai::plaidml::function function, throw unsupported_op{"Unsupported operation language: " + op().language()};
const NodeVector& args, }
std::vector<std::tuple<element::Type, PartialShape>> outputs);
void validate_and_infer_types() final; vertexai::plaidml::function::positional_t inputs;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const final; for (std::size_t idx = 0; idx < op().get_input_size(); ++idx)
{
inputs.emplace_back(op_input(idx));
}
auto app = vp::function{op().function()}.apply(inputs);
vertexai::plaidml::function func() const { return m_function; } for (std::size_t idx = 0; idx < op().get_output_size(); ++idx)
private: {
vertexai::plaidml::function m_function; set_output(idx, app.get_output(idx));
std::vector<std::tuple<element::Type, PartialShape>> m_output_shapes; }
}; }
...@@ -26,11 +26,11 @@ ...@@ -26,11 +26,11 @@
#include "ngraph/op/not.hpp" #include "ngraph/op/not.hpp"
#include "ngraph/op/not_equal.hpp" #include "ngraph/op/not_equal.hpp"
#include "ngraph/op/or.hpp" #include "ngraph/op/or.hpp"
#include "ngraph/op/passthrough.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp" #include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/any_of.hpp" #include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/plaidml/plaidml_ops_tile.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_explicit_logicals.hpp" #include "ngraph/runtime/plaidml/plaidml_pass_explicit_logicals.hpp"
void ngraph::runtime::plaidml::pass::ExplicitLogicals::construct_logical_to_data() void ngraph::runtime::plaidml::pass::ExplicitLogicals::construct_logical_to_data()
...@@ -79,9 +79,10 @@ void ngraph::runtime::plaidml::pass::ExplicitLogicals::construct_logical_to_data ...@@ -79,9 +79,10 @@ void ngraph::runtime::plaidml::pass::ExplicitLogicals::construct_logical_to_data
ngraph::insert_new_node_between( ngraph::insert_new_node_between(
producer, producer,
consumer, consumer,
std::make_shared<op::Tile>( std::make_shared<op::Passthrough>(
"ConvertLogicalToData", "ConvertLogicalToData",
vertexai::plaidml::function{"function (I) -> (O) { O = as_int(I ? 1 : 0, 8);}"}, "Tile",
"function (I) -> (O) { O = as_int(I ? 1 : 0, 8);}",
NodeVector{producer}, NodeVector{producer},
std::vector<std::tuple<element::Type, PartialShape>>{ std::vector<std::tuple<element::Type, PartialShape>>{
{std::make_tuple(element::i8, PartialShape{producer->get_output_shape(0)})}})); {std::make_tuple(element::i8, PartialShape{producer->get_output_shape(0)})}}));
......
...@@ -61,6 +61,7 @@ generate_mask ...@@ -61,6 +61,7 @@ generate_mask
avg_pool_3d avg_pool_3d
avg_pool_3d_uneven_strided_padded_include_in_computation avg_pool_3d_uneven_strided_padded_include_in_computation
quantize_dynamic_offset # Quantization/Dequantization is unimplemented quantize_dynamic_offset # Quantization/Dequantization is unimplemented
dequantize_dynamic_offset # Quantization/Dequantization is unimplemented
dequantize_int8_zero_offset # Quantization/Dequantization is unimplemented dequantize_int8_zero_offset # Quantization/Dequantization is unimplemented
dequantize_int32 # Quantization/Dequantization is unimplemented dequantize_int32 # Quantization/Dequantization is unimplemented
dequantize_int32_zero_offset # Quantization/Dequantization is unimplemented dequantize_int32_zero_offset # Quantization/Dequantization is unimplemented
...@@ -100,3 +101,4 @@ sum_stable_acc_double # To debug: precision errors ...@@ -100,3 +101,4 @@ sum_stable_acc_double # To debug: precision errors
embedding_lookup_4x5_reverse embedding_lookup_4x5_reverse
embedding_lookup_10x1_arbitrary embedding_lookup_10x1_arbitrary
embedding_lookup_10x1_arbitrary_index_type_int embedding_lookup_10x1_arbitrary_index_type_int
floor_int32
...@@ -75,6 +75,7 @@ ...@@ -75,6 +75,7 @@
#include "ngraph/op/or.hpp" #include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp" #include "ngraph/op/pad.hpp"
#include "ngraph/op/parameter.hpp" #include "ngraph/op/parameter.hpp"
#include "ngraph/op/passthrough.hpp"
#include "ngraph/op/power.hpp" #include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp" #include "ngraph/op/quantize.hpp"
...@@ -943,6 +944,22 @@ static shared_ptr<ngraph::Function> ...@@ -943,6 +944,22 @@ static shared_ptr<ngraph::Function>
make_shared<op::Parameter>(element_type, read_partial_shape(shape), cacheable); make_shared<op::Parameter>(element_type, read_partial_shape(shape), cacheable);
break; break;
} }
case OP_TYPEID::Passthrough:
{
std::vector<json> outputs_js = node_js.at("output_shapes");
std::vector<std::tuple<element::Type, PartialShape>> outputs;
for (auto output_js : outputs_js)
{
outputs.emplace_back(read_element_type(output_js.at("element_type")),
read_partial_shape(output_js.at("shape")));
}
node = make_shared<op::Passthrough>(node_js.at("logical_type"),
node_js.at("language"),
node_js.at("function"),
args,
std::move(outputs));
break;
}
case OP_TYPEID::Power: case OP_TYPEID::Power:
{ {
node = make_shared<op::Power>(args[0], args[1]); node = make_shared<op::Power>(args[0], args[1]);
...@@ -1557,6 +1574,23 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1557,6 +1574,23 @@ static json write(const Node& n, bool binary_constant_data)
node["element_type"] = write_element_type(tmp->get_element_type()); node["element_type"] = write_element_type(tmp->get_element_type());
break; break;
} }
case OP_TYPEID::Passthrough:
{
auto tmp = dynamic_cast<const op::Passthrough*>(&n);
node["logical_type"] = tmp->logical_type();
node["language"] = tmp->language();
node["function"] = tmp->function();
std::vector<json> outputs_js;
for (const auto& output_shape : tmp->output_shapes())
{
json output_js;
output_js["element_type"] = write_element_type(std::get<0>(output_shape));
output_js["shape"] = write_partial_shape(std::get<1>(output_shape));
outputs_js.emplace_back(std::move(output_js));
}
node["output_shapes"] = std::move(outputs_js);
break;
}
case OP_TYPEID::Product: case OP_TYPEID::Product:
{ {
auto tmp = dynamic_cast<const op::Product*>(&n); auto tmp = dynamic_cast<const op::Product*>(&n);
......
...@@ -17,10 +17,13 @@ ...@@ -17,10 +17,13 @@
#include <fstream> #include <fstream>
#include <sstream> #include <sstream>
#include "gmock/gmock.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/file_util.hpp" #include "ngraph/file_util.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/passthrough.hpp"
#include "ngraph/serializer.hpp" #include "ngraph/serializer.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
...@@ -30,6 +33,10 @@ using namespace std; ...@@ -30,6 +33,10 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
using json = nlohmann::json; using json = nlohmann::json;
using ::testing::ElementsAre;
using ::testing::NotNull;
using ::testing::StrEq;
template <typename T> template <typename T>
T get_or_default(nlohmann::json& j, const std::string& key, const T& default_value) T get_or_default(nlohmann::json& j, const std::string& key, const T& default_value)
{ {
...@@ -154,3 +161,50 @@ TEST(benchmark, serialize) ...@@ -154,3 +161,50 @@ TEST(benchmark, serialize)
timer.stop(); timer.stop();
cout << "deserialize took " << timer.get_milliseconds() << "ms\n"; cout << "deserialize took " << timer.get_milliseconds() << "ms\n";
} }
MATCHER_P2(IsOutputShape, type, shape, "")
{
return std::get<0>(arg) == type && std::get<1>(arg).to_shape() == shape;
}
TEST(serialize, passthrough)
{
const string tmp_file = "serialize_passthrough.json";
using estuple = std::tuple<element::Type, PartialShape>;
Shape shape{2, 2, 2};
auto p = make_shared<op::Passthrough>(
"SerializationTest",
"Plain",
"Hello, world!",
NodeVector{},
std::vector<estuple>{estuple{element::f32, PartialShape{2, 3}},
estuple{element::i8, PartialShape{4, 5}}});
auto f = make_shared<Function>(NodeVector{std::make_shared<op::GetOutputElement>(p, 0),
std::make_shared<op::GetOutputElement>(p, 1)},
ParameterVector{});
serialize(tmp_file, f);
auto g = deserialize(tmp_file);
file_util::remove_file(tmp_file);
ASSERT_THAT(g, NotNull());
std::shared_ptr<op::Passthrough> pt;
for (const auto& op : g->get_ops())
{
pt = dynamic_pointer_cast<op::Passthrough>(op);
if (pt)
{
break;
}
}
ASSERT_THAT(pt.get(), NotNull());
EXPECT_THAT(pt->logical_type(), StrEq("SerializationTest"));
EXPECT_THAT(pt->language(), StrEq("Plain"));
EXPECT_THAT(pt->function(), StrEq("Hello, world!"));
EXPECT_THAT(pt->output_shapes(),
ElementsAre(IsOutputShape(element::f32, Shape{2, 3}),
IsOutputShape(element::i8, Shape{4, 5})));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment