aboutsummaryrefslogtreecommitdiff
path: root/mobly/snippet/callback_handler_base.py
blob: 50465d1a0675e1480e8836172b0587247b3b4c89 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
# Copyright 2022 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for the base class to handle Mobly Snippet Lib's callback events."""
import abc
import time

from mobly.snippet import callback_event
from mobly.snippet import errors


class CallbackHandlerBase(abc.ABC):
  """Base class for handling Mobly Snippet Lib's callback events.

  All the events handled by a callback handler are originally triggered by one
  async RPC call. All the events are tagged with a callback_id specific to a
  call to an async RPC method defined on the server side.

  The raw message representing an event looks like:

  .. code-block:: python

    {
      'callbackId': <string, callbackId>,
      'name': <string, name of the event>,
      'time': <long, epoch time of when the event was created on the
        server side>,
      'data': <dict, extra data from the callback on the server side>
    }

  Each message is then used to create a CallbackEvent object on the client
  side.

  Attributes:
    ret_value: any, the direct return value of the async RPC call.
  """

  def __init__(self,
               callback_id,
               event_client,
               ret_value,
               method_name,
               device,
               rpc_max_timeout_sec,
               default_timeout_sec=120):
    """Initializes a callback handler base object.

    Args:
      callback_id: str, the callback ID which associates with a group of
        callback events.
      event_client: SnippetClientV2, the client object used to send RPC to the
        server and receive response.
      ret_value: any, the direct return value of the async RPC call.
      method_name: str, the name of the executed Async snippet function.
      device: DeviceController, the device object associated with this handler.
      rpc_max_timeout_sec: float, maximum time for sending a single RPC call.
      default_timeout_sec: float, the default timeout for this handler. It
        must be no longer than rpc_max_timeout_sec.
    """
    self._id = callback_id
    self.ret_value = ret_value
    self._device = device
    self._event_client = event_client
    self._method_name = method_name

    if rpc_max_timeout_sec < default_timeout_sec:
      raise ValueError('The max timeout of a single RPC must be no smaller '
                       'than the default timeout of the callback handler. '
                       f'Got rpc_max_timeout_sec={rpc_max_timeout_sec}, '
                       f'default_timeout_sec={default_timeout_sec}.')
    self._rpc_max_timeout_sec = rpc_max_timeout_sec
    self._default_timeout_sec = default_timeout_sec

  @property
  def rpc_max_timeout_sec(self):
    """Maximum time for sending a single RPC call."""
    return self._rpc_max_timeout_sec

  @property
  def default_timeout_sec(self):
    """Default timeout used by this callback handler."""
    return self._default_timeout_sec

  @property
  def callback_id(self):
    """The callback ID which associates a group of callback events."""
    return self._id

  @abc.abstractmethod
  def callEventWaitAndGetRpc(self, callback_id, event_name, timeout_sec):
    """Calls snippet lib's RPC to wait for a callback event.

    Override this method to use this class with various snippet lib
    implementations.

    This function waits and gets a CallbackEvent with the specified identifier
    from the server. It will raise a timeout error if the expected event does
    not occur within the time limit.

    Args:
      callback_id: str, the callback identifier.
      event_name: str, the callback name.
      timeout_sec: float, the number of seconds to wait for the event. It is
        already checked that this argument is no longer than the max timeout
        of a single RPC.

    Returns:
      The event dictionary.

    Raises:
      errors.CallbackHandlerTimeoutError: Raised if the expected event does not
        occur within the time limit.
    """

  @abc.abstractmethod
  def callEventGetAllRpc(self, callback_id, event_name):
    """Calls snippet lib's RPC to get all existing snippet events.

    Override this method to use this class with various snippet lib
    implementations.

    This function gets all existing events in the server with the specified
    identifier without waiting.

    Args:
      callback_id: str, the callback identifier.
      event_name: str, the callback name.

    Returns:
      A list of event dictionaries.
    """

  def waitAndGet(self, event_name, timeout=None):
    """Waits and gets a CallbackEvent with the specified identifier.

    It will raise a timeout error if the expected event does not occur within
    the time limit.

    Args:
      event_name: str, the name of the event to get.
      timeout: float, the number of seconds to wait before giving up. If None,
        it will be set to self.default_timeout_sec.

    Returns:
      CallbackEvent, the oldest entry of the specified event.

    Raises:
      errors.CallbackHandlerBaseError: If the specified timeout is longer than
        the max timeout supported.
      errors.CallbackHandlerTimeoutError: The expected event does not occur
        within the time limit.
    """
    if timeout is None:
      timeout = self.default_timeout_sec

    if timeout:
      if timeout > self.rpc_max_timeout_sec:
        raise errors.CallbackHandlerBaseError(
            self._device,
            f'Specified timeout {timeout} is longer than max timeout '
            f'{self.rpc_max_timeout_sec}.')

    raw_event = self.callEventWaitAndGetRpc(self._id, event_name, timeout)
    return callback_event.from_dict(raw_event)

  def waitForEvent(self, event_name, predicate, timeout=None):
    """Waits for an event of the specific name that satisfies the predicate.

    This call will block until the expected event has been received or time
    out.

    The predicate function defines the condition the event is expected to
    satisfy. It takes an event and returns True if the condition is
    satisfied, False otherwise.

    Note all events of the same name that are received but don't satisfy
    the predicate will be discarded and not be available for further
    consumption.

    Args:
      event_name: str, the name of the event to wait for.
      predicate: function, a function that takes an event (dictionary) and
        returns a bool.
      timeout: float, the number of seconds to wait before giving up. If None,
        it will be set to self.default_timeout_sec.

    Returns:
      dictionary, the event that satisfies the predicate if received.

    Raises:
      errors.CallbackHandlerTimeoutError: raised if no event that satisfies the
        predicate is received after timeout seconds.
    """
    if timeout is None:
      timeout = self.default_timeout_sec

    deadline = time.perf_counter() + timeout
    while time.perf_counter() <= deadline:
      single_rpc_timeout = deadline - time.perf_counter()
      if single_rpc_timeout < 0:
        break

      single_rpc_timeout = min(single_rpc_timeout, self.rpc_max_timeout_sec)
      try:
        event = self.waitAndGet(event_name, single_rpc_timeout)
      except errors.CallbackHandlerTimeoutError:
        # Ignoring errors.CallbackHandlerTimeoutError since we need to throw
        # one with a more specific message.
        break
      if predicate(event):
        return event

    raise errors.CallbackHandlerTimeoutError(
        self._device,
        f'Timed out after {timeout}s waiting for an "{event_name}" event that '
        f'satisfies the predicate "{predicate.__name__}".')

  def getAll(self, event_name):
    """Gets all existing events in the server with the specified identifier.

    This is a non-blocking call.

    Args:
      event_name: str, the name of the event to get.

    Returns:
      A list of CallbackEvent, each representing an event from the Server side.
    """
    raw_events = self.callEventGetAllRpc(self._id, event_name)
    return [callback_event.from_dict(msg) for msg in raw_events]