Unverified Commit 40ddf45a authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Allow tensor[view, layout] element type and shape to be modified (#1440)

parent e63ffa29
...@@ -22,16 +22,24 @@ ...@@ -22,16 +22,24 @@
using namespace ngraph; using namespace ngraph;
descriptor::layout::TensorViewLayout::TensorViewLayout(const descriptor::TensorView& tensor_view) descriptor::layout::TensorViewLayout::TensorViewLayout(const descriptor::TensorView& tensor_view)
: m_tensor_view_type(tensor_view.get_tensor_view_type()) : m_element_type(tensor_view.get_element_type())
, m_shape(tensor_view.get_shape())
{ {
} }
const element::Type& descriptor::layout::TensorViewLayout::get_element_type() const const element::Type& descriptor::layout::TensorViewLayout::get_element_type() const
{ {
return m_tensor_view_type->get_element_type(); return m_element_type;
} }
const Shape& descriptor::layout::TensorViewLayout::get_shape() const const Shape& descriptor::layout::TensorViewLayout::get_shape() const
{ {
return m_tensor_view_type->get_shape(); return m_shape;
}
void descriptor::layout::TensorViewLayout::set_tensor_view_type(const element::Type& element_type,
const Shape& shape)
{
m_element_type = element_type;
m_shape = shape;
} }
...@@ -62,8 +62,11 @@ namespace ngraph ...@@ -62,8 +62,11 @@ namespace ngraph
/// @brief Return true if this and other have the same element interpretation /// @brief Return true if this and other have the same element interpretation
virtual bool operator==(const TensorViewLayout& other) const = 0; virtual bool operator==(const TensorViewLayout& other) const = 0;
bool operator!=(const TensorViewLayout& other) const { return !(*this == other); } bool operator!=(const TensorViewLayout& other) const { return !(*this == other); }
void set_tensor_view_type(const element::Type& element_type, const Shape& shape);
protected: protected:
std::shared_ptr<const TensorViewType> m_tensor_view_type; element::Type m_element_type;
Shape m_shape;
}; };
} }
} }
......
...@@ -45,6 +45,10 @@ namespace ngraph ...@@ -45,6 +45,10 @@ namespace ngraph
std::shared_ptr<Node> get_node() const; std::shared_ptr<Node> get_node() const;
size_t get_index() const { return m_index; } size_t get_index() const { return m_index; }
std::shared_ptr<TensorView> get_tensor_view() const { return m_tensor_view; } std::shared_ptr<TensorView> get_tensor_view() const { return m_tensor_view; }
void set_tensor_view(const std::shared_ptr<TensorView>& tensor_view)
{
m_tensor_view = tensor_view;
}
void add_input(Input* input); void add_input(Input* input);
void remove_input(Input* input); void remove_input(Input* input);
const std::set<Input*>& get_inputs() const { return m_inputs; } const std::set<Input*>& get_inputs() const { return m_inputs; }
......
...@@ -38,3 +38,9 @@ Tensor& PrimaryTensorView::get_tensor() ...@@ -38,3 +38,9 @@ Tensor& PrimaryTensorView::get_tensor()
{ {
return m_tensor; return m_tensor;
} }
void PrimaryTensorView::set_tensor_view_type(const element::Type& element_type, const Shape& shape)
{
TensorView::set_tensor_view_type(element_type, shape);
m_tensor.set_element_type(element_type);
}
...@@ -39,6 +39,8 @@ namespace ngraph ...@@ -39,6 +39,8 @@ namespace ngraph
virtual const Tensor& get_tensor() const override; virtual const Tensor& get_tensor() const override;
virtual Tensor& get_tensor() override; virtual Tensor& get_tensor() override;
void set_tensor_view_type(const element::Type& element_type,
const Shape& shape) override;
protected: protected:
Tensor m_tensor; Tensor m_tensor;
......
...@@ -62,6 +62,11 @@ size_t descriptor::Tensor::get_pool_offset() const ...@@ -62,6 +62,11 @@ size_t descriptor::Tensor::get_pool_offset() const
return m_pool_offset; return m_pool_offset;
} }
void descriptor::Tensor::set_element_type(const element::Type& element_type)
{
m_element_type = element_type;
}
ostream& operator<<(ostream& out, const descriptor::Tensor& tensor) ostream& operator<<(ostream& out, const descriptor::Tensor& tensor)
{ {
out << "Tensor(" << tensor.get_name() << ")"; out << "Tensor(" << tensor.get_name() << ")";
......
...@@ -58,10 +58,11 @@ public: ...@@ -58,10 +58,11 @@ public:
void set_pool_offset(size_t); void set_pool_offset(size_t);
size_t get_pool_offset() const; size_t get_pool_offset() const;
const element::Type& get_element_type() const { return m_element_type; } const element::Type& get_element_type() const { return m_element_type; }
void set_element_type(const element::Type& element_type);
static std::string make_tensor_name(const Node* node, size_t value_index); static std::string make_tensor_name(const Node* node, size_t value_index);
protected: protected:
const element::Type m_element_type; element::Type m_element_type;
PrimaryTensorView* m_primary_tensor_view; PrimaryTensorView* m_primary_tensor_view;
std::string m_name; std::string m_name;
size_t m_next_view_id; size_t m_next_view_id;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
*******************************************************************************/ *******************************************************************************/
#include "ngraph/descriptor/tensor_view.hpp" #include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/descriptor/layout/tensor_view_layout.hpp"
#include "ngraph/type/type.hpp" #include "ngraph/type/type.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -24,3 +25,23 @@ shared_ptr<const ngraph::TensorViewType> descriptor::TensorView::get_value_type( ...@@ -24,3 +25,23 @@ shared_ptr<const ngraph::TensorViewType> descriptor::TensorView::get_value_type(
{ {
return m_tensor_view_type; return m_tensor_view_type;
} }
const element::Type& descriptor::TensorView::get_element_type() const
{
return m_tensor_view_type->get_element_type();
}
const Shape& descriptor::TensorView::get_shape() const
{
return m_tensor_view_type->get_shape();
}
void descriptor::TensorView::set_tensor_view_type(const element::Type& element_type,
const Shape& shape)
{
m_tensor_view_type = make_shared<ngraph::TensorViewType>(element_type, shape);
if (nullptr != m_tensor_view_layout)
{
m_tensor_view_layout->set_tensor_view_type(element_type, shape);
}
}
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <string> #include <string>
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -62,6 +63,12 @@ namespace ngraph ...@@ -62,6 +63,12 @@ namespace ngraph
return m_tensor_view_type; return m_tensor_view_type;
} }
virtual void set_tensor_view_type(const element::Type& element_type,
const Shape& shape);
const element::Type& get_element_type() const;
const Shape& get_shape() const;
const std::shared_ptr<layout::TensorViewLayout>& get_tensor_view_layout() const const std::shared_ptr<layout::TensorViewLayout>& get_tensor_view_layout() const
{ {
return m_tensor_view_layout; return m_tensor_view_layout;
......
...@@ -15,11 +15,13 @@ ...@@ -15,11 +15,13 @@
*******************************************************************************/ *******************************************************************************/
#include <cmath> #include <cmath>
#include <iostream>
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
using namespace ngraph; using namespace ngraph;
const element::Type element::unspecified(0, false, false, "unspecified");
const element::Type element::boolean(8, false, true, "char"); const element::Type element::boolean(8, false, true, "char");
const element::Type element::f32(32, true, true, "float"); const element::Type element::f32(32, true, true, "float");
const element::Type element::f64(64, true, true, "double"); const element::Type element::f64(64, true, true, "double");
...@@ -48,14 +50,6 @@ std::vector<const element::Type*> element::Type::get_known_types() ...@@ -48,14 +50,6 @@ std::vector<const element::Type*> element::Type::get_known_types()
return rc; return rc;
} }
element::Type::Type()
: m_bitwidth{0}
, m_is_real{0}
, m_is_signed{0}
, m_cname{}
{
}
element::Type::Type(size_t bitwidth, bool is_real, bool is_signed, const std::string& cname) element::Type::Type(size_t bitwidth, bool is_real, bool is_signed, const std::string& cname)
: m_bitwidth{bitwidth} : m_bitwidth{bitwidth}
, m_is_real{is_real} , m_is_real{is_real}
...@@ -64,6 +58,15 @@ element::Type::Type(size_t bitwidth, bool is_real, bool is_signed, const std::st ...@@ -64,6 +58,15 @@ element::Type::Type(size_t bitwidth, bool is_real, bool is_signed, const std::st
{ {
} }
element::Type& element::Type::operator=(const element::Type& t)
{
m_bitwidth = t.m_bitwidth;
m_is_real = t.m_is_real;
m_is_signed = t.m_is_signed;
m_cname = t.m_cname;
return *this;
}
const std::string& element::Type::c_type_string() const const std::string& element::Type::c_type_string() const
{ {
return m_cname; return m_cname;
...@@ -170,7 +173,7 @@ namespace ngraph ...@@ -170,7 +173,7 @@ namespace ngraph
std::ostream& element::operator<<(std::ostream& out, const element::Type& obj) std::ostream& element::operator<<(std::ostream& out, const element::Type& obj)
{ {
out << "element::Type(" << obj.m_bitwidth << ", " << obj.m_is_real << ", " << obj.m_is_signed out << "element::Type{" << obj.m_bitwidth << ", " << obj.m_is_real << ", " << obj.m_is_signed
<< ")"; << "," << obj.m_cname << "}";
return out; return out;
} }
...@@ -33,6 +33,7 @@ namespace ngraph ...@@ -33,6 +33,7 @@ namespace ngraph
{ {
class Type; class Type;
extern const Type unspecified;
extern const Type boolean; extern const Type boolean;
extern const Type f32; extern const Type f32;
extern const Type f64; extern const Type f64;
...@@ -48,10 +49,10 @@ namespace ngraph ...@@ -48,10 +49,10 @@ namespace ngraph
class Type class Type
{ {
public: public:
Type(); Type() {}
Type(const Type&) = default; Type(const Type&) = default;
Type(size_t bitwidth, bool is_real, bool is_signed, const std::string& cname); Type(size_t bitwidth, bool is_real, bool is_signed, const std::string& cname);
Type& operator=(const Type&) = default; Type& operator=(const Type&);
virtual ~Type() {} virtual ~Type() {}
const std::string& c_type_string() const; const std::string& c_type_string() const;
size_t size() const; size_t size() const;
...@@ -68,10 +69,10 @@ namespace ngraph ...@@ -68,10 +69,10 @@ namespace ngraph
/// Returns true if the type is floating point, else false. /// Returns true if the type is floating point, else false.
bool get_is_real() const { return m_is_real; } bool get_is_real() const { return m_is_real; }
private: private:
size_t m_bitwidth; size_t m_bitwidth{0};
bool m_is_real; bool m_is_real{false};
bool m_is_signed; bool m_is_signed{false};
std::string m_cname; std::string m_cname{"unspecified"};
}; };
template <typename T> template <typename T>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment