"""udevmonitor.py - enumerates and monitors devices using (e)udev.

Copyright (C) 2018 by Kozec

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License version 2 as published by
the Free Software Foundation

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
"""
from __future__ import annotations

import ctypes
import errno
import os
from ctypes.util import find_library
from typing import NamedTuple, TypeVar


class Eudev:
	"""Wrapped udev system library."""

	LIB_NAME = "udev"

	def __init__(self) -> None:
		self._ctx = None
		try:
			self._lib = ctypes.cdll.LoadLibrary("libudev.so.1")
		except OSError:
			lib_name = find_library(self.LIB_NAME)
			if lib_name is None:
				raise ImportError("No library named udev")
			self._lib = ctypes.CDLL(lib_name)
		Eudev._setup_lib(self._lib)
		self._ctx = self._lib.udev_new()
		if self._ctx is None:
			raise OSError("Failed to initialize udev context")

	@staticmethod
	def _setup_lib(lib: ctypes.CDLL) -> None:
		"""Just so it's away from init and can be folded in IDE."""
		# udev
		lib.udev_new.restype = ctypes.c_void_p
		lib.udev_unref.argtypes = [ ctypes.c_void_p ]
		# enumeration
		lib.udev_enumerate_new.argtypes = [ ctypes.c_void_p ]
		lib.udev_enumerate_new.restype = ctypes.c_void_p
		lib.udev_enumerate_unref.argtypes = [ ctypes.c_void_p ]
		lib.udev_enumerate_scan_devices.argtypes = [ ctypes.c_void_p ]
		lib.udev_enumerate_scan_devices.restype = ctypes.c_int
		lib.udev_enumerate_get_list_entry.argtypes = [ ctypes.c_void_p ]
		lib.udev_enumerate_get_list_entry.restype = ctypes.c_void_p
		lib.udev_list_entry_get_next.argtypes = [ ctypes.c_void_p ]
		lib.udev_list_entry_get_next.restype = ctypes.c_void_p
		lib.udev_list_entry_get_value.argtypes = [ ctypes.c_void_p ]
		lib.udev_list_entry_get_value.restype = ctypes.c_char_p
		lib.udev_list_entry_get_name.argtypes = [ ctypes.c_void_p ]
		lib.udev_list_entry_get_name.restype = ctypes.c_char_p
		# monitoring
		lib.udev_monitor_new_from_netlink.argtypes = [ ctypes.c_void_p, ctypes.c_char_p ]
		lib.udev_monitor_new_from_netlink.restype = ctypes.c_void_p
		lib.udev_monitor_unref.argtypes = [ ctypes.c_void_p ]
		lib.udev_monitor_enable_receiving.argtypes = [ ctypes.c_void_p ]
		lib.udev_monitor_enable_receiving.restype = ctypes.c_int
		lib.udev_monitor_set_receive_buffer_size.argtypes = [ ctypes.c_void_p, ctypes.c_int ]
		lib.udev_monitor_set_receive_buffer_size.restype = ctypes.c_int
		lib.udev_monitor_get_fd.argtypes = [ ctypes.c_void_p ]
		lib.udev_monitor_get_fd.restype = ctypes.c_int
		lib.udev_monitor_receive_device.argtypes = [ ctypes.c_void_p ]
		lib.udev_monitor_receive_device.restype = ctypes.c_void_p
		lib.udev_monitor_filter_update.argtypes = [ ctypes.c_void_p ]
		lib.udev_monitor_filter_update.restype = ctypes.c_int
		lib.udev_monitor_filter_add_match_subsystem_devtype.argtypes = [ ctypes.c_void_p, ctypes.c_char_p, ctypes.c_char_p ]
		lib.udev_monitor_filter_add_match_subsystem_devtype.restype = ctypes.c_int
		lib.udev_monitor_filter_add_match_tag.argtypes = [ ctypes.c_void_p, ctypes.c_char_p ]
		lib.udev_monitor_filter_add_match_tag.restype = ctypes.c_int
		# device
		lib.udev_device_get_action.argtypes = [ ctypes.c_void_p ]
		lib.udev_device_get_action.restype = ctypes.c_char_p
		lib.udev_device_get_devnode.argtypes = [ ctypes.c_void_p ]
		lib.udev_device_get_devnode.restype = ctypes.c_char_p
		lib.udev_device_get_subsystem.argtypes = [ ctypes.c_void_p ]
		lib.udev_device_get_subsystem.restype = ctypes.c_char_p
		lib.udev_device_get_devtype.argtypes = [ ctypes.c_void_p ]
		lib.udev_device_get_devtype.restype = ctypes.c_char_p
		lib.udev_device_get_syspath.argtypes = [ ctypes.c_void_p ]
		lib.udev_device_get_syspath.restype = ctypes.c_char_p
		lib.udev_device_get_sysname.argtypes = [ ctypes.c_void_p ]
		lib.udev_device_get_sysname.restype = ctypes.c_char_p
		lib.udev_device_get_is_initialized.argtypes = [ ctypes.c_void_p ]
		lib.udev_device_get_is_initialized.restype = ctypes.c_int
		lib.udev_device_get_devnum.argtypes = [ ctypes.c_void_p ]
		lib.udev_device_get_devnum.restype = ctypes.c_int
		lib.udev_device_unref.argtypes = [ ctypes.c_void_p ]

		for name in dir(Enumerator):
			if "match_" in name:
				twoargs = getattr(getattr(Enumerator, name), "twoargs", False)
				fn = getattr(lib, "udev_enumerate_add_" + name)
				if twoargs:
					fn.argtypes = [ ctypes.c_void_p, ctypes.c_char_p, ctypes.c_char_p ]
				else:
					fn.argtypes = [ ctypes.c_void_p, ctypes.c_char_p ]
				fn.restype = ctypes.c_int


	def __del__(self) -> None:
		if self._ctx is not None:
			self._lib.udev_unref(self._ctx)
			self._ctx = None

	enumerate_typevar = TypeVar("enumerate_typevar", bound="Enumerator")
	def enumerate(self, subclass: type [enumerate_typevar] | None = None) -> enumerate_typevar | Enumerator:
		"""Return new Enumerator instance."""
		enumerator: ctypes.c_void_p | None = self._lib.udev_enumerate_new(self._ctx)
		if enumerator is None:
			raise OSError("Failed to initialize enumerator")
		if subclass is not None:
			assert issubclass(subclass, Enumerator), f"subclass must be a subclass of Enumerator but {subclass} was provided"
#		print('enumerate - ', type(enumerator), enumerator)
		final_class = subclass or Enumerator
		return final_class(self, enumerator)

	monitor_typevar = TypeVar("monitor_typevar", bound="Monitor")
	def monitor(self, subclass: type [monitor_typevar] | None = None) -> monitor_typevar | Monitor:
		"""Return new Monitor instance."""
		monitor: ctypes.c_void_p | None = self._lib.udev_monitor_new_from_netlink(self._ctx, b"udev")
		if monitor is None:
			raise OSError("Failed to initialize monitor")
		if subclass is not None:
			assert issubclass(subclass, Monitor), f"subclass must be a subclass of Monitor but {subclass} was provided"
		final_class = subclass or Monitor
		return final_class(self, monitor)


def twoargs(fn):
	fn.twoargs = True
	return fn


class Enumerator:
	"""Iterable object used for enumerating available devices.

	Yields syspaths (strings).

	All match_* methods are returning self for chaining.
	"""

	def __init__(self, eudev: Eudev, enumerator: ctypes.c_void_p) -> None:
		self._eudev = eudev
		self._enumerator = enumerator
		self._keep_in_mem = []
		self._enumeration_started = False
		self._next: ctypes.c_void_p | None = None


	def __del__(self) -> None:
		if self._enumerator is not None:
			self._eudev._lib.udev_enumerate_unref(self._enumerator)
			self._enumerator = None


	def _add_match(self, whichone: str, *pars) -> Enumerator:
		if self._enumeration_started:
			raise RuntimeError("Cannot add match after enumeration is started")
		fn = getattr(self._eudev._lib, "udev_enumerate_add_" + whichone)
		pars = [ ctypes.c_char_p(p.encode("utf-8") if type(p) is str else p) for p in pars ]
		self._keep_in_mem += pars
		err = fn(self._enumerator, *pars)
		if err < 0:
			raise OSError("udev_enumerate_add_%s: error %s" % (whichone, err))
		return self


	@twoargs
	def match_sysattr(self, sysattr, value): return self._add_match("match_sysattr", sysattr, value)
	@twoargs
	def nomatch_sysattr(self, sysattr, value): return self._add_match("nomatch_sysattr", sysattr, value)
	@twoargs
	def match_property(self, property, value): return self._add_match("match_property", property, value)
	def match_subsystem(self, subsystem: str): return self._add_match("match_subsystem", subsystem)
	def nomatch_subsystem(self, subsystem: str): return self._add_match("nomatch_subsystem", subsystem)
	def match_sysname(self, sysname: str): return self._add_match("match_sysname", sysname)
	def match_tag(self, tag): return self._add_match("match_tag", tag)
	def match_is_initialized(self): return self._add_match("match_is_initialized")
	# match_parent is not implemented


	def __iter__(self) -> Enumerator:
		if self._enumeration_started:
			raise RuntimeError("Cannot iterate same Enumerator twice")
		self._enumeration_started = True
		err = self._eudev._lib.udev_enumerate_scan_devices(self._enumerator)
		if err < 0:
			raise OSError("udev_enumerate_scan_devices: error %s" % (err, ))
		self._next = self._eudev._lib.udev_enumerate_get_list_entry(self._enumerator)
		return self


	def next(self) -> str:
		return self.__next__()


	def __next__(self) -> str:
		if not self._enumeration_started:
			self.__iter__() # Starts the enumeration
		if self._next is None:
			raise StopIteration
		udev_name_pointer: ctypes.c_char_p | None = self._eudev._lib.udev_list_entry_get_name(self._next)
		if udev_name_pointer is None:
			raise OSError("udev_list_entry_get_name failed, can't get syspath")
		self._next = self._eudev._lib.udev_list_entry_get_next(self._next)
		return str(udev_name_pointer, "utf-8")

class DeviceEvent(NamedTuple):
	action: str
	node: str | None
	initialized: bool
	subsystem: str
	devtype: str
	syspath: str
	devnum: int

class Monitor:
	"""Monitor object that receives device events.

	receive_device method blocks until next event is processed, so it can be
	used either in dumb loop, or called when select syscall reports descriptor
	returned by get_fd has data available.

	All match_* methods are returning self for chaining
	"""

	def __init__(self, eudev: Eudev, monitor: ctypes.c_void_p) -> None:
		self._eudev = eudev
		self._monitor = monitor
		self._monitor_started = False
		self._keep_in_mem = []
		self._enabled_matches = set()


	def __del__(self) -> None:
		if self._monitor is not None:
			self._eudev._lib.udev_monitor_unref(self._monitor)
			self._monitor = None


	def _add_match(self, whichone, *pars):
		key = tuple([whichone] + list(pars))
		if key in self._enabled_matches:
			# Already done
			return self
		fn = getattr(self._eudev._lib, "udev_monitor_filter_add_" + whichone)
		pars = [ ctypes.c_char_p(p.encode("utf-8") if type(p) is str else p) for p in pars ]
		self._keep_in_mem += pars
		err = fn(self._monitor, *pars)
		if err < 0:
			raise OSError("udev_monitor_filter_add_%s: error %s" % (whichone, errno.errorcode.get(err, err)))
		self._enabled_matches.add(key)
		if self._monitor_started:
			err = self._eudev._lib.udev_monitor_filter_update(self._monitor)
			if err < 0:
				raise OSError("udev_monitor_filter_update: error %s" % (errno.errorcode.get(err, err), ))
		return self


	def match_subsystem_devtype(self, subsystem: str, devtype=None):
		return self._add_match("match_subsystem_devtype", subsystem, devtype)
	def match_subsystem(self, subsystem: str):
		return self._add_match("match_subsystem_devtype", subsystem, None)
	def match_tag(self, tag):
		return self._add_match("match_tag", tag)

	def is_started(self):
		return self._monitor_started


	def get_fd(self):
		fileno = self._eudev._lib.udev_monitor_get_fd(self._monitor)
		if fileno < 0:
			raise OSError("udev_monitor_get_fd: error %s" % (errno.errorcode.get(fileno, fileno), ))
		return fileno


	def enable_receiving(self):
		""" Returns self for chaining """
		if self._monitor_started:
			return # Error, but unimportant
		err = self._eudev._lib.udev_monitor_enable_receiving(self._monitor)
		if err < 0:
			raise OSError("udev_monitor_enable_receiving: error %s" % (errno.errorcode.get(err, err)))
		self._monitor_started = True
		return self


	def set_receive_buffer_size(self, size):
		""" Returns self for chaining """
		err = self._eudev._lib.udev_monitor_set_receive_buffer_size(self._monitor, size)
		if err < 0:
			raise OSError("udev_monitor_set_receive_buffer_size: error %s" % (errno.errorcode.get(err, err)))
		return self


	fileno = get_fd          # python stuff likes this name better
	start = enable_receiving # I like this name better


	def receive_device(self) -> DeviceEvent | None:
		if not self._monitor_started:
			self.enable_receiving()

		dev = self._eudev._lib.udev_monitor_receive_device(self._monitor)
		if dev is None:
			# udev_monitor_receive_device is _supposed_ to be blocking.
			# It doesn't looks that way
			return None

		devnode = self._eudev._lib.udev_device_get_devnode(dev)
		devnode_str = str(devnode, "utf-8") if devnode else None

		devtype = self._eudev._lib.udev_device_get_devtype(dev)
		devtype_str = str(devtype, "utf-8") if devtype else None

		event = DeviceEvent(
			str(self._eudev._lib.udev_device_get_action(dev), "utf-8"),
			devnode_str,
			self._eudev._lib.udev_device_get_is_initialized(dev) == 1,
			str(self._eudev._lib.udev_device_get_subsystem(dev), "utf-8"),
			str(self._eudev._lib.udev_device_get_devtype(dev) or b"", "utf-8"),
			str(self._eudev._lib.udev_device_get_syspath(dev) or b"", "utf-8"),
			self._eudev._lib.udev_device_get_devnum(dev),
		)

		self._eudev._lib.udev_device_unref(dev)
		return event


if __name__ == "__main__":
	udev = Eudev()
	en = udev.enumerate().match_subsystem("hidraw")
	for i in en:
		print(i)

	m = udev.monitor().match_subsystem("hidraw").start()
	if m is None:
		msg = "Expected to receive a Monitor device, not None."
		raise RuntimeError(msg)

	while True:
		dev = m.receive_device()
		if dev is not None:
			print(f"{os.major(dev.devnum)} {os.minor(dev.devnum)} {dev}")
