Commit 0f2ff7ea authored by Derek Bailey's avatar Derek Bailey Committed by Wouter van Oortmerssen

Lua cleanup (#5624)

parent dda09502
...@@ -29,12 +29,12 @@ local getAlignSize = compat.GetAlignSize ...@@ -29,12 +29,12 @@ local getAlignSize = compat.GetAlignSize
local function vtableEqual(a, objectStart, b) local function vtableEqual(a, objectStart, b)
UOffsetT:EnforceNumber(objectStart) UOffsetT:EnforceNumber(objectStart)
if (#a * VOffsetT.bytewidth) ~= #b then if (#a * 2) ~= #b then
return false return false
end end
for i, elem in ipairs(a) do for i, elem in ipairs(a) do
local x = string.unpack(VOffsetT.packFmt, b, 1 + (i - 1) * VOffsetT.bytewidth) local x = string.unpack(VOffsetT.packFmt, b, 1 + (i - 1) * 2)
if x ~= 0 or elem ~= 0 then if x ~= 0 or elem ~= 0 then
local y = objectStart - elem local y = objectStart - elem
if x ~= y then if x ~= y then
...@@ -121,7 +121,7 @@ function mt:WriteVtable() ...@@ -121,7 +121,7 @@ function mt:WriteVtable()
local vt2lenstr = self.bytes:Slice(vt2Start, vt2Start+1) local vt2lenstr = self.bytes:Slice(vt2Start, vt2Start+1)
local vt2Len = string.unpack(VOffsetT.packFmt, vt2lenstr, 1) local vt2Len = string.unpack(VOffsetT.packFmt, vt2lenstr, 1)
local metadata = VtableMetadataFields * VOffsetT.bytewidth local metadata = VtableMetadataFields * 2
local vt2End = vt2Start + vt2Len local vt2End = vt2Start + vt2Len
local vt2 = self.bytes:Slice(vt2Start+metadata,vt2End) local vt2 = self.bytes:Slice(vt2Start+metadata,vt2End)
...@@ -150,7 +150,7 @@ function mt:WriteVtable() ...@@ -150,7 +150,7 @@ function mt:WriteVtable()
self:PrependVOffsetT(objectSize) self:PrependVOffsetT(objectSize)
local vBytes = #self.currentVTable + VtableMetadataFields local vBytes = #self.currentVTable + VtableMetadataFields
vBytes = vBytes * VOffsetT.bytewidth vBytes = vBytes * 2
self:PrependVOffsetT(vBytes) self:PrependVOffsetT(vBytes)
local objectStart = #self.bytes - objectOffset local objectStart = #self.bytes - objectOffset
...@@ -225,17 +225,17 @@ function mt:Prep(size, additionalBytes) ...@@ -225,17 +225,17 @@ function mt:Prep(size, additionalBytes)
end end
function mt:PrependSOffsetTRelative(off) function mt:PrependSOffsetTRelative(off)
self:Prep(SOffsetT.bytewidth, 0) self:Prep(4, 0)
assert(off <= self:Offset(), "Offset arithmetic error") assert(off <= self:Offset(), "Offset arithmetic error")
local off2 = self:Offset() - off + SOffsetT.bytewidth local off2 = self:Offset() - off + 4
self:Place(off2, SOffsetT) self:Place(off2, SOffsetT)
end end
function mt:PrependUOffsetTRelative(off) function mt:PrependUOffsetTRelative(off)
self:Prep(UOffsetT.bytewidth, 0) self:Prep(4, 0)
local soffset = self:Offset() local soffset = self:Offset()
if off <= soffset then if off <= soffset then
local off2 = soffset - off + UOffsetT.bytewidth local off2 = soffset - off + 4
self:Place(off2, UOffsetT) self:Place(off2, UOffsetT)
else else
error("Offset arithmetic error") error("Offset arithmetic error")
...@@ -245,8 +245,9 @@ end ...@@ -245,8 +245,9 @@ end
function mt:StartVector(elemSize, numElements, alignment) function mt:StartVector(elemSize, numElements, alignment)
assert(not self.nested) assert(not self.nested)
self.nested = true self.nested = true
self:Prep(Uint32.bytewidth, elemSize * numElements) local elementSize = elemSize * numElements
self:Prep(alignment, elemSize * numElements) self:Prep(4, elementSize) -- Uint32 length
self:Prep(alignment, elementSize)
return self:Offset() return self:Offset()
end end
...@@ -263,7 +264,7 @@ function mt:CreateString(s) ...@@ -263,7 +264,7 @@ function mt:CreateString(s)
assert(type(s) == "string") assert(type(s) == "string")
self:Prep(UOffsetT.bytewidth, (#s + 1)*Uint8.bytewidth) self:Prep(4, #s + 1)
self:Place(0, Uint8) self:Place(0, Uint8)
local l = #s local l = #s
...@@ -271,20 +272,21 @@ function mt:CreateString(s) ...@@ -271,20 +272,21 @@ function mt:CreateString(s)
self.bytes:Set(s, self.head, self.head + l) self.bytes:Set(s, self.head, self.head + l)
return self:EndVector(#s) return self:EndVector(l)
end end
function mt:CreateByteVector(x) function mt:CreateByteVector(x)
assert(not self.nested) assert(not self.nested)
self.nested = true self.nested = true
self:Prep(UOffsetT.bytewidth, #x*Uint8.bytewidth)
local l = #x local l = #x
self:Prep(4, l)
self.head = self.head - l self.head = self.head - l
self.bytes:Set(x, self.head, self.head + l) self.bytes:Set(x, self.head, self.head + l)
return self:EndVector(#x) return self:EndVector(l)
end end
function mt:Slot(slotnum) function mt:Slot(slotnum)
...@@ -295,12 +297,7 @@ end ...@@ -295,12 +297,7 @@ end
local function finish(self, rootTable, sizePrefix) local function finish(self, rootTable, sizePrefix)
UOffsetT:EnforceNumber(rootTable) UOffsetT:EnforceNumber(rootTable)
local prepSize = UOffsetT.bytewidth self:Prep(self.minalign, sizePrefix and 8 or 4)
if sizePrefix then
prepSize = prepSize + Int32.bytewidth
end
self:Prep(self.minalign, prepSize)
self:PrependUOffsetTRelative(rootTable) self:PrependUOffsetTRelative(rootTable)
if sizePrefix then if sizePrefix then
local size = #self.bytes - self.head local size = #self.bytes - self.head
...@@ -325,8 +322,9 @@ function mt:Prepend(flags, off) ...@@ -325,8 +322,9 @@ function mt:Prepend(flags, off)
end end
function mt:PrependSlot(flags, o, x, d) function mt:PrependSlot(flags, o, x, d)
flags:EnforceNumber(x) flags:EnforceNumbers(x,d)
flags:EnforceNumber(d) -- flags:EnforceNumber(x)
-- flags:EnforceNumber(d)
if x ~= d then if x ~= d then
self:Prepend(flags, x) self:Prepend(flags, x)
self:Slot(o) self:Slot(o)
......
...@@ -34,6 +34,20 @@ function type_mt:EnforceNumber(n) ...@@ -34,6 +34,20 @@ function type_mt:EnforceNumber(n)
error("Number is not in the valid range") error("Number is not in the valid range")
end end
function type_mt:EnforceNumbers(a,b)
-- duplicate code since the overhead of function calls
-- for such a popular method is time consuming
if not self.min_value and not self.max_value then
return
end
if self.min_value <= a and a <= self.max_value and self.min_value <= b and b <= self.max_value then
return
end
error("Number is not in the valid range")
end
function type_mt:EnforceNumberAndPack(n) function type_mt:EnforceNumberAndPack(n)
return bpack(self.packFmt, n) return bpack(self.packFmt, n)
end end
...@@ -58,6 +72,7 @@ local bool_mt = ...@@ -58,6 +72,7 @@ local bool_mt =
Unpack = function(self, buf, pos) return buf[pos] == "1" end, Unpack = function(self, buf, pos) return buf[pos] == "1" end,
ValidNumber = function(self, n) return true end, -- anything is a valid boolean in Lua ValidNumber = function(self, n) return true end, -- anything is a valid boolean in Lua
EnforceNumber = function(self, n) end, -- anything is a valid boolean in Lua EnforceNumber = function(self, n) end, -- anything is a valid boolean in Lua
EnforceNumbers = function(self, a, b) end, -- anything is a valid boolean in Lua
EnforceNumberAndPack = function(self, n) return self:Pack(value) end, EnforceNumberAndPack = function(self, n) return self:Pack(value) end,
} }
......
...@@ -6,69 +6,83 @@ local mt_name = "flatbuffers.view.mt" ...@@ -6,69 +6,83 @@ local mt_name = "flatbuffers.view.mt"
local N = require("flatbuffers.numTypes") local N = require("flatbuffers.numTypes")
local binaryarray = require("flatbuffers.binaryarray") local binaryarray = require("flatbuffers.binaryarray")
local function enforceOffset(off)
if off < 0 or off > 42949672951 then
error("Offset is not valid")
end
end
local unpack = string.unpack
local function unPackUoffset(bytes, off)
return unpack("<I4", bytes.str, off + 1)
end
local function unPackVoffset(bytes, off)
return unpack("<I2", bytes.str, off + 1)
end
function m.New(buf, pos) function m.New(buf, pos)
N.UOffsetT:EnforceNumber(pos) enforceOffset(pos)
-- need to convert from a string buffer into -- need to convert from a string buffer into
-- a binary array -- a binary array
local o = { local o = {
bytes = type(buf) == "string" and binaryarray.New(buf) or buf, bytes = type(buf) == "string" and binaryarray.New(buf) or buf,
pos = pos pos = pos,
} }
setmetatable(o, {__index = mt, __metatable = mt_name}) setmetatable(o, {__index = mt, __metatable = mt_name})
return o return o
end end
function mt:Offset(vtableOffset) function mt:Offset(vtableOffset)
local vtable = self.pos - self:Get(N.SOffsetT, self.pos) local vtable = self.vtable
local vtableEnd = self:Get(N.VOffsetT, vtable) if not vtable then
if vtableOffset < vtableEnd then vtable = self.pos - self:Get(N.SOffsetT, self.pos)
return self:Get(N.VOffsetT, vtable + vtableOffset) self.vtable = vtable
self.vtableEnd = self:Get(N.VOffsetT, vtable)
end
if vtableOffset < self.vtableEnd then
return unPackVoffset(self.bytes, vtable + vtableOffset)
end end
return 0 return 0
end end
function mt:Indirect(off) function mt:Indirect(off)
N.UOffsetT:EnforceNumber(off) enforceOffset(off)
return off + N.UOffsetT:Unpack(self.bytes, off) return off + unPackUoffset(self.bytes, off)
end end
function mt:String(off) function mt:String(off)
N.UOffsetT:EnforceNumber(off) enforceOffset(off)
off = off + N.UOffsetT:Unpack(self.bytes, off) off = off + unPackUoffset(self.bytes, off)
local start = off + N.UOffsetT.bytewidth local start = off + 4
local length = N.UOffsetT:Unpack(self.bytes, off) local length = unPackUoffset(self.bytes, off)
return self.bytes:Slice(start, start+length) return self.bytes:Slice(start, start+length)
end end
function mt:VectorLen(off) function mt:VectorLen(off)
N.UOffsetT:EnforceNumber(off) enforceOffset(off)
off = off + self.pos off = off + self.pos
off = off + N.UOffsetT:Unpack(self.bytes, off) off = off + unPackUoffset(self.bytes, off)
return N.UOffsetT:Unpack(self.bytes, off) return unPackUoffset(self.bytes, off)
end end
function mt:Vector(off) function mt:Vector(off)
N.UOffsetT:EnforceNumber(off) enforceOffset(off)
off = off + self.pos off = off + self.pos
local x = off + self:Get(N.UOffsetT, off) return off + self:Get(N.UOffsetT, off) + 4
x = x + N.UOffsetT.bytewidth
return x
end end
function mt:Union(t2, off) function mt:Union(t2, off)
assert(getmetatable(t2) == mt_name) assert(getmetatable(t2) == mt_name)
N.UOffsetT:EnforceNumber(off) enforceOffset(off)
off = off + self.pos off = off + self.pos
t2.pos = off + self:Get(N.UOffsetT, off) t2.pos = off + self:Get(N.UOffsetT, off)
t2.bytes = self.bytes t2.bytes = self.bytes
end end
function mt:Get(flags, off) function mt:Get(flags, off)
N.UOffsetT:EnforceNumber(off) enforceOffset(off)
return flags:Unpack(self.bytes, off) return flags:Unpack(self.bytes, off)
end end
...@@ -85,8 +99,7 @@ function mt:GetSlot(slot, d, validatorFlags) ...@@ -85,8 +99,7 @@ function mt:GetSlot(slot, d, validatorFlags)
end end
function mt:GetVOffsetTSlot(slot, d) function mt:GetVOffsetTSlot(slot, d)
N.VOffsetT:EnforceNumber(slot) N.VOffsetT:EnforceNumbers(slot, d)
N.VOffsetT:EnforceNumber(d)
local off = self:Offset(slot) local off = self:Offset(slot)
if off == 0 then if off == 0 then
return d return d
......
...@@ -82,6 +82,7 @@ local function checkReadBuffer(buf, offset, sizePrefix) ...@@ -82,6 +82,7 @@ local function checkReadBuffer(buf, offset, sizePrefix)
end end
local function generateMonster(sizePrefix, b) local function generateMonster(sizePrefix, b)
if b then b:Clear() end
b = b or flatbuffers.Builder(0) b = b or flatbuffers.Builder(0)
local str = b:CreateString("MyMonster") local str = b:CreateString("MyMonster")
local test1 = b:CreateString("test1") local test1 = b:CreateString("test1")
...@@ -208,26 +209,16 @@ local function testCanonicalData() ...@@ -208,26 +209,16 @@ local function testCanonicalData()
checkReadBuffer(wireData) checkReadBuffer(wireData)
end end
local function benchmarkMakeMonster(count) local function benchmarkMakeMonster(count, reuseBuilder)
local length = #(generateMonster()) local fbb = reuseBuilder and flatbuffers.Builder(0)
local length = #(generateMonster(false, fbb))
--require("flatbuffers.profiler")
--profiler = newProfiler("call")
--profiler:start()
local s = os.clock() local s = os.clock()
for i=1,count do for i=1,count do
generateMonster() generateMonster(false, fbb)
end end
local e = os.clock() local e = os.clock()
--profiler:stop()
--local outfile = io.open( "profile.txt", "w+" )
--profiler:report( outfile, true)
--outfile:close()
local dur = (e - s) local dur = (e - s)
local rate = count / (dur * 1000) local rate = count / (dur * 1000)
local data = (length * count) / (1024 * 1024) local data = (length * count) / (1024 * 1024)
...@@ -279,6 +270,7 @@ local tests = ...@@ -279,6 +270,7 @@ local tests =
{100}, {100},
{1000}, {1000},
{10000}, {10000},
{10000, true}
} }
}, },
{ {
...@@ -290,7 +282,7 @@ local tests = ...@@ -290,7 +282,7 @@ local tests =
{10000}, {10000},
-- uncomment following to run 1 million to compare. -- uncomment following to run 1 million to compare.
-- Took ~141 seconds on my machine -- Took ~141 seconds on my machine
--{1000000}, --{1000000},
} }
}, },
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment