Unverified Commit b804cc90 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Hybrid transformer test (#1938)

* wip

* wip

* simple hybrid test harness

* cleanup

* disable unit test in progress
parent f0acb7da
......@@ -61,7 +61,12 @@ if (NGRAPH_ONNX_IMPORT_ENABLE)
endif()
if (NGRAPH_INTERPRETER_ENABLE)
list(APPEND SRC backend_debug_api.cpp builder.cpp backend_api.cpp)
list(APPEND SRC
backend_debug_api.cpp
builder.cpp
backend_api.cpp
hybrid_backend.cpp
hybrid_utils.cpp)
set(ACTIVE_BACKEND_LIST ${ACTIVE_BACKEND_LIST} INTERPRETER)
endif()
......
......@@ -13,3 +13,74 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "gtest/gtest.h"
#include "hybrid_utils.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"
using namespace std;
using namespace ngraph;
static runtime::Backend* hybrid1_creator(const char* config)
{
vector<shared_ptr<runtime::Backend>> backend_list;
set<string> s0 = {"Add"};
auto b0 = make_shared<BackendWrapper>("INTERPRETER", s0, "AddOnly");
backend_list.push_back(b0);
#define NGRAPH_OP(a, b) #a,
set<string> s1 = {
#include "ngraph/op/op_tbl.hpp"
};
auto b1 = make_shared<BackendWrapper>("INTERPRETER", s1, "AllOps");
backend_list.push_back(b1);
return new TestBackend(backend_list);
}
TEST(DISABLED_HYBRID, abc)
{
const string backend_name = "HYBRID1";
runtime::BackendManager::register_backend(backend_name, hybrid1_creator);
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>((A + B) * C, op::ParameterVector{A, B, C});
auto backend = runtime::Backend::create(backend_name);
// Create some tensors for input/output
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> c = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::f32, shape);
copy_data(a, test::NDArray<float, 2>({{1, 2}, {3, 4}}).get_vector());
copy_data(b, test::NDArray<float, 2>({{5, 6}, {7, 8}}).get_vector());
copy_data(c, test::NDArray<float, 2>({{9, 10}, {11, 12}}).get_vector());
backend->call_with_validate(f, {result}, {a, b, c});
EXPECT_EQ(read_vector<float>(result),
(test::NDArray<float, 2>({{54, 80}, {110, 144}})).get_vector());
backend->call_with_validate(f, {result}, {b, a, c});
EXPECT_EQ(read_vector<float>(result),
(test::NDArray<float, 2>({{54, 80}, {110, 144}})).get_vector());
backend->call_with_validate(f, {result}, {a, c, b});
EXPECT_EQ(read_vector<float>(result),
(test::NDArray<float, 2>({{50, 72}, {98, 128}})).get_vector());
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "hybrid_utils.hpp"
using namespace std;
using namespace ngraph;
TestBackend::TestBackend(const vector<shared_ptr<runtime::Backend>>& backend_list)
: m_backend_list{backend_list}
{
if (m_backend_list.size() == 0)
{
throw runtime_error("TestBackend backend list empty");
}
}
shared_ptr<runtime::Tensor> TestBackend::create_tensor(const element::Type& element_type,
const Shape& shape)
{
return m_backend_list[0]->create_tensor(element_type, shape);
}
shared_ptr<runtime::Tensor> TestBackend::create_tensor(const element::Type& element_type,
const Shape& shape,
void* memory_pointer)
{
return m_backend_list[0]->create_tensor(element_type, shape, memory_pointer);
}
bool TestBackend::compile(shared_ptr<Function> func)
{
return m_backend_list[0]->compile(func);
}
bool TestBackend::call(shared_ptr<Function> func,
const vector<shared_ptr<runtime::Tensor>>& outputs,
const vector<shared_ptr<runtime::Tensor>>& inputs)
{
throw runtime_error("TestBackend call not supported");
// for (auto backend : m_backend_list)
// {
// if (backend->is_supported(node))
// {
// // backend supports the op
// }
// }
// return true;
}
BackendWrapper::BackendWrapper(const string& backend_name,
const set<string>& supported_ops,
const string& name)
: m_backend{runtime::Backend::create(backend_name)}
, m_supported_ops{supported_ops}
, m_name{name}
{
}
shared_ptr<runtime::Tensor> BackendWrapper::create_tensor(const element::Type& element_type,
const Shape& shape)
{
return m_backend->create_tensor(element_type, shape);
}
shared_ptr<runtime::Tensor> BackendWrapper::create_tensor(const element::Type& element_type,
const Shape& shape,
void* memory_pointer)
{
return m_backend->create_tensor(element_type, shape, memory_pointer);
}
bool BackendWrapper::compile(shared_ptr<Function> func)
{
return m_backend->compile(func);
}
bool BackendWrapper::call(shared_ptr<Function> func,
const vector<shared_ptr<runtime::Tensor>>& outputs,
const vector<shared_ptr<runtime::Tensor>>& inputs)
{
return m_backend->call(func, outputs, inputs);
}
bool BackendWrapper::is_supported(const Node& node) const
{
return m_supported_ops.find(node.description()) != m_supported_ops.end();
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/ngraph.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
#include "util/random.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"
class TestBackend : public ngraph::runtime::Backend
{
public:
TestBackend(const std::vector<std::shared_ptr<ngraph::runtime::Backend>>& backend_list);
std::shared_ptr<ngraph::runtime::Tensor>
create_tensor(const ngraph::element::Type& element_type,
const ngraph::Shape& shape) override;
std::shared_ptr<ngraph::runtime::Tensor>
create_tensor(const ngraph::element::Type& element_type,
const ngraph::Shape& shape,
void* memory_pointer) override;
bool compile(std::shared_ptr<ngraph::Function> func) override;
bool call(std::shared_ptr<ngraph::Function> func,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& inputs) override;
private:
// This list of backends is in order of priority with the first backend higher priority
// than the second.
std::vector<std::shared_ptr<ngraph::runtime::Backend>> m_backend_list;
};
class BackendWrapper : public ngraph::runtime::Backend
{
public:
BackendWrapper(const std::string& backend_name,
const std::set<std::string>& supported_ops,
const std::string& name);
std::shared_ptr<ngraph::runtime::Tensor>
create_tensor(const ngraph::element::Type& element_type,
const ngraph::Shape& shape) override;
std::shared_ptr<ngraph::runtime::Tensor>
create_tensor(const ngraph::element::Type& element_type,
const ngraph::Shape& shape,
void* memory_pointer) override;
bool compile(std::shared_ptr<ngraph::Function> func) override;
bool call(std::shared_ptr<ngraph::Function> func,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& inputs) override;
bool is_supported(const ngraph::Node& node) const override;
private:
std::shared_ptr<ngraph::runtime::Backend> m_backend;
const std::set<std::string> m_supported_ops;
const std::string m_name;
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment