Commit 4ecdb791 authored by Jai Menon's avatar Jai Menon Committed by GitHub

Merge branch 'master' into jmenon/codegen

parents d185b48c 65aeb4b5
...@@ -51,6 +51,11 @@ namespace ngraph ...@@ -51,6 +51,11 @@ namespace ngraph
/// With non-linear buffers, this will need to be something other than size_t. /// With non-linear buffers, this will need to be something other than size_t.
virtual size_t get_index_offset(const std::vector<size_t>& indices) = 0; virtual size_t get_index_offset(const std::vector<size_t>& indices) = 0;
const element::Type& get_element_type() const
{
return m_tensor_view.get_tensor_view_type()->get_element_type();
}
const Shape& get_shape() const const Shape& get_shape() const
{ {
return m_tensor_view.get_tensor_view_type()->get_shape(); return m_tensor_view.get_tensor_view_type()->get_shape();
......
...@@ -17,9 +17,10 @@ ...@@ -17,9 +17,10 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph;
using namespace ngraph::op; using namespace ngraph::op;
void Convert::propagate_types() const element::Type& Convert::propagate_element_types(const element::Type& arg_element_type) const
{ {
throw ngraph_error("NIY"); return m_element_type;
} }
...@@ -27,9 +27,9 @@ namespace ngraph ...@@ -27,9 +27,9 @@ namespace ngraph
{ {
} }
virtual const element::Type&
propagate_element_types(const element::Type& arg_element_type) const override;
virtual std::string description() const override { return "Convert"; } virtual std::string description() const override { return "Convert"; }
virtual void propagate_types() override;
protected: protected:
const ngraph::element::Type& m_element_type; const ngraph::element::Type& m_element_type;
}; };
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename ETI, typename ETO>
class ConvertInstruction : public Instruction
{
public:
ConvertInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ETO>(call_frame, m_out) =
EigenArray1d<ETI>(call_frame, m_arg).template cast<typename ETO::type>();
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
This diff is collapsed.
This diff is collapsed.
...@@ -237,9 +237,49 @@ TEST(type_prop, concat_deduce_elem_type_mismatch) ...@@ -237,9 +237,49 @@ TEST(type_prop, concat_deduce_elem_type_mismatch)
} }
} }
// TEST(type_prop, convert_deduce)
// Tests for dot product. {
// // Deduce type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3, 4});
auto c = make_shared<op::Convert>(param, element::Int32::element_type());
c->propagate_types();
auto c_vt = c->get_value_type();
ASSERT_EQ(*c_vt, TensorViewType(element::Int32::element_type(), Shape{2, 3, 4}));
}
TEST(type_prop, convert_deduce_correct)
{
// Check deduced type against incorrectly specified type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3, 4});
auto c = make_shared<op::Convert>(param, element::Int32::element_type());
c->set_value_type(make_shared<TensorViewType>(element::Int32::element_type(), Shape{2, 3, 4}));
c->propagate_types();
auto c_vt = c->get_value_type();
ASSERT_EQ(*c_vt, TensorViewType(element::Int32::element_type(), Shape{2, 3, 4}));
}
TEST(type_prop, convert_deduce_incorrect)
{
// Check deduced type against incorrectly specified type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3, 4});
auto c = make_shared<op::Convert>(param, element::Int32::element_type());
c->set_value_type(make_shared<TensorViewType>(element::Int32::element_type(), Shape{2, 14, 4}));
try
{
c->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Deduced type should disagree with specified type";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Setting value type to a different ValueType"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, dot_deduce_scalar_2d) TEST(type_prop, dot_deduce_scalar_2d)
{ {
// Deduce type for scalar/matrix arguments // Deduce type for scalar/matrix arguments
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment