aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGilles Boccon-Gibod <boccongibod@google.com>2024-04-04 19:17:07 -0700
committerGitHub <noreply@github.com>2024-04-04 19:17:07 -0700
commit2698d4534e21046a8e3744a9fe59b903ebbc5d65 (patch)
tree8f905bd2dd179a15f46cb6cb9dec98bc6b59e5e9
parentbbcd64286a7010d6e2cd51e0ce5cc0e6ad6d04e2 (diff)
parent1ceeccbbc0ae6f0c36c24ef593082d1f36ca0f5a (diff)
downloadbumble-2698d4534e21046a8e3744a9fe59b903ebbc5d65.tar.gz
Merge pull request #435 from jeru/main
open_tcp_server_transport: allow explicit sock as input.
-rw-r--r--.gitignore2
-rw-r--r--bumble/transport/tcp_server.py29
-rw-r--r--tests/transport_tcp_server_test.py64
3 files changed, 90 insertions, 5 deletions
diff --git a/.gitignore b/.gitignore
index 1a5fb9d..ac9f74d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,6 +6,8 @@ dist/
docs/mkdocs/site
test-results.xml
__pycache__
+# Vim
+.*.sw*
# generated by setuptools_scm
bumble/_version.py
.vscode/launch.json
diff --git a/bumble/transport/tcp_server.py b/bumble/transport/tcp_server.py
index 77d0304..8991ead 100644
--- a/bumble/transport/tcp_server.py
+++ b/bumble/transport/tcp_server.py
@@ -18,6 +18,7 @@
from __future__ import annotations
import asyncio
import logging
+import socket
from .common import Transport, StreamPacketSource
@@ -28,6 +29,12 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
+
+# A pass-through function to ease mock testing.
+async def _create_server(*args, **kw_args):
+ await asyncio.get_running_loop().create_server(*args, **kw_args)
+
+
async def open_tcp_server_transport(spec: str) -> Transport:
'''
Open a TCP server transport.
@@ -38,7 +45,22 @@ async def open_tcp_server_transport(spec: str) -> Transport:
Example: _:9001
'''
+ local_host, local_port = spec.split(':')
+ return await _open_tcp_server_transport_impl(
+ host=local_host if local_host != '_' else None, port=int(local_port)
+ )
+
+
+async def open_tcp_server_transport_with_socket(sock: socket.socket) -> Transport:
+ '''
+ Open a TCP server transport with an existing socket.
+
+ One reason to use this variant is to let python pick an unused port.
+ '''
+ return await _open_tcp_server_transport_impl(sock=sock)
+
+async def _open_tcp_server_transport_impl(**kwargs) -> Transport:
class TcpServerTransport(Transport):
async def close(self):
await super().close()
@@ -77,13 +99,10 @@ async def open_tcp_server_transport(spec: str) -> Transport:
else:
logger.debug('no client, dropping packet')
- local_host, local_port = spec.split(':')
packet_source = StreamPacketSource()
packet_sink = TcpServerPacketSink()
- await asyncio.get_running_loop().create_server(
- lambda: TcpServerProtocol(packet_source, packet_sink),
- host=local_host if local_host != '_' else None,
- port=int(local_port),
+ await _create_server(
+ lambda: TcpServerProtocol(packet_source, packet_sink), **kwargs
)
return TcpServerTransport(packet_source, packet_sink)
diff --git a/tests/transport_tcp_server_test.py b/tests/transport_tcp_server_test.py
new file mode 100644
index 0000000..a5f015d
--- /dev/null
+++ b/tests/transport_tcp_server_test.py
@@ -0,0 +1,64 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import os
+import pytest
+import socket
+import unittest
+from unittest.mock import ANY, patch
+
+from bumble.transport.tcp_server import (
+ open_tcp_server_transport,
+ open_tcp_server_transport_with_socket,
+)
+
+
+class OpenTcpServerTransportTests(unittest.TestCase):
+ def setUp(self):
+ self.patcher = patch('bumble.transport.tcp_server._create_server')
+ self.mock_create_server = self.patcher.start()
+
+ def tearDown(self):
+ self.patcher.stop()
+
+ def test_open_with_spec(self):
+ asyncio.run(open_tcp_server_transport('localhost:32100'))
+ self.mock_create_server.assert_awaited_once_with(
+ ANY, host='localhost', port=32100
+ )
+
+ def test_open_with_port_only_spec(self):
+ asyncio.run(open_tcp_server_transport('_:32100'))
+ self.mock_create_server.assert_awaited_once_with(ANY, host=None, port=32100)
+
+ def test_open_with_socket(self):
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+ asyncio.run(open_tcp_server_transport_with_socket(sock=sock))
+ self.mock_create_server.assert_awaited_once_with(ANY, sock=sock)
+
+
+@pytest.mark.skipif(
+ not os.environ.get('PYTEST_NOSKIP', 0),
+ reason='''\
+Not hermetic. Should only run manually with
+ $ PYTEST_NOSKIP=1 pytest tests
+''',
+)
+def test_open_with_real_socket():
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+ sock.bind(('localhost', 0))
+ port = sock.getsockname()[1]
+ assert port != 0
+ asyncio.run(open_tcp_server_transport_with_socket(sock=sock))