Unverified Commit 40ddf45a authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Allow tensor[view, layout] element type and shape to be modified (#1440)

parent e63ffa29
......@@ -22,16 +22,24 @@
using namespace ngraph;
descriptor::layout::TensorViewLayout::TensorViewLayout(const descriptor::TensorView& tensor_view)
: m_tensor_view_type(tensor_view.get_tensor_view_type())
: m_element_type(tensor_view.get_element_type())
, m_shape(tensor_view.get_shape())
{
}
const element::Type& descriptor::layout::TensorViewLayout::get_element_type() const
{
return m_tensor_view_type->get_element_type();
return m_element_type;
}
const Shape& descriptor::layout::TensorViewLayout::get_shape() const
{
return m_tensor_view_type->get_shape();
return m_shape;
}
void descriptor::layout::TensorViewLayout::set_tensor_view_type(const element::Type& element_type,
const Shape& shape)
{
m_element_type = element_type;
m_shape = shape;
}
......@@ -62,8 +62,11 @@ namespace ngraph
/// @brief Return true if this and other have the same element interpretation
virtual bool operator==(const TensorViewLayout& other) const = 0;
bool operator!=(const TensorViewLayout& other) const { return !(*this == other); }
void set_tensor_view_type(const element::Type& element_type, const Shape& shape);
protected:
std::shared_ptr<const TensorViewType> m_tensor_view_type;
element::Type m_element_type;
Shape m_shape;
};
}
}
......
......@@ -45,6 +45,10 @@ namespace ngraph
std::shared_ptr<Node> get_node() const;
size_t get_index() const { return m_index; }
std::shared_ptr<TensorView> get_tensor_view() const { return m_tensor_view; }
void set_tensor_view(const std::shared_ptr<TensorView>& tensor_view)
{
m_tensor_view = tensor_view;
}
void add_input(Input* input);
void remove_input(Input* input);
const std::set<Input*>& get_inputs() const { return m_inputs; }
......
......@@ -38,3 +38,9 @@ Tensor& PrimaryTensorView::get_tensor()
{
return m_tensor;
}
void PrimaryTensorView::set_tensor_view_type(const element::Type& element_type, const Shape& shape)
{
TensorView::set_tensor_view_type(element_type, shape);
m_tensor.set_element_type(element_type);
}
......@@ -39,6 +39,8 @@ namespace ngraph
virtual const Tensor& get_tensor() const override;
virtual Tensor& get_tensor() override;
void set_tensor_view_type(const element::Type& element_type,
const Shape& shape) override;
protected:
Tensor m_tensor;
......
......@@ -62,6 +62,11 @@ size_t descriptor::Tensor::get_pool_offset() const
return m_pool_offset;
}
void descriptor::Tensor::set_element_type(const element::Type& element_type)
{
m_element_type = element_type;
}
ostream& operator<<(ostream& out, const descriptor::Tensor& tensor)
{
out << "Tensor(" << tensor.get_name() << ")";
......
......@@ -58,10 +58,11 @@ public:
void set_pool_offset(size_t);
size_t get_pool_offset() const;
const element::Type& get_element_type() const { return m_element_type; }
void set_element_type(const element::Type& element_type);
static std::string make_tensor_name(const Node* node, size_t value_index);
protected:
const element::Type m_element_type;
element::Type m_element_type;
PrimaryTensorView* m_primary_tensor_view;
std::string m_name;
size_t m_next_view_id;
......
......@@ -15,6 +15,7 @@
*******************************************************************************/
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/descriptor/layout/tensor_view_layout.hpp"
#include "ngraph/type/type.hpp"
using namespace ngraph;
......@@ -24,3 +25,23 @@ shared_ptr<const ngraph::TensorViewType> descriptor::TensorView::get_value_type(
{
return m_tensor_view_type;
}
const element::Type& descriptor::TensorView::get_element_type() const
{
return m_tensor_view_type->get_element_type();
}
const Shape& descriptor::TensorView::get_shape() const
{
return m_tensor_view_type->get_shape();
}
void descriptor::TensorView::set_tensor_view_type(const element::Type& element_type,
const Shape& shape)
{
m_tensor_view_type = make_shared<ngraph::TensorViewType>(element_type, shape);
if (nullptr != m_tensor_view_layout)
{
m_tensor_view_layout->set_tensor_view_type(element_type, shape);
}
}
......@@ -20,6 +20,7 @@
#include <string>
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph
{
......@@ -62,6 +63,12 @@ namespace ngraph
return m_tensor_view_type;
}
virtual void set_tensor_view_type(const element::Type& element_type,
const Shape& shape);
const element::Type& get_element_type() const;
const Shape& get_shape() const;
const std::shared_ptr<layout::TensorViewLayout>& get_tensor_view_layout() const
{
return m_tensor_view_layout;
......
......@@ -15,11 +15,13 @@
*******************************************************************************/
#include <cmath>
#include <iostream>
#include "ngraph/type/element_type.hpp"
using namespace ngraph;
const element::Type element::unspecified(0, false, false, "unspecified");
const element::Type element::boolean(8, false, true, "char");
const element::Type element::f32(32, true, true, "float");
const element::Type element::f64(64, true, true, "double");
......@@ -48,14 +50,6 @@ std::vector<const element::Type*> element::Type::get_known_types()
return rc;
}
element::Type::Type()
: m_bitwidth{0}
, m_is_real{0}
, m_is_signed{0}
, m_cname{}
{
}
element::Type::Type(size_t bitwidth, bool is_real, bool is_signed, const std::string& cname)
: m_bitwidth{bitwidth}
, m_is_real{is_real}
......@@ -64,6 +58,15 @@ element::Type::Type(size_t bitwidth, bool is_real, bool is_signed, const std::st
{
}
element::Type& element::Type::operator=(const element::Type& t)
{
m_bitwidth = t.m_bitwidth;
m_is_real = t.m_is_real;
m_is_signed = t.m_is_signed;
m_cname = t.m_cname;
return *this;
}
const std::string& element::Type::c_type_string() const
{
return m_cname;
......@@ -170,7 +173,7 @@ namespace ngraph
std::ostream& element::operator<<(std::ostream& out, const element::Type& obj)
{
out << "element::Type(" << obj.m_bitwidth << ", " << obj.m_is_real << ", " << obj.m_is_signed
<< ")";
out << "element::Type{" << obj.m_bitwidth << ", " << obj.m_is_real << ", " << obj.m_is_signed
<< "," << obj.m_cname << "}";
return out;
}
......@@ -33,6 +33,7 @@ namespace ngraph
{
class Type;
extern const Type unspecified;
extern const Type boolean;
extern const Type f32;
extern const Type f64;
......@@ -48,10 +49,10 @@ namespace ngraph
class Type
{
public:
Type();
Type() {}
Type(const Type&) = default;
Type(size_t bitwidth, bool is_real, bool is_signed, const std::string& cname);
Type& operator=(const Type&) = default;
Type& operator=(const Type&);
virtual ~Type() {}
const std::string& c_type_string() const;
size_t size() const;
......@@ -68,10 +69,10 @@ namespace ngraph
/// Returns true if the type is floating point, else false.
bool get_is_real() const { return m_is_real; }
private:
size_t m_bitwidth;
bool m_is_real;
bool m_is_signed;
std::string m_cname;
size_t m_bitwidth{0};
bool m_is_real{false};
bool m_is_signed{false};
std::string m_cname{"unspecified"};
};
template <typename T>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment