diff options
author | Pat Ferate <pferate@users.noreply.github.com> | 2016-08-12 14:15:31 -0700 |
---|---|---|
committer | Jon Wayne Parrott <jonwayne@google.com> | 2016-08-12 14:15:31 -0700 |
commit | 5137d7e8377266ef4beffe1c59c638c05c82cf10 (patch) | |
tree | 0d3744551ca792253d184e8f65a9721d29e434ec | |
parent | c9b4b07525730338f2e560981b3fbe295d2146ab (diff) | |
download | oauth2client-5137d7e8377266ef4beffe1c59c638c05c82cf10.tar.gz |
Complete branches from partial test coverages (#629)
-rw-r--r-- | tests/contrib/test_sqlalchemy.py | 26 | ||||
-rw-r--r-- | tests/test_client.py | 59 |
2 files changed, 66 insertions, 19 deletions
diff --git a/tests/contrib/test_sqlalchemy.py b/tests/contrib/test_sqlalchemy.py index 67762f6..068aa92 100644 --- a/tests/contrib/test_sqlalchemy.py +++ b/tests/contrib/test_sqlalchemy.py @@ -15,6 +15,7 @@ import datetime import unittest +import mock import sqlalchemy import sqlalchemy.ext.declarative import sqlalchemy.orm @@ -66,7 +67,8 @@ class TestSQLAlchemyStorage(unittest.TestCase): self.assertEqual(result.token_uri, self.credentials.token_uri) self.assertEqual(result.user_agent, self.credentials.user_agent) - def test_get(self): + @mock.patch('oauth2client.client.OAuth2Credentials.set_store') + def test_get(self, set_store): session = self.session() credentials_storage = oauth2client.contrib.sqlalchemy.Storage( session=session, @@ -75,7 +77,21 @@ class TestSQLAlchemyStorage(unittest.TestCase): key_value=1, property_name='credentials', ) + # No credentials stored self.assertIsNone(credentials_storage.get()) + + # Invalid credentials stored + session.add(DummyModel( + key=1, + credentials=oauth2client.client.Credentials(), + )) + session.commit() + bad_credentials = credentials_storage.get() + self.assertIsInstance(bad_credentials, oauth2client.client.Credentials) + set_store.assert_not_called() + + # Valid credentials stored + session.query(DummyModel).filter_by(key=1).delete() session.add(DummyModel( key=1, credentials=self.credentials, @@ -83,16 +99,20 @@ class TestSQLAlchemyStorage(unittest.TestCase): session.commit() self.compare_credentials(credentials_storage.get()) + set_store.assert_called_with(credentials_storage) def test_put(self): session = self.session() - oauth2client.contrib.sqlalchemy.Storage( + storage = oauth2client.contrib.sqlalchemy.Storage( session=session, model_class=DummyModel, key_name='key', key_value=1, property_name='credentials', - ).put(self.credentials) + ) + # Store invalid credentials first to verify overwriting + storage.put(oauth2client.client.Credentials()) + storage.put(self.credentials) session.commit() entity = session.query(DummyModel).filter_by(key=1).first() diff --git a/tests/test_client.py b/tests/test_client.py index 49a9210..27f24d8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1619,6 +1619,9 @@ class OAuth2WebServerFlowTest(unittest.TestCase): user_agent='unittest-sample/1.0', revoke_uri='dummy_revoke_uri', ) + self.bad_verifier = b'__NOT_THE_VERIFIER_YOURE_LOOKING_FOR__' + self.good_verifier = b'__TEST_VERIFIER__' + self.good_challenger = b'__TEST_CHALLENGE__' def test_construct_authorize_url(self): authorize_url = self.flow.step1_get_authorize_url(state='state+1') @@ -1691,19 +1694,42 @@ class OAuth2WebServerFlowTest(unittest.TestCase): @mock.patch('oauth2client.client._pkce.code_challenge') @mock.patch('oauth2client.client._pkce.code_verifier') def test_step1_get_authorize_url_pkce(self, fake_verifier, fake_challenge): - fake_verifier.return_value = b'__TEST_VERIFIER__' - fake_challenge.return_value = b'__TEST_CHALLENGE__' + fake_verifier.return_value = self.good_verifier + fake_challenge.return_value = self.good_challenger flow = client.OAuth2WebServerFlow( - 'client_id+1', - scope='foo', - redirect_uri='http://example.com', - pkce=True) + 'client_id+1', + scope='foo', + redirect_uri='http://example.com', + pkce=True) + auth_url = urllib.parse.urlparse(flow.step1_get_authorize_url()) + self.assertEqual(flow.code_verifier, self.good_verifier) + results = dict(urllib.parse.parse_qsl(auth_url.query)) + self.assertEqual( + results['code_challenge'], self.good_challenger.decode()) + self.assertEqual(results['code_challenge_method'], 'S256') + fake_verifier.assert_called() + fake_challenge.assert_called_with(self.good_verifier) + + @mock.patch('oauth2client.client._pkce.code_challenge') + @mock.patch('oauth2client.client._pkce.code_verifier') + def test_step1_get_authorize_url_pkce_invalid_verifier( + self, fake_verifier, fake_challenge): + fake_verifier.return_value = self.good_verifier + fake_challenge.return_value = self.good_challenger + flow = client.OAuth2WebServerFlow( + 'client_id+1', + scope='foo', + redirect_uri='http://example.com', + pkce=True, + code_verifier=self.bad_verifier) auth_url = urllib.parse.urlparse(flow.step1_get_authorize_url()) - self.assertEqual(flow.code_verifier, b'__TEST_VERIFIER__') + self.assertEqual(flow.code_verifier, self.bad_verifier) results = dict(urllib.parse.parse_qsl(auth_url.query)) - self.assertEqual(results['code_challenge'], '__TEST_CHALLENGE__') + self.assertEqual( + results['code_challenge'], self.good_challenger.decode()) self.assertEqual(results['code_challenge_method'], 'S256') - fake_challenge.assert_called_with(b'__TEST_VERIFIER__') + fake_verifier.assert_not_called() + fake_challenge.assert_called_with(self.bad_verifier) def test_step1_get_authorize_url_without_redirect(self): flow = client.OAuth2WebServerFlow('client_id+1', scope='foo', @@ -1955,17 +1981,18 @@ class OAuth2WebServerFlowTest(unittest.TestCase): ({'status': http_client.OK}, b'access_token=SlAV32hkKG'), ]) flow = client.OAuth2WebServerFlow( - 'client_id+1', - scope='foo', - redirect_uri='http://example.com', - pkce=True, - code_verifier=b'__TEST_VERIFIER__' - ) + 'client_id+1', + scope='foo', + redirect_uri='http://example.com', + pkce=True, + code_verifier=self.good_verifier) flow.step2_exchange(code='some random code', http=http) self.assertEqual(len(http.requests), 1) test_request = http.requests[0] - self.assertIn('code_verifier=__TEST_VERIFIER__', test_request['body']) + self.assertIn( + 'code_verifier={0}'.format(self.good_verifier.decode()), + test_request['body']) def test_exchange_using_authorization_header(self): auth_header = 'Basic Y2xpZW50X2lkKzE6c2Vjexc_managerV0KzE=', |