aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDanny Hermes <daniel.j.hermes@gmail.com>2016-08-17 15:59:57 -0700
committerGitHub <noreply@github.com>2016-08-17 15:59:57 -0700
commit51ae8761eaf5ec044aa5ab221020b05be8fe2a71 (patch)
tree4bfe3e9c6ccd639ae4a598456daa0ec437e0221c
parent4c7b3be5a101454e2c641a9835e652a92d16800e (diff)
parentebe9ed0bbbe4ce51c1a76de694c795e38906d690 (diff)
downloadoauth2client-51ae8761eaf5ec044aa5ab221020b05be8fe2a71.tar.gz
Merge pull request #622 from dhermes/allow-repeated-params
Correct query loss when using parse_qsl to dict
-rw-r--r--oauth2client/_helpers.py54
-rw-r--r--oauth2client/client.py28
-rw-r--r--oauth2client/tools.py12
-rw-r--r--tests/test__helpers.py40
-rw-r--r--tests/test_client.py19
5 files changed, 104 insertions, 49 deletions
diff --git a/oauth2client/_helpers.py b/oauth2client/_helpers.py
index 79586a5..e912397 100644
--- a/oauth2client/_helpers.py
+++ b/oauth2client/_helpers.py
@@ -179,6 +179,54 @@ def string_to_scopes(scopes):
return scopes
+def parse_unique_urlencoded(content):
+ """Parses unique key-value parameters from urlencoded content.
+
+ Args:
+ content: string, URL-encoded key-value pairs.
+
+ Returns:
+ dict, The key-value pairs from ``content``.
+
+ Raises:
+ ValueError: if one of the keys is repeated.
+ """
+ urlencoded_params = urllib.parse.parse_qs(content)
+ params = {}
+ for key, value in six.iteritems(urlencoded_params):
+ if len(value) != 1:
+ msg = ('URL-encoded content contains a repeated value:'
+ '%s -> %s' % (key, ', '.join(value)))
+ raise ValueError(msg)
+ params[key] = value[0]
+ return params
+
+
+def update_query_params(uri, params):
+ """Updates a URI with new query parameters.
+
+ If a given key from ``params`` is repeated in the ``uri``, then
+ the URI will be considered invalid and an error will occur.
+
+ If the URI is valid, then each value from ``params`` will
+ replace the corresponding value in the query parameters (if
+ it exists).
+
+ Args:
+ uri: string, A valid URI, with potential existing query parameters.
+ params: dict, A dictionary of query parameters.
+
+ Returns:
+ The same URI but with the new query parameters added.
+ """
+ parts = urllib.parse.urlparse(uri)
+ query_params = parse_unique_urlencoded(parts.query)
+ query_params.update(params)
+ new_query = urllib.parse.urlencode(query_params)
+ new_parts = parts._replace(query=new_query)
+ return urllib.parse.urlunparse(new_parts)
+
+
def _add_query_parameter(url, name, value):
"""Adds a query parameter to a url.
@@ -195,11 +243,7 @@ def _add_query_parameter(url, name, value):
if value is None:
return url
else:
- parsed = list(urllib.parse.urlparse(url))
- query = dict(urllib.parse.parse_qsl(parsed[4]))
- query[name] = value
- parsed[4] = urllib.parse.urlencode(query)
- return urllib.parse.urlunparse(parsed)
+ return update_query_params(url, {name: value})
def validate_file(filename):
diff --git a/oauth2client/client.py b/oauth2client/client.py
index 0497d07..704c610 100644
--- a/oauth2client/client.py
+++ b/oauth2client/client.py
@@ -438,23 +438,6 @@ class Storage(object):
self.release_lock()
-def _update_query_params(uri, params):
- """Updates a URI with new query parameters.
-
- Args:
- uri: string, A valid URI, with potential existing query parameters.
- params: dict, A dictionary of query parameters.
-
- Returns:
- The same URI but with the new query parameters added.
- """
- parts = urllib.parse.urlparse(uri)
- query_params = dict(urllib.parse.parse_qsl(parts.query))
- query_params.update(params)
- new_parts = parts._replace(query=urllib.parse.urlencode(query_params))
- return urllib.parse.urlunparse(new_parts)
-
-
class OAuth2Credentials(Credentials):
"""Credentials object for OAuth 2.0.
@@ -850,7 +833,8 @@ class OAuth2Credentials(Credentials):
"""
logger.info('Revoking token')
query_params = {'token': token}
- token_revoke_uri = _update_query_params(self.revoke_uri, query_params)
+ token_revoke_uri = _helpers.update_query_params(
+ self.revoke_uri, query_params)
resp, content = transport.request(http, token_revoke_uri)
if resp.status == http_client.OK:
self.invalid = True
@@ -889,8 +873,8 @@ class OAuth2Credentials(Credentials):
"""
logger.info('Refreshing scopes')
query_params = {'access_token': token, 'fields': 'scope'}
- token_info_uri = _update_query_params(self.token_info_uri,
- query_params)
+ token_info_uri = _helpers.update_query_params(
+ self.token_info_uri, query_params)
resp, content = transport.request(http, token_info_uri)
content = _helpers._from_bytes(content)
if resp.status == http_client.OK:
@@ -1610,7 +1594,7 @@ def _parse_exchange_token_response(content):
except Exception:
# different JSON libs raise different exceptions,
# so we just do a catch-all here
- resp = dict(urllib.parse.parse_qsl(content))
+ resp = _helpers.parse_unique_urlencoded(content)
# some providers respond with 'expires', others with 'expires_in'
if resp and 'expires' in resp:
@@ -1943,7 +1927,7 @@ class OAuth2WebServerFlow(Flow):
query_params['code_challenge_method'] = 'S256'
query_params.update(self.params)
- return _update_query_params(self.auth_uri, query_params)
+ return _helpers.update_query_params(self.auth_uri, query_params)
@_helpers.positional(1)
def step1_get_device_and_user_codes(self, http=None):
diff --git a/oauth2client/tools.py b/oauth2client/tools.py
index 0aa671b..b882429 100644
--- a/oauth2client/tools.py
+++ b/oauth2client/tools.py
@@ -122,16 +122,16 @@ class ClientRedirectHandler(BaseHTTPServer.BaseHTTPRequestHandler):
if an error occurred.
"""
self.send_response(http_client.OK)
- self.send_header("Content-type", "text/html")
+ self.send_header('Content-type', 'text/html')
self.end_headers()
- query = self.path.split('?', 1)[-1]
- query = dict(urllib.parse.parse_qsl(query))
+ parts = urllib.parse.urlparse(self.path)
+ query = _helpers.parse_unique_urlencoded(parts.query)
self.server.query_params = query
self.wfile.write(
- b"<html><head><title>Authentication Status</title></head>")
+ b'<html><head><title>Authentication Status</title></head>')
self.wfile.write(
- b"<body><p>The authentication flow has completed.</p>")
- self.wfile.write(b"</body></html>")
+ b'<body><p>The authentication flow has completed.</p>')
+ self.wfile.write(b'</body></html>')
def log_message(self, format, *args):
"""Do not log messages to stdout while running as cmd. line program."""
diff --git a/tests/test__helpers.py b/tests/test__helpers.py
index aac5f8d..00cd38a 100644
--- a/tests/test__helpers.py
+++ b/tests/test__helpers.py
@@ -19,6 +19,7 @@ import unittest
import mock
from oauth2client import _helpers
+from tests import test_client
class PositionalTests(unittest.TestCase):
@@ -242,3 +243,42 @@ class Test__urlsafe_b64decode(unittest.TestCase):
bad_string = b'+'
with self.assertRaises((TypeError, binascii.Error)):
_helpers._urlsafe_b64decode(bad_string)
+
+
+class Test_update_query_params(unittest.TestCase):
+
+ def test_update_query_params_no_params(self):
+ uri = 'http://www.google.com'
+ updated = _helpers.update_query_params(uri, {'a': 'b'})
+ self.assertEqual(updated, uri + '?a=b')
+
+ def test_update_query_params_existing_params(self):
+ uri = 'http://www.google.com?x=y'
+ updated = _helpers.update_query_params(uri, {'a': 'b', 'c': 'd&'})
+ hardcoded_update = uri + '&a=b&c=d%26'
+ test_client.assertUrisEqual(self, updated, hardcoded_update)
+
+ def test_update_query_params_replace_param(self):
+ base_uri = 'http://www.google.com'
+ uri = base_uri + '?x=a'
+ updated = _helpers.update_query_params(uri, {'x': 'b', 'y': 'c'})
+ hardcoded_update = base_uri + '?x=b&y=c'
+ test_client.assertUrisEqual(self, updated, hardcoded_update)
+
+ def test_update_query_params_repeated_params(self):
+ uri = 'http://www.google.com?x=a&x=b'
+ with self.assertRaises(ValueError):
+ _helpers.update_query_params(uri, {'a': 'c'})
+
+
+class Test_parse_unique_urlencoded(unittest.TestCase):
+
+ def test_without_repeats(self):
+ content = 'a=b&c=d'
+ result = _helpers.parse_unique_urlencoded(content)
+ self.assertEqual(result, {'a': 'b', 'c': 'd'})
+
+ def test_with_repeats(self):
+ content = 'a=b&a=d'
+ with self.assertRaises(ValueError):
+ _helpers.parse_unique_urlencoded(content)
diff --git a/tests/test_client.py b/tests/test_client.py
index dbe11eb..a3268ba 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -1364,7 +1364,7 @@ class BasicCredentialsTests(unittest.TestCase):
self.assertEqual(credentials.scopes, set())
self.assertEqual(exc_manager.exception.args, (error_msg,))
- token_uri = client._update_query_params(
+ token_uri = _helpers.update_query_params(
oauth2client.GOOGLE_TOKEN_INFO_URI,
{'fields': 'scope', 'access_token': token})
@@ -1558,19 +1558,6 @@ class TestAssertionCredentials(unittest.TestCase):
credentials.sign_blob(b'blob')
-class UpdateQueryParamsTest(unittest.TestCase):
- def test_update_query_params_no_params(self):
- uri = 'http://www.google.com'
- updated = client._update_query_params(uri, {'a': 'b'})
- self.assertEqual(updated, uri + '?a=b')
-
- def test_update_query_params_existing_params(self):
- uri = 'http://www.google.com?x=y'
- updated = client._update_query_params(uri, {'a': 'b', 'c': 'd&'})
- hardcoded_update = uri + '&a=b&c=d%26'
- assertUrisEqual(self, updated, hardcoded_update)
-
-
class ExtractIdTokenTest(unittest.TestCase):
"""Tests client._extract_id_token()."""
@@ -1670,7 +1657,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
'access_type': 'offline',
'response_type': 'code',
}
- expected = client._update_query_params(flow.auth_uri, query_params)
+ expected = _helpers.update_query_params(flow.auth_uri, query_params)
assertUrisEqual(self, expected, result)
# Check stubs.
self.assertEqual(logger.warning.call_count, 1)
@@ -1735,7 +1722,7 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
'access_type': 'offline',
'response_type': 'code',
}
- expected = client._update_query_params(flow.auth_uri, query_params)
+ expected = _helpers.update_query_params(flow.auth_uri, query_params)
assertUrisEqual(self, expected, result)
def test_step1_get_device_and_user_codes_wo_device_uri(self):