diff options
author | Danny Hermes <daniel.j.hermes@gmail.com> | 2016-08-17 15:18:29 -0700 |
---|---|---|
committer | Danny Hermes <daniel.j.hermes@gmail.com> | 2016-08-17 15:56:41 -0700 |
commit | ebe9ed0bbbe4ce51c1a76de694c795e38906d690 (patch) | |
tree | 4bfe3e9c6ccd639ae4a598456daa0ec437e0221c | |
parent | 4c7b3be5a101454e2c641a9835e652a92d16800e (diff) | |
download | oauth2client-ebe9ed0bbbe4ce51c1a76de694c795e38906d690.tar.gz |
Correct query loss when using parse_qsl to dict
-rw-r--r-- | oauth2client/_helpers.py | 54 | ||||
-rw-r--r-- | oauth2client/client.py | 28 | ||||
-rw-r--r-- | oauth2client/tools.py | 12 | ||||
-rw-r--r-- | tests/test__helpers.py | 40 | ||||
-rw-r--r-- | tests/test_client.py | 19 |
5 files changed, 104 insertions, 49 deletions
diff --git a/oauth2client/_helpers.py b/oauth2client/_helpers.py index 79586a5..e912397 100644 --- a/oauth2client/_helpers.py +++ b/oauth2client/_helpers.py @@ -179,6 +179,54 @@ def string_to_scopes(scopes): return scopes +def parse_unique_urlencoded(content): + """Parses unique key-value parameters from urlencoded content. + + Args: + content: string, URL-encoded key-value pairs. + + Returns: + dict, The key-value pairs from ``content``. + + Raises: + ValueError: if one of the keys is repeated. + """ + urlencoded_params = urllib.parse.parse_qs(content) + params = {} + for key, value in six.iteritems(urlencoded_params): + if len(value) != 1: + msg = ('URL-encoded content contains a repeated value:' + '%s -> %s' % (key, ', '.join(value))) + raise ValueError(msg) + params[key] = value[0] + return params + + +def update_query_params(uri, params): + """Updates a URI with new query parameters. + + If a given key from ``params`` is repeated in the ``uri``, then + the URI will be considered invalid and an error will occur. + + If the URI is valid, then each value from ``params`` will + replace the corresponding value in the query parameters (if + it exists). + + Args: + uri: string, A valid URI, with potential existing query parameters. + params: dict, A dictionary of query parameters. + + Returns: + The same URI but with the new query parameters added. + """ + parts = urllib.parse.urlparse(uri) + query_params = parse_unique_urlencoded(parts.query) + query_params.update(params) + new_query = urllib.parse.urlencode(query_params) + new_parts = parts._replace(query=new_query) + return urllib.parse.urlunparse(new_parts) + + def _add_query_parameter(url, name, value): """Adds a query parameter to a url. @@ -195,11 +243,7 @@ def _add_query_parameter(url, name, value): if value is None: return url else: - parsed = list(urllib.parse.urlparse(url)) - query = dict(urllib.parse.parse_qsl(parsed[4])) - query[name] = value - parsed[4] = urllib.parse.urlencode(query) - return urllib.parse.urlunparse(parsed) + return update_query_params(url, {name: value}) def validate_file(filename): diff --git a/oauth2client/client.py b/oauth2client/client.py index 0497d07..704c610 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -438,23 +438,6 @@ class Storage(object): self.release_lock() -def _update_query_params(uri, params): - """Updates a URI with new query parameters. - - Args: - uri: string, A valid URI, with potential existing query parameters. - params: dict, A dictionary of query parameters. - - Returns: - The same URI but with the new query parameters added. - """ - parts = urllib.parse.urlparse(uri) - query_params = dict(urllib.parse.parse_qsl(parts.query)) - query_params.update(params) - new_parts = parts._replace(query=urllib.parse.urlencode(query_params)) - return urllib.parse.urlunparse(new_parts) - - class OAuth2Credentials(Credentials): """Credentials object for OAuth 2.0. @@ -850,7 +833,8 @@ class OAuth2Credentials(Credentials): """ logger.info('Revoking token') query_params = {'token': token} - token_revoke_uri = _update_query_params(self.revoke_uri, query_params) + token_revoke_uri = _helpers.update_query_params( + self.revoke_uri, query_params) resp, content = transport.request(http, token_revoke_uri) if resp.status == http_client.OK: self.invalid = True @@ -889,8 +873,8 @@ class OAuth2Credentials(Credentials): """ logger.info('Refreshing scopes') query_params = {'access_token': token, 'fields': 'scope'} - token_info_uri = _update_query_params(self.token_info_uri, - query_params) + token_info_uri = _helpers.update_query_params( + self.token_info_uri, query_params) resp, content = transport.request(http, token_info_uri) content = _helpers._from_bytes(content) if resp.status == http_client.OK: @@ -1610,7 +1594,7 @@ def _parse_exchange_token_response(content): except Exception: # different JSON libs raise different exceptions, # so we just do a catch-all here - resp = dict(urllib.parse.parse_qsl(content)) + resp = _helpers.parse_unique_urlencoded(content) # some providers respond with 'expires', others with 'expires_in' if resp and 'expires' in resp: @@ -1943,7 +1927,7 @@ class OAuth2WebServerFlow(Flow): query_params['code_challenge_method'] = 'S256' query_params.update(self.params) - return _update_query_params(self.auth_uri, query_params) + return _helpers.update_query_params(self.auth_uri, query_params) @_helpers.positional(1) def step1_get_device_and_user_codes(self, http=None): diff --git a/oauth2client/tools.py b/oauth2client/tools.py index 0aa671b..b882429 100644 --- a/oauth2client/tools.py +++ b/oauth2client/tools.py @@ -122,16 +122,16 @@ class ClientRedirectHandler(BaseHTTPServer.BaseHTTPRequestHandler): if an error occurred. """ self.send_response(http_client.OK) - self.send_header("Content-type", "text/html") + self.send_header('Content-type', 'text/html') self.end_headers() - query = self.path.split('?', 1)[-1] - query = dict(urllib.parse.parse_qsl(query)) + parts = urllib.parse.urlparse(self.path) + query = _helpers.parse_unique_urlencoded(parts.query) self.server.query_params = query self.wfile.write( - b"<html><head><title>Authentication Status</title></head>") + b'<html><head><title>Authentication Status</title></head>') self.wfile.write( - b"<body><p>The authentication flow has completed.</p>") - self.wfile.write(b"</body></html>") + b'<body><p>The authentication flow has completed.</p>') + self.wfile.write(b'</body></html>') def log_message(self, format, *args): """Do not log messages to stdout while running as cmd. line program.""" diff --git a/tests/test__helpers.py b/tests/test__helpers.py index aac5f8d..00cd38a 100644 --- a/tests/test__helpers.py +++ b/tests/test__helpers.py @@ -19,6 +19,7 @@ import unittest import mock from oauth2client import _helpers +from tests import test_client class PositionalTests(unittest.TestCase): @@ -242,3 +243,42 @@ class Test__urlsafe_b64decode(unittest.TestCase): bad_string = b'+' with self.assertRaises((TypeError, binascii.Error)): _helpers._urlsafe_b64decode(bad_string) + + +class Test_update_query_params(unittest.TestCase): + + def test_update_query_params_no_params(self): + uri = 'http://www.google.com' + updated = _helpers.update_query_params(uri, {'a': 'b'}) + self.assertEqual(updated, uri + '?a=b') + + def test_update_query_params_existing_params(self): + uri = 'http://www.google.com?x=y' + updated = _helpers.update_query_params(uri, {'a': 'b', 'c': 'd&'}) + hardcoded_update = uri + '&a=b&c=d%26' + test_client.assertUrisEqual(self, updated, hardcoded_update) + + def test_update_query_params_replace_param(self): + base_uri = 'http://www.google.com' + uri = base_uri + '?x=a' + updated = _helpers.update_query_params(uri, {'x': 'b', 'y': 'c'}) + hardcoded_update = base_uri + '?x=b&y=c' + test_client.assertUrisEqual(self, updated, hardcoded_update) + + def test_update_query_params_repeated_params(self): + uri = 'http://www.google.com?x=a&x=b' + with self.assertRaises(ValueError): + _helpers.update_query_params(uri, {'a': 'c'}) + + +class Test_parse_unique_urlencoded(unittest.TestCase): + + def test_without_repeats(self): + content = 'a=b&c=d' + result = _helpers.parse_unique_urlencoded(content) + self.assertEqual(result, {'a': 'b', 'c': 'd'}) + + def test_with_repeats(self): + content = 'a=b&a=d' + with self.assertRaises(ValueError): + _helpers.parse_unique_urlencoded(content) diff --git a/tests/test_client.py b/tests/test_client.py index dbe11eb..a3268ba 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1364,7 +1364,7 @@ class BasicCredentialsTests(unittest.TestCase): self.assertEqual(credentials.scopes, set()) self.assertEqual(exc_manager.exception.args, (error_msg,)) - token_uri = client._update_query_params( + token_uri = _helpers.update_query_params( oauth2client.GOOGLE_TOKEN_INFO_URI, {'fields': 'scope', 'access_token': token}) @@ -1558,19 +1558,6 @@ class TestAssertionCredentials(unittest.TestCase): credentials.sign_blob(b'blob') -class UpdateQueryParamsTest(unittest.TestCase): - def test_update_query_params_no_params(self): - uri = 'http://www.google.com' - updated = client._update_query_params(uri, {'a': 'b'}) - self.assertEqual(updated, uri + '?a=b') - - def test_update_query_params_existing_params(self): - uri = 'http://www.google.com?x=y' - updated = client._update_query_params(uri, {'a': 'b', 'c': 'd&'}) - hardcoded_update = uri + '&a=b&c=d%26' - assertUrisEqual(self, updated, hardcoded_update) - - class ExtractIdTokenTest(unittest.TestCase): """Tests client._extract_id_token().""" @@ -1670,7 +1657,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase): 'access_type': 'offline', 'response_type': 'code', } - expected = client._update_query_params(flow.auth_uri, query_params) + expected = _helpers.update_query_params(flow.auth_uri, query_params) assertUrisEqual(self, expected, result) # Check stubs. self.assertEqual(logger.warning.call_count, 1) @@ -1735,7 +1722,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase): 'access_type': 'offline', 'response_type': 'code', } - expected = client._update_query_params(flow.auth_uri, query_params) + expected = _helpers.update_query_params(flow.auth_uri, query_params) assertUrisEqual(self, expected, result) def test_step1_get_device_and_user_codes_wo_device_uri(self): |