diff --git a/stacklet/client/sinistral/cli.py b/stacklet/client/sinistral/cli.py index 657ada3..ed0e8e9 100644 --- a/stacklet/client/sinistral/cli.py +++ b/stacklet/client/sinistral/cli.py @@ -5,7 +5,7 @@ from pathlib import Path -from stacklet.client.sinistral.cognito import CognitoClientAuth, CognitoUserManager +from stacklet.client.sinistral.cognito import CognitoUserManager from stacklet.client.sinistral.commands import commands from stacklet.client.sinistral.config import StackletConfig from stacklet.client.sinistral.context import StackletContext @@ -149,22 +149,12 @@ def login(ctx, username, password, *args, **kwargs): # Otherwise, prefer non-interactive auth to interactive. if context.can_project_auth(): - client = CognitoClientAuth(ctx) - token = client.get_access_token( - context.config.auth_url, - context.config.project_client_id, - context.config.project_client_secret, - ) + token = context.do_project_auth() context.write_access_token(token) return if context.can_org_auth(): - client = CognitoClientAuth(ctx) - token = client.get_access_token( - context.config.auth_url, - context.config.org_client_id, - context.config.org_client_secret, - ) + token = context.do_org_auth() context.write_access_token(token) return diff --git a/stacklet/client/sinistral/client.py b/stacklet/client/sinistral/client.py index 84b71fb..20f5300 100644 --- a/stacklet/client/sinistral/client.py +++ b/stacklet/client/sinistral/client.py @@ -182,6 +182,8 @@ def make_request( ): with self.ctx as context: token = context.get_access_token() + if not token: + raise Exception("Unauthorized, check credentials") executor = RestExecutor(context, token) func = getattr(executor, method) if isinstance(_json, str): diff --git a/stacklet/client/sinistral/commands/run.py b/stacklet/client/sinistral/commands/run.py index c38a3b5..1dadf65 100644 --- a/stacklet/client/sinistral/commands/run.py +++ b/stacklet/client/sinistral/commands/run.py @@ -74,10 +74,14 @@ def run(ctx, project, dryrun, *args, **kwargs): policy_collections_client = sinistral.client("policy-collections") results = [] - project_data = projects_client.get(name=SinistralFormat.project) + try: + project_data = projects_client.get(name=SinistralFormat.project) + except Exception as e: + click.echo(f"Unable to get project: {e}", err=True) + sys.exit(1) if not project_data.get("collections"): - click.echo("Project has no policy collections") + click.echo("Project has no policy collections", err=True) sys.exit(1) for c in project_data["collections"]: diff --git a/stacklet/client/sinistral/context.py b/stacklet/client/sinistral/context.py index baad325..869d5b4 100644 --- a/stacklet/client/sinistral/context.py +++ b/stacklet/client/sinistral/context.py @@ -4,6 +4,8 @@ import click +from stacklet.client.sinistral.cognito import CognitoClientAuth + class StackletContext: """ @@ -13,6 +15,9 @@ class StackletContext: CREDENTIALS_FILE = "credentials" ID_FILE = "id" + # cache access token + _ACCESS_TOKEN = None + def __init__(self, click_context): self._click_context = click_context self.config = click_context.obj["config"] @@ -24,16 +29,24 @@ def __init__(self, click_context): self.id_token_path = base_path / self.ID_FILE def get_access_token(self): + if StackletContext._ACCESS_TOKEN: + return StackletContext._ACCESS_TOKEN + access_token = None if self.access_token_path.exists(): - return self.access_token_path.read_text() - else: - return None + access_token = self.access_token_path.read_text() + elif self.can_project_auth(): + access_token = self.do_project_auth() + elif self.can_org_auth(): + access_token = self.do_org_auth() + StackletContext._ACCESS_TOKEN = access_token + return access_token def _write_token(self, token_path, token): token_path.parent.mkdir(parents=True, exist_ok=True) token_path.write_text(token) def write_access_token(self, token): + StackletContext._ACCESS_TOKEN = token self._write_token(self.access_token_path, token) def get_id_token(self): @@ -76,6 +89,15 @@ def can_project_auth(self): ] ) + def do_project_auth(self) -> str: + client = CognitoClientAuth(self._click_context) + token = client.get_access_token( + self.config.auth_url, + self.config.project_client_id, + self.config.project_client_secret, + ) + return token + def can_org_auth(self): return all( [ @@ -84,6 +106,15 @@ def can_org_auth(self): ] ) + def do_org_auth(self) -> str: + client = CognitoClientAuth(self._click_context) + token = client.get_access_token( + self.config.auth_url, + self.config.org_client_id, + self.config.org_client_secret, + ) + return token + class StackletCredentialWriter: def write_id_token(self, token): diff --git a/tests/test_client.py b/tests/test_client.py index 72cfce6..75d2e6c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -16,6 +16,7 @@ sinistral_client, ) from stacklet.client.sinistral.executor import RestExecutor +from stacklet.client.sinistral.context import StackletContext @pytest.fixture(autouse=True, scope="module") @@ -24,11 +25,56 @@ def mock_current_context(): ctx().obj = { "output": "raw", "formatter": None, - "config": MagicMock(), + "config": MagicMock( + project_client_id=None, + project_client_secret=None, + org_client_id=None, + org_client_secret=None, + ), } yield ctx +@pytest.fixture(autouse=True, scope="module") +def mock_access_token(): + with patch.object(StackletContext, "_ACCESS_TOKEN", new="token"): + yield + + +@pytest.fixture() +def empty_access_token(): + with patch.object(StackletContext, "_ACCESS_TOKEN", new=None): + yield + + +@pytest.fixture() +def mock_project_creds(mock_current_context): + mock_config = MagicMock( + project_client_id="client-id", + project_client_secret="client-secret", + org_client_id=None, + org_client_secret=None, + ) + with patch.dict(mock_current_context().obj, config=mock_config): + with patch.object(StackletContext, "do_project_auth", return_value="token"): + with patch.object(StackletContext, "do_org_auth", return_value=None): + yield + + +@pytest.fixture() +def mock_org_creds(mock_current_context): + mock_config = MagicMock( + project_client_id=None, + project_client_secret=None, + org_client_id="client-id", + org_client_secret="client-secret", + ) + with patch.dict(mock_current_context().obj, config=mock_config): + with patch.object(StackletContext, "do_project_auth", return_value=None): + with patch.object(StackletContext, "do_org_auth", return_value="token"): + yield + + sample_schema = { "$id": "https://example.com/person.schema.json", "$schema": "https://json-schema.org/draft/2020-12/schema", @@ -209,6 +255,28 @@ def test_client_command_unauthorized(): client.delete(name="foo") +def test_client_command_auto_project_auth(empty_access_token, mock_project_creds): + client = sinistral_client().client("projects") + with patch.object( + RestExecutor, "get", return_value=get_mock_response(json={"foo": "bar"}) + ): + res = client.list() + assert res == {"foo": "bar"} + assert StackletContext.do_project_auth.called + assert not StackletContext.do_org_auth.called + + +def test_client_command_auto_org_auth(empty_access_token, mock_org_creds): + client = sinistral_client().client("projects") + with patch.object( + RestExecutor, "get", return_value=get_mock_response(json={"foo": "bar"}) + ): + res = client.list() + assert res == {"foo": "bar"} + assert not StackletContext.do_project_auth.called + assert StackletContext.do_org_auth.called + + def test_client_command_other_error(): client = sinistral_client().client("projects") with patch.object(