Commit 7e5a5f76 authored by Adam Procter's avatar Adam Procter Committed by Yixing Lao

Implement scalar and tensor constants

* This is kind of a cheap hack for the moment, at least in the tensor
  constant case: each constant instruction carries around its value,
  and copies that to the result buffer everytime we need it. Ultimately
  we will probably want a pass that allocates space in the call frame
  for constants, similar to what we do for parameters.
parent 1138d1b9
......@@ -115,6 +115,8 @@ namespace ngraph
typename std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>> get_value() const { return m_value; }
void set_value(const std::vector<type>& value) const { m_value->get_vector() = value; }
protected:
std::shared_ptr<ngraph::runtime::ParameterizedTensorView<T>> m_value;
};
......
......@@ -38,9 +38,8 @@ namespace ngraph
size_t initial_pc,
const std::shared_ptr<std::vector<std::shared_ptr<Instruction>>>& instructions);
void
operator()(const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& inputs,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& outpus);
void operator()(const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& inputs,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& outpus);
void set_return() { m_return = true; }
std::shared_ptr<TensorView> get_tensor(size_t i) { return m_tensors[i]; }
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename ET,typename T>
void assign_constant(const std::vector<ET>& value, T out)
{
out->get_vector() = value;
}
template <typename ET>
class ConstantInstruction : public Instruction
{
public:
ConstantInstruction(const std::vector<typename ET::type> value, size_t out)
: m_value(value)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::assign_constant(
m_value,
call_frame.get_parameterized_tensor<ET>(m_out));
}
protected:
const std::vector<typename ET::type> m_value;
size_t m_out;
};
}
}
}
......@@ -24,6 +24,7 @@
#include "ngraph/node.hpp"
#include "ngraph/ops/abs.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/divide.hpp"
#include "ngraph/ops/equal.hpp"
#include "ngraph/ops/less.hpp"
......@@ -39,6 +40,7 @@
#include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/eigen/abs.hpp"
#include "ngraph/runtime/eigen/add.hpp"
#include "ngraph/runtime/eigen/constant.hpp"
#include "ngraph/runtime/eigen/divide.hpp"
#include "ngraph/runtime/eigen/equal.hpp"
#include "ngraph/runtime/eigen/less_than.hpp"
......@@ -116,6 +118,16 @@ std::unordered_map<std::type_index,
const std::vector<size_t>& in,
const std::vector<size_t>& out) {};
REGISTER_INSTRUCTION(op::ScalarConstant<element::Float32>,
runtime::eigen::ConstantInstruction<element::Float32>,
std::vector<element::Float32::type>{dynamic_cast<op::ScalarConstant<element::Float32>*>(n)->get_value()},
out[0]);
REGISTER_INSTRUCTION(op::TensorConstant<element::Float32>,
runtime::eigen::ConstantInstruction<element::Float32>,
dynamic_cast<op::TensorConstant<element::Float32>*>(n)->get_value()->get_vector(),
out[0]);
initialized = true;
}
return op_map;
......
......@@ -253,3 +253,52 @@ TEST(execute, test_subtract)
ASSERT_EQ((vector<float>{1, 2, 4, 8}), result->get_vector());
}
TEST(execute, test_scalar_constant)
{
auto shape = Shape{};
auto A = make_shared<op::ScalarConstant<element::Float32>>(-3.0f);
auto f = make_shared<Function>(A, op::Parameters{});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto result = ngraph::runtime::make_tensor<element::Float32>(shape);
(*cf)({}, {result});
ASSERT_EQ((vector<float>{-3.0f}), result->get_vector());
}
TEST(execute, test_tensor_constant)
{
auto shape = Shape{2,2,2};
auto A = make_shared<op::TensorConstant<element::Float32>>(shape);
A->set_value(vector<float>{1,2,3,4,5,6,7,8});
auto f = make_shared<Function>(A, op::Parameters{});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto result = ngraph::runtime::make_tensor<element::Float32>(shape);
(*cf)({}, {result});
ASSERT_EQ((vector<float>{1,2,3,4,5,6,7,8}), result->get_vector());
}
TEST(execute, test_tensor_constant_with_op)
{
auto shape = Shape{2,2,2};
auto A = make_shared<op::TensorConstant<element::Float32>>(shape);
A->set_value(vector<float>{-1,2,3,-4,5,-6,-7,8});
auto f = make_shared<Function>(make_shared<op::Abs>(A), op::Parameters{});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto result = ngraph::runtime::make_tensor<element::Float32>(shape);
(*cf)({}, {result});
ASSERT_EQ((vector<float>{1,2,3,4,5,6,7,8}), result->get_vector());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment