diff options
Diffstat (limited to 'tests/mobly/snippet/callback_handler_base_test.py')
-rw-r--r-- | tests/mobly/snippet/callback_handler_base_test.py | 105 |
1 files changed, 67 insertions, 38 deletions
diff --git a/tests/mobly/snippet/callback_handler_base_test.py b/tests/mobly/snippet/callback_handler_base_test.py index 0891fd5..32199b2 100644 --- a/tests/mobly/snippet/callback_handler_base_test.py +++ b/tests/mobly/snippet/callback_handler_base_test.py @@ -28,25 +28,34 @@ MOCK_RAW_EVENT = { 'data': { 'exampleData': "Here's a simple event.", 'successful': True, - 'secretNumber': 12 - } + 'secretNumber': 12, + }, } class FakeCallbackHandler(callback_handler_base.CallbackHandlerBase): """Fake client class for unit tests.""" - def __init__(self, - callback_id=None, - event_client=None, - ret_value=None, - method_name=None, - device=None, - rpc_max_timeout_sec=120, - default_timeout_sec=120): + def __init__( + self, + callback_id=None, + event_client=None, + ret_value=None, + method_name=None, + device=None, + rpc_max_timeout_sec=120, + default_timeout_sec=120, + ): """Initializes a fake callback handler object used for unit tests.""" - super().__init__(callback_id, event_client, ret_value, method_name, device, - rpc_max_timeout_sec, default_timeout_sec) + super().__init__( + callback_id, + event_client, + ret_value, + method_name, + device, + rpc_max_timeout_sec, + default_timeout_sec, + ) self.mock_rpc_func = mock.Mock() def callEventWaitAndGetRpc(self, *args, **kwargs): @@ -66,15 +75,18 @@ class CallbackHandlerBaseTest(unittest.TestCase): self.assertEqual(str(actual_event), str(expected_event)) def test_default_timeout_too_large(self): - err_msg = ('The max timeout of a single RPC must be no smaller than ' - 'the default timeout of the callback handler. ' - 'Got rpc_max_timeout_sec=10, default_timeout_sec=20.') + err_msg = ( + 'The max timeout of a single RPC must be no smaller than ' + 'the default timeout of the callback handler. ' + 'Got rpc_max_timeout_sec=10, default_timeout_sec=20.' + ) with self.assertRaisesRegex(ValueError, err_msg): _ = FakeCallbackHandler(rpc_max_timeout_sec=10, default_timeout_sec=20) def test_timeout_property(self): - handler = FakeCallbackHandler(rpc_max_timeout_sec=20, - default_timeout_sec=10) + handler = FakeCallbackHandler( + rpc_max_timeout_sec=20, default_timeout_sec=10 + ) self.assertEqual(handler.rpc_max_timeout_sec, 20) self.assertEqual(handler.default_timeout_sec, 10) with self.assertRaises(AttributeError): @@ -92,39 +104,48 @@ class CallbackHandlerBaseTest(unittest.TestCase): def test_event_dict_to_snippet_event(self): handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID) handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( - return_value=MOCK_RAW_EVENT) + return_value=MOCK_RAW_EVENT + ) event = handler.waitAndGet('ha', timeout=10) self.assert_event_correct(event, MOCK_RAW_EVENT) handler.mock_rpc_func.callEventWaitAndGetRpc.assert_called_once_with( - MOCK_CALLBACK_ID, 'ha', 10) + MOCK_CALLBACK_ID, 'ha', 10 + ) def test_wait_and_get_timeout_default(self): handler = FakeCallbackHandler(rpc_max_timeout_sec=20, default_timeout_sec=5) handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( - return_value=MOCK_RAW_EVENT) + return_value=MOCK_RAW_EVENT + ) _ = handler.waitAndGet('ha') handler.mock_rpc_func.callEventWaitAndGetRpc.assert_called_once_with( - mock.ANY, mock.ANY, 5) + mock.ANY, mock.ANY, 5 + ) def test_wait_and_get_timeout_ecxeed_threshold(self): rpc_max_timeout_sec = 5 big_timeout_sec = 10 - handler = FakeCallbackHandler(rpc_max_timeout_sec=rpc_max_timeout_sec, - default_timeout_sec=rpc_max_timeout_sec) + handler = FakeCallbackHandler( + rpc_max_timeout_sec=rpc_max_timeout_sec, + default_timeout_sec=rpc_max_timeout_sec, + ) handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( - return_value=MOCK_RAW_EVENT) + return_value=MOCK_RAW_EVENT + ) expected_msg = ( f'Specified timeout {big_timeout_sec} is longer than max timeout ' - f'{rpc_max_timeout_sec}.') + f'{rpc_max_timeout_sec}.' + ) with self.assertRaisesRegex(errors.CallbackHandlerBaseError, expected_msg): handler.waitAndGet('ha', big_timeout_sec) def test_wait_for_event(self): handler = FakeCallbackHandler() handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( - return_value=MOCK_RAW_EVENT) + return_value=MOCK_RAW_EVENT + ) def some_condition(event): return event.data['successful'] @@ -135,48 +156,56 @@ class CallbackHandlerBaseTest(unittest.TestCase): def test_wait_for_event_negative(self): handler = FakeCallbackHandler() handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( - return_value=MOCK_RAW_EVENT) + return_value=MOCK_RAW_EVENT + ) expected_msg = ( 'Timed out after 0.01s waiting for an "AsyncTaskResult" event that' - ' satisfies the predicate "some_condition".') + ' satisfies the predicate "some_condition".' + ) def some_condition(_): return False - with self.assertRaisesRegex(errors.CallbackHandlerTimeoutError, - expected_msg): + with self.assertRaisesRegex( + errors.CallbackHandlerTimeoutError, expected_msg + ): handler.waitForEvent('AsyncTaskResult', some_condition, 0.01) def test_wait_for_event_max_timeout(self): """waitForEvent should not raise the timeout exceed threshold error.""" rpc_max_timeout_sec = 5 big_timeout_sec = 10 - handler = FakeCallbackHandler(rpc_max_timeout_sec=rpc_max_timeout_sec, - default_timeout_sec=rpc_max_timeout_sec) + handler = FakeCallbackHandler( + rpc_max_timeout_sec=rpc_max_timeout_sec, + default_timeout_sec=rpc_max_timeout_sec, + ) handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( - return_value=MOCK_RAW_EVENT) + return_value=MOCK_RAW_EVENT + ) def some_condition(event): return event.data['successful'] # This line should not raise. - event = handler.waitForEvent('AsyncTaskResult', - some_condition, - timeout=big_timeout_sec) + event = handler.waitForEvent( + 'AsyncTaskResult', some_condition, timeout=big_timeout_sec + ) self.assert_event_correct(event, MOCK_RAW_EVENT) def test_get_all(self): handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID) handler.mock_rpc_func.callEventGetAllRpc = mock.Mock( - return_value=[MOCK_RAW_EVENT, MOCK_RAW_EVENT]) + return_value=[MOCK_RAW_EVENT, MOCK_RAW_EVENT] + ) all_events = handler.getAll('ha') for event in all_events: self.assert_event_correct(event, MOCK_RAW_EVENT) handler.mock_rpc_func.callEventGetAllRpc.assert_called_once_with( - MOCK_CALLBACK_ID, 'ha') + MOCK_CALLBACK_ID, 'ha' + ) if __name__ == '__main__': |