From ce02ea2fdf07f33f23a0321f6b10407b3b51c9cd Mon Sep 17 00:00:00 2001 From: skbeh <60107333+skbeh@users.noreply.github.com> Date: Thu, 21 Aug 2025 15:08:48 +0000 Subject: [PATCH] refactor: use 'with' statement for socket handling Use a 'with' statement to handle socket creation and cleanup. This ensures that sockets are always closed properly, even when exceptions occur, preventing potential resource leaks. --- miio/miioprotocol.py | 182 ++++++++++++++++++------------------- miio/push_server/server.py | 13 +-- 2 files changed, 98 insertions(+), 97 deletions(-) diff --git a/miio/miioprotocol.py b/miio/miioprotocol.py index c05d92882..ad90b34e7 100644 --- a/miio/miioprotocol.py +++ b/miio/miioprotocol.py @@ -116,34 +116,34 @@ def discover(addr: Optional[str] = None, timeout: int = 5) -> Any: "21310020ffffffffffffffffffffffffffffffffffffffffffffffffffffffff" ) - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) - s.settimeout(timeout) - for _ in range(3): - s.sendto(helobytes, (addr, 54321)) - while True: - try: - data, recv_addr = s.recvfrom(1024) - m: Message = Message.parse(data) - _LOGGER.debug("Got a response: %s", m) - if not is_broadcast: - return m - - if recv_addr[0] not in seen_addrs: - _LOGGER.info( - " IP %s (ID: %s) - token: %s", - recv_addr[0], - binascii.hexlify(m.header.value.device_id).decode(), - codecs.encode(m.checksum, "hex"), - ) - seen_addrs.append(recv_addr[0]) - except socket.timeout: - if is_broadcast: - _LOGGER.info("Discovery done") - return # ignore timeouts on discover - except Exception as ex: - _LOGGER.warning("error while reading discover results: %s", ex) - break + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + s.settimeout(timeout) + for _ in range(3): + s.sendto(helobytes, (addr, 54321)) + while True: + try: + data, recv_addr = s.recvfrom(1024) + m: Message = Message.parse(data) + _LOGGER.debug("Got a response: %s", m) + if not is_broadcast: + return m + + if recv_addr[0] not in seen_addrs: + _LOGGER.info( + " IP %s (ID: %s) - token: %s", + recv_addr[0], + binascii.hexlify(m.header.value.device_id).decode(), + codecs.encode(m.checksum, "hex"), + ) + seen_addrs.append(recv_addr[0]) + except socket.timeout: + if is_broadcast: + _LOGGER.info("Discovery done") + return # ignore timeouts on discover + except Exception as ex: + _LOGGER.warning("error while reading discover results: %s", ex) + break def send( self, @@ -187,80 +187,80 @@ def send( Message.parse(m, token=self.token), ) - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - s.settimeout(self._timeout) + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.settimeout(self._timeout) - try: - s.sendto(m, (self.ip, self.port)) - except OSError as ex: - _LOGGER.error("failed to send msg: %s", ex) - raise DeviceException from ex - - try: - data, addr = s.recvfrom(4096) - m = Message.parse(data, token=self.token) + try: + s.sendto(m, (self.ip, self.port)) + except OSError as ex: + _LOGGER.error("failed to send msg: %s", ex) + raise DeviceException from ex - if self.debug > 1: - _LOGGER.debug("recv from %s: %s", addr[0], m) + try: + data, addr = s.recvfrom(4096) + m = Message.parse(data, token=self.token) - header = m.header.value - payload = m.data.value + if self.debug > 1: + _LOGGER.debug("recv from %s: %s", addr[0], m) - self.__id = payload["id"] - self._device_ts = header["ts"] # type: ignore # ts uses timeadapter + header = m.header.value + payload = m.data.value - _LOGGER.debug( - "%s:%s (ts: %s, id: %s) << %s", - self.ip, - self.port, - header["ts"], - payload["id"], - pf(payload), - ) - if "error" in payload: - self._handle_error(payload["error"]) + self.__id = payload["id"] + self._device_ts = header["ts"] # type: ignore # ts uses timeadapter - try: - return payload["result"] - except KeyError: - return payload - except construct.core.ChecksumError as ex: - raise InvalidTokenException( - "Got checksum error which indicates use " - "of an invalid token. " - "Please check your token!" - ) from ex - except OSError as ex: - if retry_count > 0: _LOGGER.debug( - "Retrying with incremented id, retries left: %s", retry_count - ) - self.__id += 100 - self._discovered = False - return self.send( - command, - parameters, - retry_count - 1, - extra_parameters=extra_parameters, + "%s:%s (ts: %s, id: %s) << %s", + self.ip, + self.port, + header["ts"], + payload["id"], + pf(payload), ) + if "error" in payload: + self._handle_error(payload["error"]) + + try: + return payload["result"] + except KeyError: + return payload + except construct.core.ChecksumError as ex: + raise InvalidTokenException( + "Got checksum error which indicates use " + "of an invalid token. " + "Please check your token!" + ) from ex + except OSError as ex: + if retry_count > 0: + _LOGGER.debug( + "Retrying with incremented id, retries left: %s", retry_count + ) + self.__id += 100 + self._discovered = False + return self.send( + command, + parameters, + retry_count - 1, + extra_parameters=extra_parameters, + ) - _LOGGER.error("Got error when receiving: %s", ex) - raise DeviceException("No response from the device") from ex + _LOGGER.error("Got error when receiving: %s", ex) + raise DeviceException("No response from the device") from ex - except RecoverableError as ex: - if retry_count > 0: - _LOGGER.debug( - "Retrying to send failed command, retries left: %s", retry_count - ) - return self.send( - command, - parameters, - retry_count - 1, - extra_parameters=extra_parameters, - ) + except RecoverableError as ex: + if retry_count > 0: + _LOGGER.debug( + "Retrying to send failed command, retries left: %s", retry_count + ) + return self.send( + command, + parameters, + retry_count - 1, + extra_parameters=extra_parameters, + ) - _LOGGER.error("Got error when receiving: %s", ex) - raise DeviceException("Unable to recover failed command") from ex + _LOGGER.error("Got error when receiving: %s", ex) + raise DeviceException("Unable to recover failed command") from ex @property def _id(self) -> int: diff --git a/miio/push_server/server.py b/miio/push_server/server.py index 9080f3ee1..1689c870e 100644 --- a/miio/push_server/server.py +++ b/miio/push_server/server.py @@ -210,12 +210,13 @@ async def unsubscribe_event(self, device: Device, event_id: str): async def _get_server_ip(self): """Connect to the miio device to get server_ip using a one time use socket.""" - get_ip_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) - get_ip_socket.bind((self._address, SERVER_PORT)) - get_ip_socket.setblocking(False) - await self._loop.sock_connect(get_ip_socket, (self._device_ip, SERVER_PORT)) - server_ip = get_ip_socket.getsockname()[0] - get_ip_socket.close() + with socket.socket( + family=socket.AF_INET, type=socket.SOCK_DGRAM + ) as get_ip_socket: + get_ip_socket.bind((self._address, SERVER_PORT)) + get_ip_socket.setblocking(False) + await self._loop.sock_connect(get_ip_socket, (self._device_ip, SERVER_PORT)) + server_ip = get_ip_socket.getsockname()[0] _LOGGER.debug("Miio push server device ip=%s", server_ip) return server_ip