Commit 7050ad2c authored by Tom Niget's avatar Tom Niget

Close automatically created dup sockets

parent 455179ac
...@@ -21,17 +21,17 @@ import struct ...@@ -21,17 +21,17 @@ import struct
from io import IOBase from io import IOBase
def __check_socket(sock: socket.socket | IOBase) -> socket.socket: def __check_socket(sock: socket.socket | IOBase) -> (socket.socket, bool):
if hasattr(sock, 'family') and sock.family != socket.AF_UNIX: if hasattr(sock, 'family') and sock.family != socket.AF_UNIX:
raise ValueError("Only AF_UNIX sockets are allowed") raise ValueError("Only AF_UNIX sockets are allowed")
if hasattr(sock, 'fileno'): if hasattr(sock, 'fileno'):
sock = socket.fromfd(sock.fileno(), family=socket.AF_UNIX, type=socket.SOCK_STREAM) return socket.fromfd(sock.fileno(), family=socket.AF_UNIX, type=socket.SOCK_STREAM), True
if not isinstance(sock, socket.socket): if not isinstance(sock, socket.socket):
raise TypeError("An socket object or file descriptor was expected") raise TypeError("An socket object or file descriptor was expected")
return sock return sock, False
def __check_fd(fd) -> int: def __check_fd(fd) -> int:
try: try:
...@@ -46,7 +46,10 @@ def __check_fd(fd) -> int: ...@@ -46,7 +46,10 @@ def __check_fd(fd) -> int:
def recvfd(sock: socket.socket | IOBase, msg_buf: int = 4096) -> tuple[int, str]: def recvfd(sock: socket.socket | IOBase, msg_buf: int = 4096) -> tuple[int, str]:
size = struct.calcsize("@i") size = struct.calcsize("@i")
msg, ancdata, flags, addr = __check_socket(sock).recvmsg(msg_buf, socket.CMSG_SPACE(size)) sock, close = __check_socket(sock)
msg, ancdata, flags, addr = sock.recvmsg(msg_buf, socket.CMSG_SPACE(size))
if close:
sock.close()
cmsg_level, cmsg_type, cmsg_data = ancdata[0] cmsg_level, cmsg_type, cmsg_data = ancdata[0]
if not (cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS): if not (cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS):
raise RuntimeError("The message received did not contain exactly one" + raise RuntimeError("The message received did not contain exactly one" +
...@@ -60,6 +63,11 @@ def recvfd(sock: socket.socket | IOBase, msg_buf: int = 4096) -> tuple[int, str] ...@@ -60,6 +63,11 @@ def recvfd(sock: socket.socket | IOBase, msg_buf: int = 4096) -> tuple[int, str]
def sendfd(sock: socket.socket | IOBase, fd: int, message: bytes = b"NONE") -> int: def sendfd(sock: socket.socket | IOBase, fd: int, message: bytes = b"NONE") -> int:
return __check_socket(sock).sendmsg( sock, close = __check_socket(sock)
[message], try:
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, struct.pack("@i", fd))]) return sock.sendmsg(
\ No newline at end of file [message],
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, struct.pack("@i", fd))])
finally:
if close:
sock.close()
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment