Commit e2bddf19 authored by Diego Caballero's avatar Diego Caballero Committed by Scott Cyphers

[MLIR] Add support for parsing nGraph tensor type (#3654)

* [MLIR] Add support for parsing nGraph tensor type

Initial commit that enables nGraph parsing. It's needed for testing.

* Clang format

* Fix return in parser
parent 1c9a1996
......@@ -22,6 +22,8 @@
#include "ops.hpp"
#include "type.hpp"
#include <mlir/Parser.h>
using namespace mlir;
NGraphOpsDialect::NGraphOpsDialect(mlir::MLIRContext* ctx)
......@@ -37,6 +39,51 @@ NGraphOpsDialect::NGraphOpsDialect(mlir::MLIRContext* ctx)
>();
}
mlir::Type NGraphOpsDialect::parseType(llvm::StringRef tyData, mlir::Location loc) const
{
StringRef origTypeStr = tyData;
MLIRContext* context = getContext();
if (tyData.consume_front("tensor"))
{
if (!tyData.consume_front("<") || !tyData.consume_back(">"))
{
return (emitError(loc, "expected '<' and '>' enclosing the tensor shape: " + tyData),
Type());
}
// Get x-separated sub-strings.
SmallVector<StringRef, 8> subStrings;
tyData.split(subStrings, "x");
// Parse shape dimensions.
SmallVector<int64_t, 4> shape;
for (unsigned i = 0, end = subStrings.size() - 1; i < end; ++i)
{
StringRef dimStr = subStrings[i];
int64_t dim = -1;
// NOTE: `consumeInteger` returns false if an integer was parsed successfully.
if (dimStr.consumeInteger(/*Radix=*/10, dim) || !dimStr.empty())
{
return (
emitError(loc, "expected a list of '[0-9]+x' dimension specifiers: " + tyData),
Type());
}
shape.push_back(dim);
}
auto elem_ty = mlir::parseType(subStrings.back(), context);
if (!elem_ty)
{
return (emitError(loc, "Unexpected element type in tensor type: " + tyData), Type());
}
return NGTensorType::get(context, elem_ty, shape);
}
return (emitError(loc, "Unknown nGraph type: " + origTypeStr), Type());
}
void NGraphOpsDialect::printType(mlir::Type type, raw_ostream& os) const
{
switch (type.getKind())
......
......@@ -34,11 +34,7 @@ namespace mlir
{
public:
explicit NGraphOpsDialect(mlir::MLIRContext* ctx);
mlir::Type parseType(llvm::StringRef tyData, mlir::Location loc) const override
{
NGRAPH_CHECK(false, "Unsupported type parsing.");
return mlir::Type();
}
mlir::Type parseType(llvm::StringRef tyData, mlir::Location loc) const override;
void printType(mlir::Type type, llvm::raw_ostream& os) const override;
static StringRef getDialectNamespace() { return "ng"; }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment