Commit a2ccc1a4 authored by Robert Kimball's avatar Robert Kimball

basic API

parent 39cdee0e
...@@ -23,12 +23,14 @@ using namespace ngraph; ...@@ -23,12 +23,14 @@ using namespace ngraph;
op::Parameter::Parameter(const element::Type& element_type, op::Parameter::Parameter(const element::Type& element_type,
const PartialShape& pshape, const PartialShape& pshape,
const bool cacheable) const bool cacheable,
bool can_double_buffer)
: Op("Parameter", {}) : Op("Parameter", {})
, m_cacheable(cacheable) , m_cacheable(cacheable)
, m_partial_shape(pshape) , m_partial_shape(pshape)
, m_element_type(element_type) , m_element_type(element_type)
, m_is_relevant_to_shapes(false) , m_is_relevant_to_shapes(false)
, m_can_double_buffer(can_double_buffer)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -42,7 +42,8 @@ namespace ngraph ...@@ -42,7 +42,8 @@ namespace ngraph
/// \param cacheable True if the parameter is not expected to be frequently updated. /// \param cacheable True if the parameter is not expected to be frequently updated.
Parameter(const ngraph::element::Type& element_type, Parameter(const ngraph::element::Type& element_type,
const PartialShape& pshape, const PartialShape& pshape,
const bool cacheable = false); const bool cacheable = false,
bool can_double_buffer = false);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -52,12 +53,13 @@ namespace ngraph ...@@ -52,12 +53,13 @@ namespace ngraph
bool is_relevant_to_shapes() const; bool is_relevant_to_shapes() const;
void set_is_relevant_to_shapes(bool is_relevant); void set_is_relevant_to_shapes(bool is_relevant);
bool get_can_double_buffer() const { return m_can_double_buffer; }
protected: protected:
bool m_cacheable; bool m_cacheable;
PartialShape m_partial_shape; PartialShape m_partial_shape;
element::Type m_element_type; element::Type m_element_type;
bool m_is_relevant_to_shapes; bool m_is_relevant_to_shapes;
bool m_can_double_buffer;
}; };
} }
using ParameterVector = std::vector<std::shared_ptr<op::Parameter>>; using ParameterVector = std::vector<std::shared_ptr<op::Parameter>>;
......
...@@ -24,8 +24,9 @@ ...@@ -24,8 +24,9 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Result::Result(const shared_ptr<Node>& arg) op::Result::Result(const shared_ptr<Node>& arg, bool can_double_buffer)
: Op("Result", check_single_output_args({arg})) : Op("Result", check_single_output_args({arg}))
, m_can_double_buffer(can_double_buffer)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
// always borrow the placement conf even the default one // always borrow the placement conf even the default one
......
...@@ -30,7 +30,7 @@ namespace ngraph ...@@ -30,7 +30,7 @@ namespace ngraph
/// \brief Allows a value to be used as a function result. /// \brief Allows a value to be used as a function result.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Result(const std::shared_ptr<Node>& arg); Result(const std::shared_ptr<Node>& arg, bool can_double_buffer = false);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -40,12 +40,14 @@ namespace ngraph ...@@ -40,12 +40,14 @@ namespace ngraph
virtual bool is_output() const override { return true; } virtual bool is_output() const override { return true; }
void set_needs_default_layout(bool val) { m_needs_default_layout = val; } void set_needs_default_layout(bool val) { m_needs_default_layout = val; }
bool needs_default_layout() const { return m_needs_default_layout; } bool needs_default_layout() const { return m_needs_default_layout; }
bool get_can_double_buffer() const { return m_can_double_buffer; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
private: private:
bool m_needs_default_layout{false}; bool m_needs_default_layout{false};
bool m_can_double_buffer;
}; };
} }
using ResultVector = std::vector<std::shared_ptr<op::Result>>; using ResultVector = std::vector<std::shared_ptr<op::Result>>;
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#pragma once #pragma once
#include <future>
#include <memory> #include <memory>
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
...@@ -51,6 +52,10 @@ public: ...@@ -51,6 +52,10 @@ public:
bool call_with_validate(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs, bool call_with_validate(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs); const std::vector<std::shared_ptr<runtime::Tensor>>& inputs);
virtual std::future<bool>&
begin_execute(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs);
/// \brief Collect performance information gathered on a Function. /// \brief Collect performance information gathered on a Function.
/// \returns Vector of PerformanceCounter information. /// \returns Vector of PerformanceCounter information.
virtual std::vector<PerformanceCounter> get_performance_data() const; virtual std::vector<PerformanceCounter> get_performance_data() const;
...@@ -82,4 +87,5 @@ protected: ...@@ -82,4 +87,5 @@ protected:
private: private:
ngraph::ParameterVector m_parameters; ngraph::ParameterVector m_parameters;
ngraph::ResultVector m_results; ngraph::ResultVector m_results;
std::future<bool> m_future;
}; };
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#pragma once #pragma once
#include <future>
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -105,6 +106,18 @@ namespace ngraph ...@@ -105,6 +106,18 @@ namespace ngraph
/// \param n Number of bytes to read, must be integral number of elements. /// \param n Number of bytes to read, must be integral number of elements.
virtual void read(void* p, size_t offset, size_t n) const = 0; virtual void read(void* p, size_t offset, size_t n) const = 0;
/// \brief Write bytes directly into the tensor
/// \param p Pointer to source of data
/// \param n Number of bytes to write, must be integral number of elements.
/// \return std::future to track the operation
virtual std::future<bool> begin_write(const void* p, size_t n);
/// \brief Read bytes directly from the tensor
/// \param p Pointer to destination for data
/// \param n Number of bytes to read, must be integral number of elements.
/// \return std::future to track the operation
virtual std::future<bool> begin_read(void* p, size_t n) const;
/// \brief copy bytes directly from source to this tensor /// \brief copy bytes directly from source to this tensor
/// \param source The source tensor /// \param source The source tensor
virtual void copy_from(const ngraph::runtime::Tensor& source); virtual void copy_from(const ngraph::runtime::Tensor& source);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment