Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 179 additions & 40 deletions src/bitn/bit64.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,31 @@ local bit64 = {}

local bit32 = require("bitn.bit32")

-- Private metatable for Int64 type identification
local Int64Meta = { __name = "Int64" }

-- Type definitions
--- @alias Int64HighLow [integer, integer] Array with [1]=high 32 bits, [2]=low 32 bits

--------------------------------------------------------------------------------
-- Constructor and type checking
--------------------------------------------------------------------------------

--- Create a new Int64 value with metatable marker.
--- @param high? integer Upper 32 bits (default: 0)
--- @param low? integer Lower 32 bits (default: 0)
--- @return Int64HighLow value Int64 value with metatable marker
function bit64.new(high, low)
return setmetatable({ high or 0, low or 0 }, Int64Meta)
end

--- Check if a value is an Int64 (created by bit64 functions).
--- @param value any Value to check
--- @return boolean isInt64 True if value is an Int64
function bit64.isInt64(value)
return type(value) == "table" and getmetatable(value) == Int64Meta
end

--------------------------------------------------------------------------------
-- Bitwise operations
--------------------------------------------------------------------------------
Expand All @@ -21,42 +43,30 @@ local bit32 = require("bitn.bit32")
--- @param b Int64HighLow Second operand {high, low}
--- @return Int64HighLow result {high, low} AND result
function bit64.band(a, b)
return {
bit32.band(a[1], b[1]),
bit32.band(a[2], b[2]),
}
return bit64.new(bit32.band(a[1], b[1]), bit32.band(a[2], b[2]))
end

--- Bitwise OR operation.
--- @param a Int64HighLow First operand {high, low}
--- @param b Int64HighLow Second operand {high, low}
--- @return Int64HighLow result {high, low} OR result
function bit64.bor(a, b)
return {
bit32.bor(a[1], b[1]),
bit32.bor(a[2], b[2]),
}
return bit64.new(bit32.bor(a[1], b[1]), bit32.bor(a[2], b[2]))
end

--- Bitwise XOR operation.
--- @param a Int64HighLow First operand {high, low}
--- @param b Int64HighLow Second operand {high, low}
--- @return Int64HighLow result {high, low} XOR result
function bit64.bxor(a, b)
return {
bit32.bxor(a[1], b[1]),
bit32.bxor(a[2], b[2]),
}
return bit64.new(bit32.bxor(a[1], b[1]), bit32.bxor(a[2], b[2]))
end

--- Bitwise NOT operation.
--- @param a Int64HighLow Operand {high, low}
--- @return Int64HighLow result {high, low} NOT result
function bit64.bnot(a)
return {
bit32.bnot(a[1]),
bit32.bnot(a[2]),
}
return bit64.new(bit32.bnot(a[1]), bit32.bnot(a[2]))
end

--------------------------------------------------------------------------------
Expand All @@ -69,17 +79,17 @@ end
--- @return Int64HighLow result {high, low} shifted value
function bit64.lshift(x, n)
if n == 0 then
return { x[1], x[2] }
return bit64.new(x[1], x[2])
elseif n >= 64 then
return { 0, 0 }
return bit64.new(0, 0)
elseif n >= 32 then
-- Shift by 32 or more: low becomes 0, high gets bits from low
return { bit32.lshift(x[2], n - 32), 0 }
return bit64.new(bit32.lshift(x[2], n - 32), 0)
else
-- Shift by less than 32
local new_high = bit32.bor(bit32.lshift(x[1], n), bit32.rshift(x[2], 32 - n))
local new_low = bit32.lshift(x[2], n)
return { new_high, new_low }
return bit64.new(new_high, new_low)
end
end

Expand All @@ -89,17 +99,17 @@ end
--- @return Int64HighLow result {high, low} shifted value
function bit64.rshift(x, n)
if n == 0 then
return { x[1], x[2] }
return bit64.new(x[1], x[2])
elseif n >= 64 then
return { 0, 0 }
return bit64.new(0, 0)
elseif n >= 32 then
-- Shift by 32 or more: high becomes 0, low gets bits from high
return { 0, bit32.rshift(x[1], n - 32) }
return bit64.new(0, bit32.rshift(x[1], n - 32))
else
-- Shift by less than 32
local new_low = bit32.bor(bit32.rshift(x[2], n), bit32.lshift(x[1], 32 - n))
local new_high = bit32.rshift(x[1], n)
return { new_high, new_low }
return bit64.new(new_high, new_low)
end
end

