-- Copyright (c) 2015  Phil Leblanc  -- see LICENSE file
------------------------------------------------------------
--[[

Poly1305 message authentication (MAC) created by Dan Bernstein
...
]]

-----------------------------------------------------------
-- poly1305

local sunp = string.unpack
local bit32 = bit32 -- alias

local function poly_init(k)
  -- k: 32-byte key as a string
  -- initialize internal state
  local st = {
    r = {
      bit32.band(sunp('<I4', k, 1), 0x3ffffff),                -- r0
      bit32.band(bit32.rshift(sunp('<I4', k, 4), 2), 0x3ffff03), -- r1
      bit32.band(bit32.rshift(sunp('<I4', k, 7), 4), 0x3ffc0ff), -- r2
      bit32.band(bit32.rshift(sunp('<I4', k, 10), 6), 0x3f03fff), -- r3
      bit32.band(bit32.rshift(sunp('<I4', k, 13), 8), 0x00fffff), -- r4
    },
    h = { 0, 0, 0, 0, 0 },
    pad = { sunp('<I4', k, 17), -- 's' in rfc
      sunp('<I4', k, 21),
      sunp('<I4', k, 25),
      sunp('<I4', k, 29),
    },
    buffer = "", --
    leftover = 0,
    final = false,
  } --st
  return st
end --poly_init()

local function poly_blocks(st, m)
  -- st: internal state
  -- m: message:string
  local bytes = #m
  local midx = 1
  local hibit = st.final and 0 or 0x01000000 -- 1 << 24
  local r0 = st.r[1]
  local r1 = st.r[2]
  local r2 = st.r[3]
  local r3 = st.r[4]
  local r4 = st.r[5]
  local s1 = r1 * 5
  local s2 = r2 * 5
  local s3 = r3 * 5
  local s4 = r4 * 5
  local h0 = st.h[1]
  local h1 = st.h[2]
  local h2 = st.h[3]
  local h3 = st.h[4]
  local h4 = st.h[5]
  local d0, d1, d2, d3, d4, c
  --
  while bytes >= 16 do -- 16 = poly1305_block_size
    -- h += m[i]  (in rfc:  a += n with 0x01 byte)
    h0    = h0 + bit32.band(sunp('<I4', m, midx), 0x3ffffff)
    h1    = h1 + bit32.band(bit32.rshift(sunp('<I4', m, midx + 3), 2), 0x3ffffff)
    h2    = h2 + bit32.band(bit32.rshift(sunp('<I4', m, midx + 6), 4), 0x3ffffff)
    h3    = h3 + bit32.band(bit32.rshift(sunp('<I4', m, midx + 9), 6), 0x3ffffff)
    h4    = h4 + bit32.bor(bit32.rshift(sunp('<I4', m, midx + 12), 8), hibit) -- 0x01 byte
    --
    -- h *= r % p (partial)
    d0    = h0 * r0 + h1 * s4 + h2 * s3 + h3 * s2 + h4 * s1
    d1    = h0 * r1 + h1 * r0 + h2 * s4 + h3 * s3 + h4 * s2
    d2    = h0 * r2 + h1 * r1 + h2 * r0 + h3 * s4 + h4 * s3
    d3    = h0 * r3 + h1 * r2 + h2 * r1 + h3 * r0 + h4 * s4
    d4    = h0 * r4 + h1 * r3 + h2 * r2 + h3 * r1 + h4 * r0
    --
    c     = bit32.band(bit32.rshift(d0, 26), 0xffffffff); h0 = bit32.band(d0, 0x3ffffff)
    d1    = d1 + c; c = bit32.band(bit32.rshift(d1, 26), 0xffffffff); h1 = bit32.band(d1, 0x3ffffff)
    d2    = d2 + c; c = bit32.band(bit32.rshift(d2, 26), 0xffffffff); h2 = bit32.band(d2, 0x3ffffff)
    d3    = d3 + c; c = bit32.band(bit32.rshift(d3, 26), 0xffffffff); h3 = bit32.band(d3, 0x3ffffff)
    d4    = d4 + c; c = bit32.band(bit32.rshift(d4, 26), 0xffffffff); h4 = bit32.band(d4, 0x3ffffff)
    h0    = h0 + (c * 5); c = bit32.rshift(h0, 26); h0 = bit32.band(h0, 0x3ffffff)
    h1    = h1 + c
    --
    midx  = midx + 16 -- 16 = poly1305_block_size
    bytes = bytes - 16
  end              --while
  st.h[1] = h0
  st.h[2] = h1
  st.h[3] = h2
  st.h[4] = h3
  st.h[5] = h4
  st.bytes = bytes -- remaining bytes. must be < 16 here
  st.midx = midx  -- index of first remaining bytes
  return st
end               --poly_blocks()

local function poly_update(st, m)
  -- st: internal state
  -- m: message:string
  st.bytes, st.midx = #m, 1
  -- process full blocks if any
  if st.bytes >= 16 then
    poly_blocks(st, m)
  end
  --handle remaining bytes
  if st.bytes == 0 then -- no bytes left
    -- nothing to do? no add 0x01? - apparently not.
  else
    local buffer = string.sub(m, st.midx)
        .. '\x01' .. string.rep('\0', 16 - st.bytes - 1)
    assert(#buffer == 16)
    st.final = true -- this is the last block
    --~ 		p16(buffer)
    poly_blocks(st, buffer)
  end
  --
  return st
end --poly_update

local function poly_finish(st)
  --
  local c, mask --u32
  local f      --u64
  -- fully carry h
  local h0 = st.h[1]
  local h1 = st.h[2]
  local h2 = st.h[3]
  local h3 = st.h[4]
  local h4 = st.h[5]
  --
  c = bit32.rshift(h1, 26); h1 = bit32.band(h1, 0x3ffffff)
  h2 = h2 + c; c = bit32.rshift(h2, 26); h2 = bit32.band(h2, 0x3ffffff)
  h3 = h3 + c; c = bit32.rshift(h3, 26); h3 = bit32.band(h3, 0x3ffffff)
  h4 = h4 + c; c = bit32.rshift(h4, 26); h4 = bit32.band(h4, 0x3ffffff)
  h0 = h0 + (c * 5); c = bit32.rshift(h0, 26); h0 = bit32.band(h0, 0x3ffffff)
  h1 = h1 + c
  --
  --compute h + -p
  local g0 = (h0 + 5); c = bit32.rshift(g0, 26); g0 = bit32.band(g0, 0x3ffffff)
  local g1 = (h1 + c); c = bit32.rshift(g1, 26); g1 = bit32.band(g1, 0x3ffffff)
  local g2 = (h2 + c); c = bit32.rshift(g2, 26); g2 = bit32.band(g2, 0x3ffffff)
  local g3 = (h3 + c); c = bit32.rshift(g3, 26); g3 = bit32.band(g3, 0x3ffffff)
  local g4 = bit32.band(h4 + c - 0x4000000, 0xffffffff) -- (1 << 26)
  --
  -- select h if h < p, or h + -p if h >= p
  mask = bit32.band(bit32.rshift(g4, 31) - 1, 0xffffffff)
  --
  g0 = bit32.band(g0, mask)
  g1 = bit32.band(g1, mask)
  g2 = bit32.band(g2, mask)
  g3 = bit32.band(g3, mask)
  g4 = bit32.band(g4, mask)
  --
  mask = bit32.band(bit32.bnot(mask), 0xffffffff)
  h0 = bit32.bor(bit32.band(h0, mask), g0)
  h1 = bit32.bor(bit32.band(h1, mask), g1)
  h2 = bit32.bor(bit32.band(h2, mask), g2)
  h3 = bit32.bor(bit32.band(h3, mask), g3)
  h4 = bit32.bor(bit32.band(h4, mask), g4)
  --
  --h = h % (2^128)
  h0 = bit32.band(bit32.bor((h0), bit32.lshift(h1, 26)), 0xffffffff)
  h1 = bit32.band(bit32.bor(bit32.rshift(h1, 6), bit32.lshift(h2, 20)), 0xffffffff)
  h2 = bit32.band(bit32.bor(bit32.rshift(h2, 12), bit32.lshift(h3, 14)), 0xffffffff)
  h3 = bit32.band(bit32.bor(bit32.rshift(h3, 18), bit32.lshift(h4, 8)), 0xffffffff)
  --
  -- mac = (h + pad) % (2^128)
  f = h0 + st.pad[1]; h0 = bit32.band(f, 0xffffffff)
  f = h1 + st.pad[2] + bit32.rshift(f, 32); h1 = bit32.band(f, 0xffffffff)
  f = h2 + st.pad[3] + bit32.rshift(f, 32); h2 = bit32.band(f, 0xffffffff)
  f = h3 + st.pad[4] + bit32.rshift(f, 32); h3 = bit32.band(f, 0xffffffff)
  --
  local mac = string.pack('<I4I4I4I4', h0, h1, h2, h3)
  -- (should zero out the state?)
  --
  return mac
end --poly_finish()

local function poly_auth(m, k)
  -- m: msg string
  -- k: key string (must be 32 bytes)
  -- return mac 16-byte string
  assert(#k == 32)
  local st = poly_init(k)
  poly_update(st, m)
  local mac = poly_finish(st)
  return mac
end --poly_auth()

local function poly_verify(m, k, mac)
  local macm = poly_auth(m, k)
  return macm == mac
end --poly_verify()

------------------------------------------------------------
-- return poly1305 module

return {
  init = poly_init,
  update = poly_update,
  finish = poly_finish,
  auth = poly_auth,
  verify = poly_verify,
}
