#!/usr/bin/python

import os
import sys
import re
import time
import fcntl
import struct
from glob import glob
from subprocess import Popen, PIPE

MAX_SEC_TRY = 30
IOCTL_BLKGETSIZE64 = 0x80081272
IOCTL_BLKDISCARD = 0x1277

def kmgs_write(data):
    kmsg.write("%s\n" % data)
    kmsg.flush()

def re_read(drive, num_try):
    if num_try > MAX_SEC_TRY:
        kmgs_write("Failed to re-read partition table for %s in %i sec" % (drive, MAX_SEC_TRY))
        return
    p = Popen(["/usr/sbin/blockdev", "--flushbufs", "--rereadpt", drive], stdout=PIPE, stderr=PIPE)
    stdout, stderr = p.communicate()
    if p.returncode != 0:
        kmgs_write("Utility blockdev failed: exited with code %d" % p.returncode)
        kmgs_write("zerombr: blockdev stdout: %s" % stdout)
        kmgs_write("zerombr: blockdev stderr: %s" % stderr)
        # Drive was busy. Will wait and try again
        if "BLKRRPART: Device or resource busy" in stderr:
            time.sleep(1)
            re_read(drive, num_try + 1)
    else:
        kmgs_write("zerombr: %s kernel forced to re-read partition table" % drive)

def sync_partition(fd, drive):
    wait_for_sync = 30
    while wait_for_sync > 0:
        try:
            os.fsync(fd)
            return 0
        except:
            kmsg_write("zerombr: %s sync failed..." % drive)
            pass
        time.sleep(1)
        wait_for_sync -= 1
    kmsg_write("zerombr: %s failed to sync in 30 seconds..." % drive)
    os.close(fd)
    return 1

def check_zeros(fd):
    data = os.read(fd, to_zero)
    zeroed = True
    for b in data:
        if b == '\0':
            continue
        zeroed = False
        break
    return zeroed

def is_zeroed(fd, to_zero):
    if not check_zeros(fd):
        return False
    os.lseek(fd, -to_zero, 2)
    return check_zeros(fd)

def ioctl(fd, req, fmt, *args):
    buf = struct.pack(fmt, *(args or [0]))
    buf = fcntl.ioctl(fd, req, buf)
    return struct.unpack(fmt, buf)[0]

def read_sysfspath(dev):
    try:
        return os.readlink(dev)
    except:
        return ""

kmsg = open("/dev/kmsg", "a+")

lockfd = os.open("/tmp/zerombr.lock", os.O_WRONLY | os.O_CREAT)
try:
    fcntl.lockf(lockfd, fcntl.LOCK_EX | fcntl.LOCK_NB)
except IOError:
    kmgs_write("zerombr: another zerombr started")
    sys.exit(0)

try:
    fd = open("/proc/cmdline", "r")
    cmdline = fd.read().split()
    fd.close()
except:
    kmgs_write("zerombr: Failed to open /proc/cmdline")
    sys.exit(0)

zerodrives= []

for cmd in cmdline:
    if cmd == 'zerombr':
        zerodrives = ['all']
        break
    if not cmd.startswith("zerombr="):
        continue
    cmd = re.sub("^zerombr=", "", cmd)
    zerodrives = [ "/dev/%s" % drive for drive in re.sub("^zerombr=", "", cmd).split(",") ]
    break

drive_re = re.compile(r'^ +[0-9]+ +[0-9]+ +[0-9]+ +(nvme[0-9]+n[0-9]+|[hs]d[a-z]+)$')
part_re = re.compile(r'^ +[0-9]+ +[0-9]+ +[0-9]+ +(nvme[0-9]+n[0-9]+p[0-9]+|[hs]d[a-z]+[0-9]+)$')

def get_partitions():
    with open("/proc/partitions", 'r') as f:
        lines = f.readlines()

    return [ b for b in lines if part_re.match(b) ] + \
            [ b for b in lines if drive_re.match(b) ]

partitions = get_partitions()

for l in partitions:
    lsp = l.split()
    drive = os.path.join("/dev", lsp[3])

    # Skip not allowed
    if not 'all' in zerodrives and not drive in zerodrives:
        continue

    # Skip not existing
    if not os.path.exists(drive):
        continue

    # Detect parent sysfspath
    sysfspath = ""
    for d in glob("/sys/block/*"):
        if os.path.isdir(d + "/" + lsp[3]):
            sysfspath = read_sysfspath(d)
            break

    # Check for parent
    if not sysfspath:
        sysfspath = read_sysfspath("/sys/block/" + lsp[3])

    # Skip USB devices
    if "allow_usb_hdd" not in cmdline and re.match(".*\/usb.*", sysfspath):
        kmgs_write("zerombr: skipped usb drive %s" % drive)
        continue

    # make sure no drive contains old metadata, by filling the first
    # and the last 17kB
    # avoid overflow
    to_zero = min(int(lsp[2]), 17) << 10
    kmgs_write("zerombr: %s" % drive)
    try:
        fd = os.open(drive, os.O_RDONLY)
    except:
        kmgs_write("%s open for read failed" % drive)
        continue

    # Skip already zeroed
    if is_zeroed(fd, to_zero):
        kmgs_write("zerombr: %s already zeroed, skipping..." % drive)
        os.close(fd)
        continue

    # Open for write
    os.close(fd)
    try:
        fd = os.open(drive, os.O_WRONLY)
    except:
        kmgs_write("%s open for write failed" % drive)
        continue

    os.write(fd, '\0' * to_zero)
    kmgs_write("zerombr: %s first %i bytes ..." % (drive, to_zero))
    if drive_re.match(l) and sync_partition(fd, drive) != 0:
        continue
    kmgs_write(" ... zeroed")
    # SEEK_END == 2
    os.lseek(fd, -to_zero, 2)
    kmgs_write("zerombr: %s preparing to zeroend..." % drive)
    os.write(fd, '\0' * to_zero)
    kmgs_write("zerombr: %s last %i bytes ..." % (drive, to_zero))
    if drive_re.match(l) and sync_partition(fd, drive) != 0:
        continue
    kmgs_write(" ... zeroed")
    kmgs_write("zerombr: %s totally zeroed" % drive)
    if drive_re.match(l):
        kmgs_write("zerombr: %s preparing to blkdiscard..." % drive)
        try:
            ioctl(fd, IOCTL_BLKDISCARD, 'LL', 0, ioctl(fd, IOCTL_BLKGETSIZE64, 'L'))
        except:
            kmgs_write("zerombr: %s blkdiscard failed" % drive)
        else:
            kmgs_write(" ... blkdiscarded")
        re_read(drive, 0)
    os.close(fd)

    # Check for missed
    if partitions != get_partitions():
        partitions += list(set(get_partitions()) - set(partitions))

try:
    fcntl.lockf(lockfd, fcntl.LOCK_UN)
except:
    pass

kmsg.close()
sys.exit(0)