Expand All @@ -109,7 +119,7 @@ end
--- @return Int64HighLow result {high, low} shifted value
function bit64.arshift(x, n)
if n == 0 then
return { x[1], x[2] }
return bit64.new(x[1], x[2])
end

-- Check sign bit (bit 31 of high word)
Expand All @@ -118,20 +128,20 @@ function bit64.arshift(x, n)
if n >= 64 then
-- All bits shift out, result is all 1s if negative, all 0s if positive
if is_negative then
return { 0xFFFFFFFF, 0xFFFFFFFF }
return bit64.new(0xFFFFFFFF, 0xFFFFFFFF)
else
return { 0, 0 }
return bit64.new(0, 0)
end
elseif n >= 32 then
-- High word shifts into low, high fills with sign
local new_low = bit32.arshift(x[1], n - 32)
local new_high = is_negative and 0xFFFFFFFF or 0
return { new_high, new_low }
return bit64.new(new_high, new_low)
else
-- Shift by less than 32
local new_low = bit32.bor(bit32.rshift(x[2], n), bit32.lshift(x[1], 32 - n))
local new_high = bit32.arshift(x[1], n)
return { new_high, new_low }
return bit64.new(new_high, new_low)
end
end

Expand All @@ -146,25 +156,25 @@ end
function bit64.rol(x, n)
n = n % 64
if n == 0 then
return { x[1], x[2] }
return bit64.new(x[1], x[2])
end

local high, low = x[1], x[2]

if n == 32 then
-- Special case: swap high and low
return { low, high }
return bit64.new(low, high)
elseif n < 32 then
-- Rotate within 32-bit boundaries
local new_high = bit32.bor(bit32.lshift(high, n), bit32.rshift(low, 32 - n))
local new_low = bit32.bor(bit32.lshift(low, n), bit32.rshift(high, 32 - n))
return { new_high, new_low }
return bit64.new(new_high, new_low)
else
-- n > 32: rotate by (n - 32) after swapping
n = n - 32
local new_high = bit32.bor(bit32.lshift(low, n), bit32.rshift(high, 32 - n))
local new_low = bit32.bor(bit32.lshift(high, n), bit32.rshift(low, 32 - n))
return { new_high, new_low }
return bit64.new(new_high, new_low)
end
end

Expand All @@ -175,25 +185,25 @@ end
function bit64.ror(x, n)
n = n % 64
if n == 0 then
return { x[1], x[2] }
return bit64.new(x[1], x[2])
end

local high, low = x[1], x[2]

if n == 32 then
-- Special case: swap high and low
return { low, high }
return bit64.new(low, high)
elseif n < 32 then
-- Rotate within 32-bit boundaries
local new_low = bit32.bor(bit32.rshift(low, n), bit32.lshift(high, 32 - n))
local new_high = bit32.bor(bit32.rshift(high, n), bit32.lshift(low, 32 - n))
return { new_high, new_low }
return bit64.new(new_high, new_low)
else
-- n > 32: rotate by (n - 32) after swapping
n = n - 32
local new_low = bit32.bor(bit32.rshift(high, n), bit32.lshift(low, 32 - n))
local new_high = bit32.bor(bit32.rshift(low, n), bit32.lshift(high, 32 - n))
return { new_high, new_low }
return bit64.new(new_high, new_low)
end
end

Expand All @@ -218,7 +228,7 @@ function bit64.add(a, b)
-- Keep high within 32 bits
high = high % 0x100000000

return { high, low }
return bit64.new(high, low)
end

--------------------------------------------------------------------------------
Expand Down Expand Up @@ -248,7 +258,7 @@ function bit64.be_bytes_to_u64(str, offset)
assert(#str >= offset + 7, "Insufficient bytes for u64")
local high = bit32.be_bytes_to_u32(str, offset)
local low = bit32.be_bytes_to_u32(str, offset + 4)
return { high, low }
return bit64.new(high, low)
end

--- Convert 8 bytes to 64-bit value (little-endian).
Expand All @@ -260,7 +270,7 @@ function bit64.le_bytes_to_u64(str, offset)
assert(#str >= offset + 7, "Insufficient bytes for u64")
local low = bit32.le_bytes_to_u32(str, offset)
local high = bit32.le_bytes_to_u32(str, offset + 4)
return { high, low }
return bit64.new(high, low)
end

--------------------------------------------------------------------------------
Expand Down Expand Up @@ -614,6 +624,135 @@ function bit64.selftest()
end
end

-- Int64 type identification tests
print("\nRunning Int64 type identification tests...")

-- Test bit64.new() creates Int64 values
total = total + 1
local new_val = bit64.new(0x12345678, 0x9ABCDEF0)
if bit64.isInt64(new_val) and new_val[1] == 0x12345678 and new_val[2] == 0x9ABCDEF0 then
print(" PASS: new() creates Int64 with correct values")
passed = passed + 1
else
print(" FAIL: new() creates Int64 with correct values")
end

-- Test bit64.new() with defaults
total = total + 1
local zero_val = bit64.new()
if bit64.isInt64(zero_val) and zero_val[1] == 0 and zero_val[2] == 0 then
print(" PASS: new() with no args creates {0, 0}")
passed = passed + 1
else
print(" FAIL: new() with no args creates {0, 0}")
end

-- Test isInt64() returns false for regular tables
total = total + 1
local plain_table = { 0x12345678, 0x9ABCDEF0 }
if not bit64.isInt64(plain_table) then
print(" PASS: isInt64() returns false for plain table")
passed = passed + 1
else
print(" FAIL: isInt64() returns false for plain table")
end

-- Test isInt64() returns false for non-tables
total = total + 1
if not bit64.isInt64(123) and not bit64.isInt64("string") and not bit64.isInt64(nil) then
print(" PASS: isInt64() returns false for non-tables")
passed = passed + 1
else
print(" FAIL: isInt64() returns false for non-tables")
end

-- Test all operations return Int64 values
local ops_returning_int64 = {
{
name = "band",
fn = function()
return bit64.band({ 1, 2 }, { 3, 4 })
end,
},
{
name = "bor",
fn = function()
return bit64.bor({ 1, 2 }, { 3, 4 })
end,
},
{
name = "bxor",
fn = function()
return bit64.bxor({ 1, 2 }, { 3, 4 })
end,
},
{
name = "bnot",
fn = function()
return bit64.bnot({ 1, 2 })
end,
},
{
name = "lshift",
fn = function()
return bit64.lshift({ 1, 2 }, 1)
end,
},
{
name = "rshift",
fn = function()
return bit64.rshift({ 1, 2 }, 1)
end,
},
{
name = "arshift",
fn = function()
return bit64.arshift({ 1, 2 }, 1)
end,
},
{
name = "rol",
fn = function()
return bit64.rol({ 1, 2 }, 1)
end,
},
{
name = "ror",
fn = function()
return bit64.ror({ 1, 2 }, 1)
end,
},
{
name = "add",
fn = function()
return bit64.add({ 1, 2 }, { 3, 4 })
end,
},
{
name = "be_bytes_to_u64",
fn = function()
return bit64.be_bytes_to_u64("\0\0\0\1\0\0\0\2")
end,
},
{
name = "le_bytes_to_u64",
fn = function()
return bit64.le_bytes_to_u64("\2\0\0\0\1\0\0\0")
end,
},
}

for _, op in ipairs(ops_returning_int64) do
total = total + 1
local result = op.fn()
if bit64.isInt64(result) then
print(" PASS: " .. op.name .. "() returns Int64")
passed = passed + 1
else
print(" FAIL: " .. op.name .. "() returns Int64")
end
end

print(string.format("\n64-bit operations: %d/%d tests passed\n", passed, total))
return passed == total
end
Expand Down
Loading