Commit dd23b0cb authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Split runtime backend and executable source files (#2544)

* split runtime backend and executable source files

* style
parent a8559a67
......@@ -156,12 +156,13 @@ set (SRC
runtime/aligned_buffer.cpp
runtime/backend.cpp
runtime/backend_manager.cpp
state/rng_state.cpp
runtime/executable.cpp
runtime/host_tensor.cpp
runtime/tensor.cpp
serializer.cpp
shape.cpp
shape_util.cpp
state/rng_state.cpp
strides.cpp
type/bfloat16.cpp
type/element_type.cpp
......
......@@ -54,99 +54,6 @@ bool runtime::Backend::is_supported(const Node& node) const
return false;
}
runtime::Executable::Executable()
{
}
runtime::Executable::~Executable()
{
}
bool runtime::Executable::call_with_validate(const vector<shared_ptr<runtime::Tensor>>& outputs,
const vector<shared_ptr<runtime::Tensor>>& inputs)
{
validate(outputs, inputs);
return call(outputs, inputs);
}
void runtime::Executable::validate(const vector<std::shared_ptr<runtime::Tensor>>& outputs,
const vector<std::shared_ptr<runtime::Tensor>>& inputs)
{
const ParameterVector& parameters = get_parameters();
const ResultVector& results = get_results();
if (parameters.size() != inputs.size())
{
stringstream ss;
ss << "Call input count " << inputs.size() << " does not match Function's Parameter count "
<< parameters.size();
throw runtime_error(ss.str());
}
if (results.size() != outputs.size())
{
stringstream ss;
ss << "Call output count " << outputs.size() << " does not match Function's Result count "
<< results.size();
throw runtime_error(ss.str());
}
for (size_t i = 0; i < parameters.size(); i++)
{
if (parameters[i]->get_element_type() != inputs[i]->get_element_type())
{
stringstream ss;
ss << "Input " << i << " type '" << inputs[i]->get_element_type()
<< "' does not match Parameter type '" << parameters[i]->get_element_type() << "'";
throw runtime_error(ss.str());
}
if (parameters[i]->get_shape() != inputs[i]->get_shape())
{
stringstream ss;
ss << "Input " << i << " shape {" << join(inputs[i]->get_shape())
<< "} does not match Parameter shape {" << join(parameters[i]->get_shape()) << "}";
throw runtime_error(ss.str());
}
}
for (size_t i = 0; i < results.size(); i++)
{
if (results[i]->get_element_type() != outputs[i]->get_element_type())
{
stringstream ss;
ss << "Output " << i << " type '" << outputs[i]->get_element_type()
<< "' does not match Result type '" << results[i]->get_element_type() << "'";
throw runtime_error(ss.str());
}
if (results[i]->get_shape() != outputs[i]->get_shape())
{
stringstream ss;
ss << "Output " << i << " shape {" << join(outputs[i]->get_shape())
<< "} does not match Result shape {" << join(results[i]->get_shape()) << "}";
throw runtime_error(ss.str());
}
}
}
const ngraph::ParameterVector& runtime::Executable::get_parameters() const
{
return m_parameters;
}
const ngraph::ResultVector& runtime::Executable::get_results() const
{
return m_results;
}
void runtime::Executable::set_parameters_and_results(const Function& func)
{
m_parameters = func.get_parameters();
m_results = func.get_results();
}
vector<runtime::PerformanceCounter> runtime::Executable::get_performance_data() const
{
return vector<PerformanceCounter>();
}
bool runtime::Backend::is_supported_property(const Property prop) const
{
return false;
......
......@@ -20,6 +20,7 @@
#include "ngraph/function.hpp"
#include "ngraph/pass/pass_config.hpp"
#include "ngraph/runtime/executable.hpp"
#include "ngraph/runtime/performance_counter.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
......@@ -28,10 +29,8 @@ namespace ngraph
{
namespace runtime
{
class ExternalFunction;
class Tensor;
class Backend;
class Executable;
}
}
......@@ -111,51 +110,3 @@ public:
virtual void remove_compiled_function(std::shared_ptr<Executable> exec);
};
class ngraph::runtime::Executable
{
public:
Executable();
virtual ~Executable();
/// \param outputs vector of runtime::Tensor used as outputs
/// \param inputs vector of runtime::Tensor used as inputs
/// \returns true if iteration is successful, false otherwise
virtual bool call(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) = 0;
/// \brief Executes a single iteration of a Function.
/// \param outputs vector of runtime::Tensor used as outputs
/// \param inputs vector of runtime::Tensor used as inputs
/// \returns true if iteration is successful, false otherwise
bool call_with_validate(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs);
/// \brief Collect performance information gathered on a Function.
/// \returns Vector of PerformanceCounter information.
virtual std::vector<PerformanceCounter> get_performance_data() const;
/// \brief Validates a Function.
/// \param outputs vector of runtime::Tensor used as outputs
/// \param inputs vector of runtime::Tensor used as inputs
void validate(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs);
/// \brief Query the input Parameters
/// \returns an ngraph::op::ParameterVector of all input parameters
const ngraph::ParameterVector& get_parameters() const;
/// \brief Query the output Results
/// \returns an ngraph::ResultVector of all input parameters
const ngraph::ResultVector& get_results() const;
protected:
/// \brief Called at the end of compile to the values to be returned by get_parameters
/// and get_results
/// \param func The function with Results fully resolved.
void set_parameters_and_results(const Function& func);
private:
ngraph::ParameterVector m_parameters;
ngraph::ResultVector m_results;
};
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <sstream>
#include "ngraph/file_util.hpp"
#include "ngraph/runtime/executable.hpp"
#include "ngraph/runtime/tensor.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
runtime::Executable::Executable()
{
}
runtime::Executable::~Executable()
{
}
bool runtime::Executable::call_with_validate(const vector<shared_ptr<runtime::Tensor>>& outputs,
const vector<shared_ptr<runtime::Tensor>>& inputs)
{
validate(outputs, inputs);
return call(outputs, inputs);
}
void runtime::Executable::validate(const vector<std::shared_ptr<runtime::Tensor>>& outputs,
const vector<std::shared_ptr<runtime::Tensor>>& inputs)
{
const ParameterVector& parameters = get_parameters();
const ResultVector& results = get_results();
if (parameters.size() != inputs.size())
{
stringstream ss;
ss << "Call input count " << inputs.size() << " does not match Function's Parameter count "
<< parameters.size();
throw runtime_error(ss.str());
}
if (results.size() != outputs.size())
{
stringstream ss;
ss << "Call output count " << outputs.size() << " does not match Function's Result count "
<< results.size();
throw runtime_error(ss.str());
}
for (size_t i = 0; i < parameters.size(); i++)
{
if (parameters[i]->get_element_type() != inputs[i]->get_element_type())
{
stringstream ss;
ss << "Input " << i << " type '" << inputs[i]->get_element_type()
<< "' does not match Parameter type '" << parameters[i]->get_element_type() << "'";
throw runtime_error(ss.str());
}
if (parameters[i]->get_shape() != inputs[i]->get_shape())
{
stringstream ss;
ss << "Input " << i << " shape {" << join(inputs[i]->get_shape())
<< "} does not match Parameter shape {" << join(parameters[i]->get_shape()) << "}";
throw runtime_error(ss.str());
}
}
for (size_t i = 0; i < results.size(); i++)
{
if (results[i]->get_element_type() != outputs[i]->get_element_type())
{
stringstream ss;
ss << "Output " << i << " type '" << outputs[i]->get_element_type()
<< "' does not match Result type '" << results[i]->get_element_type() << "'";
throw runtime_error(ss.str());
}
if (results[i]->get_shape() != outputs[i]->get_shape())
{
stringstream ss;
ss << "Output " << i << " shape {" << join(outputs[i]->get_shape())
<< "} does not match Result shape {" << join(results[i]->get_shape()) << "}";
throw runtime_error(ss.str());
}
}
}
const ngraph::ParameterVector& runtime::Executable::get_parameters() const
{
return m_parameters;
}
const ngraph::ResultVector& runtime::Executable::get_results() const
{
return m_results;
}
void runtime::Executable::set_parameters_and_results(const Function& func)
{
m_parameters = func.get_parameters();
m_results = func.get_results();
}
vector<runtime::PerformanceCounter> runtime::Executable::get_performance_data() const
{
return vector<PerformanceCounter>();
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/function.hpp"
#include "ngraph/runtime/performance_counter.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph
{
namespace runtime
{
class Tensor;
class Executable;
}
}
class ngraph::runtime::Executable
{
public:
Executable();
virtual ~Executable();
/// \param outputs vector of runtime::Tensor used as outputs
/// \param inputs vector of runtime::Tensor used as inputs
/// \returns true if iteration is successful, false otherwise
virtual bool call(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) = 0;
/// \brief Executes a single iteration of a Function.
/// \param outputs vector of runtime::Tensor used as outputs
/// \param inputs vector of runtime::Tensor used as inputs
/// \returns true if iteration is successful, false otherwise
bool call_with_validate(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs);
/// \brief Collect performance information gathered on a Function.
/// \returns Vector of PerformanceCounter information.
virtual std::vector<PerformanceCounter> get_performance_data() const;
/// \brief Validates a Function.
/// \param outputs vector of runtime::Tensor used as outputs
/// \param inputs vector of runtime::Tensor used as inputs
void validate(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs);
/// \brief Query the input Parameters
/// \returns an ngraph::op::ParameterVector of all input parameters
const ngraph::ParameterVector& get_parameters() const;
/// \brief Query the output Results
/// \returns an ngraph::ResultVector of all input parameters
const ngraph::ResultVector& get_results() const;
protected:
/// \brief Called at the end of compile to the values to be returned by get_parameters
/// and get_results
/// \param func The function with Results fully resolved.
void set_parameters_and_results(const Function& func);
private:
ngraph::ParameterVector m_parameters;
ngraph::ResultVector m_results;
};
......@@ -21,7 +21,7 @@
#include <string>
#include <vector>
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/executable.hpp"
namespace ngraph
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment