Commit d4e64381 authored by Adam Procter's avatar Adam Procter

Remove obsolete Eigen vector-broadcast instructions; remove some commented-out code

parent b485bb33
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace eigen
{
template <typename ET>
class BroadcastVectorColwiseInstruction : public Instruction
{
public:
BroadcastVectorColwiseInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenMatrix<ET>(call_frame, m_out).colwise() =
EigenVector<ET>(call_frame, m_arg);
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace eigen
{
template <typename ET>
class BroadcastVectorRowwiseInstruction : public Instruction
{
public:
BroadcastVectorRowwiseInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenMatrix<ET>(call_frame, m_out).rowwise() =
EigenVector<ET>(call_frame, m_arg).transpose();
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -71,8 +71,6 @@
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/runtime/ngvm/eigen/broadcast_scalar.hpp"
#include "ngraph/runtime/ngvm/eigen/broadcast_vector_colwise.hpp"
#include "ngraph/runtime/ngvm/eigen/broadcast_vector_rowwise.hpp"
#include "ngraph/runtime/ngvm/eigen/concat_matrix.hpp"
#include "ngraph/runtime/ngvm/eigen/concat_vector.hpp"
#include "ngraph/runtime/ngvm/eigen/dot.hpp"
......@@ -437,52 +435,6 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
arg_shape,
result_shape,
broadcast->get_broadcast_axes());
/*
if (broadcast->get_broadcast_axes().empty())
{
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Broadcast has unhandled element type",
instruction::CopyInstruction,
in[0],
out[0]);
}
else if (arg_shape.size() == 0)
{
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Broadcast has unhandled element type",
eigen::BroadcastScalarInstruction,
in[0],
out[0]);
}
else if (arg_shape.size() == 1 && result_shape.size() == 2)
{
if (broadcast->get_broadcast_axes() == AxisSet{1})
{
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Broadcast has unhandled element type",
eigen::BroadcastVectorColwiseInstruction,
in[0],
out[0]);
}
else if (broadcast->get_broadcast_axes() == AxisSet{0})
{
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Broadcast has unhandled element type",
eigen::BroadcastVectorRowwiseInstruction,
in[0],
out[0]);
}
else
{
throw ngraph_error(
"Internal error: axis set for vector-matrix broadcast is neither {0} nor "
"{1}");
}
}
else
{
throw ngraph_error("Broadcast not implemented for rank>2 in VM yet");
}*/
};
REGISTER_TO_OP_MAP(op::Concat)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment