aboutsummaryrefslogtreecommitdiff
path: root/tests/mobly/snippet/callback_handler_base_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/mobly/snippet/callback_handler_base_test.py')
-rw-r--r--tests/mobly/snippet/callback_handler_base_test.py105
1 files changed, 67 insertions, 38 deletions
diff --git a/tests/mobly/snippet/callback_handler_base_test.py b/tests/mobly/snippet/callback_handler_base_test.py
index 0891fd5..32199b2 100644
--- a/tests/mobly/snippet/callback_handler_base_test.py
+++ b/tests/mobly/snippet/callback_handler_base_test.py
@@ -28,25 +28,34 @@ MOCK_RAW_EVENT = {
'data': {
'exampleData': "Here's a simple event.",
'successful': True,
- 'secretNumber': 12
- }
+ 'secretNumber': 12,
+ },
}
class FakeCallbackHandler(callback_handler_base.CallbackHandlerBase):
"""Fake client class for unit tests."""
- def __init__(self,
- callback_id=None,
- event_client=None,
- ret_value=None,
- method_name=None,
- device=None,
- rpc_max_timeout_sec=120,
- default_timeout_sec=120):
+ def __init__(
+ self,
+ callback_id=None,
+ event_client=None,
+ ret_value=None,
+ method_name=None,
+ device=None,
+ rpc_max_timeout_sec=120,
+ default_timeout_sec=120,
+ ):
"""Initializes a fake callback handler object used for unit tests."""
- super().__init__(callback_id, event_client, ret_value, method_name, device,
- rpc_max_timeout_sec, default_timeout_sec)
+ super().__init__(
+ callback_id,
+ event_client,
+ ret_value,
+ method_name,
+ device,
+ rpc_max_timeout_sec,
+ default_timeout_sec,
+ )
self.mock_rpc_func = mock.Mock()
def callEventWaitAndGetRpc(self, *args, **kwargs):
@@ -66,15 +75,18 @@ class CallbackHandlerBaseTest(unittest.TestCase):
self.assertEqual(str(actual_event), str(expected_event))
def test_default_timeout_too_large(self):
- err_msg = ('The max timeout of a single RPC must be no smaller than '
- 'the default timeout of the callback handler. '
- 'Got rpc_max_timeout_sec=10, default_timeout_sec=20.')
+ err_msg = (
+ 'The max timeout of a single RPC must be no smaller than '
+ 'the default timeout of the callback handler. '
+ 'Got rpc_max_timeout_sec=10, default_timeout_sec=20.'
+ )
with self.assertRaisesRegex(ValueError, err_msg):
_ = FakeCallbackHandler(rpc_max_timeout_sec=10, default_timeout_sec=20)
def test_timeout_property(self):
- handler = FakeCallbackHandler(rpc_max_timeout_sec=20,
- default_timeout_sec=10)
+ handler = FakeCallbackHandler(
+ rpc_max_timeout_sec=20, default_timeout_sec=10
+ )
self.assertEqual(handler.rpc_max_timeout_sec, 20)
self.assertEqual(handler.default_timeout_sec, 10)
with self.assertRaises(AttributeError):
@@ -92,39 +104,48 @@ class CallbackHandlerBaseTest(unittest.TestCase):
def test_event_dict_to_snippet_event(self):
handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID)
handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
- return_value=MOCK_RAW_EVENT)
+ return_value=MOCK_RAW_EVENT
+ )
event = handler.waitAndGet('ha', timeout=10)
self.assert_event_correct(event, MOCK_RAW_EVENT)
handler.mock_rpc_func.callEventWaitAndGetRpc.assert_called_once_with(
- MOCK_CALLBACK_ID, 'ha', 10)
+ MOCK_CALLBACK_ID, 'ha', 10
+ )
def test_wait_and_get_timeout_default(self):
handler = FakeCallbackHandler(rpc_max_timeout_sec=20, default_timeout_sec=5)
handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
- return_value=MOCK_RAW_EVENT)
+ return_value=MOCK_RAW_EVENT
+ )
_ = handler.waitAndGet('ha')
handler.mock_rpc_func.callEventWaitAndGetRpc.assert_called_once_with(
- mock.ANY, mock.ANY, 5)
+ mock.ANY, mock.ANY, 5
+ )
def test_wait_and_get_timeout_ecxeed_threshold(self):
rpc_max_timeout_sec = 5
big_timeout_sec = 10
- handler = FakeCallbackHandler(rpc_max_timeout_sec=rpc_max_timeout_sec,
- default_timeout_sec=rpc_max_timeout_sec)
+ handler = FakeCallbackHandler(
+ rpc_max_timeout_sec=rpc_max_timeout_sec,
+ default_timeout_sec=rpc_max_timeout_sec,
+ )
handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
- return_value=MOCK_RAW_EVENT)
+ return_value=MOCK_RAW_EVENT
+ )
expected_msg = (
f'Specified timeout {big_timeout_sec} is longer than max timeout '
- f'{rpc_max_timeout_sec}.')
+ f'{rpc_max_timeout_sec}.'
+ )
with self.assertRaisesRegex(errors.CallbackHandlerBaseError, expected_msg):
handler.waitAndGet('ha', big_timeout_sec)
def test_wait_for_event(self):
handler = FakeCallbackHandler()
handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
- return_value=MOCK_RAW_EVENT)
+ return_value=MOCK_RAW_EVENT
+ )
def some_condition(event):
return event.data['successful']
@@ -135,48 +156,56 @@ class CallbackHandlerBaseTest(unittest.TestCase):
def test_wait_for_event_negative(self):
handler = FakeCallbackHandler()
handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
- return_value=MOCK_RAW_EVENT)
+ return_value=MOCK_RAW_EVENT
+ )
expected_msg = (
'Timed out after 0.01s waiting for an "AsyncTaskResult" event that'
- ' satisfies the predicate "some_condition".')
+ ' satisfies the predicate "some_condition".'
+ )
def some_condition(_):
return False
- with self.assertRaisesRegex(errors.CallbackHandlerTimeoutError,
- expected_msg):
+ with self.assertRaisesRegex(
+ errors.CallbackHandlerTimeoutError, expected_msg
+ ):
handler.waitForEvent('AsyncTaskResult', some_condition, 0.01)
def test_wait_for_event_max_timeout(self):
"""waitForEvent should not raise the timeout exceed threshold error."""
rpc_max_timeout_sec = 5
big_timeout_sec = 10
- handler = FakeCallbackHandler(rpc_max_timeout_sec=rpc_max_timeout_sec,
- default_timeout_sec=rpc_max_timeout_sec)
+ handler = FakeCallbackHandler(
+ rpc_max_timeout_sec=rpc_max_timeout_sec,
+ default_timeout_sec=rpc_max_timeout_sec,
+ )
handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
- return_value=MOCK_RAW_EVENT)
+ return_value=MOCK_RAW_EVENT
+ )
def some_condition(event):
return event.data['successful']
# This line should not raise.
- event = handler.waitForEvent('AsyncTaskResult',
- some_condition,
- timeout=big_timeout_sec)
+ event = handler.waitForEvent(
+ 'AsyncTaskResult', some_condition, timeout=big_timeout_sec
+ )
self.assert_event_correct(event, MOCK_RAW_EVENT)
def test_get_all(self):
handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID)
handler.mock_rpc_func.callEventGetAllRpc = mock.Mock(
- return_value=[MOCK_RAW_EVENT, MOCK_RAW_EVENT])
+ return_value=[MOCK_RAW_EVENT, MOCK_RAW_EVENT]
+ )
all_events = handler.getAll('ha')
for event in all_events:
self.assert_event_correct(event, MOCK_RAW_EVENT)
handler.mock_rpc_func.callEventGetAllRpc.assert_called_once_with(
- MOCK_CALLBACK_ID, 'ha')
+ MOCK_CALLBACK_ID, 'ha'
+ )
if __name__ == '__main__':