Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 3 additions & 13 deletions stacklet/client/sinistral/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pathlib import Path

from stacklet.client.sinistral.cognito import CognitoClientAuth, CognitoUserManager
from stacklet.client.sinistral.cognito import CognitoUserManager
from stacklet.client.sinistral.commands import commands
from stacklet.client.sinistral.config import StackletConfig
from stacklet.client.sinistral.context import StackletContext
Expand Down Expand Up @@ -149,22 +149,12 @@ def login(ctx, username, password, *args, **kwargs):

# Otherwise, prefer non-interactive auth to interactive.
if context.can_project_auth():
client = CognitoClientAuth(ctx)
token = client.get_access_token(
context.config.auth_url,
context.config.project_client_id,
context.config.project_client_secret,
)
token = context.do_project_auth()
context.write_access_token(token)
return

if context.can_org_auth():
client = CognitoClientAuth(ctx)
token = client.get_access_token(
context.config.auth_url,
context.config.org_client_id,
context.config.org_client_secret,
)
token = context.do_org_auth()
context.write_access_token(token)
return

Expand Down
2 changes: 2 additions & 0 deletions stacklet/client/sinistral/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def make_request(
):
with self.ctx as context:
token = context.get_access_token()
if not token:
raise Exception("Unauthorized, check credentials")
executor = RestExecutor(context, token)
func = getattr(executor, method)
if isinstance(_json, str):
Expand Down
8 changes: 6 additions & 2 deletions stacklet/client/sinistral/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,14 @@ def run(ctx, project, dryrun, *args, **kwargs):
policy_collections_client = sinistral.client("policy-collections")

results = []
project_data = projects_client.get(name=SinistralFormat.project)
try:
project_data = projects_client.get(name=SinistralFormat.project)
except Exception as e:
click.echo(f"Unable to get project: {e}", err=True)
sys.exit(1)

if not project_data.get("collections"):
click.echo("Project has no policy collections")
click.echo("Project has no policy collections", err=True)
sys.exit(1)

for c in project_data["collections"]:
Expand Down
37 changes: 34 additions & 3 deletions stacklet/client/sinistral/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import click

from stacklet.client.sinistral.cognito import CognitoClientAuth


class StackletContext:
"""
Expand All @@ -13,6 +15,9 @@ class StackletContext:
CREDENTIALS_FILE = "credentials"
ID_FILE = "id"

# cache access token
_ACCESS_TOKEN = None

def __init__(self, click_context):
self._click_context = click_context
self.config = click_context.obj["config"]
Expand All @@ -24,16 +29,24 @@ def __init__(self, click_context):
self.id_token_path = base_path / self.ID_FILE

def get_access_token(self):
if StackletContext._ACCESS_TOKEN:
return StackletContext._ACCESS_TOKEN
access_token = None
if self.access_token_path.exists():
return self.access_token_path.read_text()
else:
return None
access_token = self.access_token_path.read_text()
elif self.can_project_auth():
access_token = self.do_project_auth()
elif self.can_org_auth():
access_token = self.do_org_auth()
StackletContext._ACCESS_TOKEN = access_token
return access_token

def _write_token(self, token_path, token):
token_path.parent.mkdir(parents=True, exist_ok=True)
token_path.write_text(token)

def write_access_token(self, token):
StackletContext._ACCESS_TOKEN = token
self._write_token(self.access_token_path, token)

def get_id_token(self):
Expand Down Expand Up @@ -76,6 +89,15 @@ def can_project_auth(self):
]
)

def do_project_auth(self) -> str:
client = CognitoClientAuth(self._click_context)
token = client.get_access_token(
self.config.auth_url,
self.config.project_client_id,
self.config.project_client_secret,
)
return token

def can_org_auth(self):
return all(
[
Expand All @@ -84,6 +106,15 @@ def can_org_auth(self):
]
)

def do_org_auth(self) -> str:
client = CognitoClientAuth(self._click_context)
token = client.get_access_token(
self.config.auth_url,
self.config.org_client_id,
self.config.org_client_secret,
)
return token


class StackletCredentialWriter:
def write_id_token(self, token):
Expand Down
70 changes: 69 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
sinistral_client,
)
from stacklet.client.sinistral.executor import RestExecutor
from stacklet.client.sinistral.context import StackletContext


@pytest.fixture(autouse=True, scope="module")
Expand All @@ -24,11 +25,56 @@ def mock_current_context():
ctx().obj = {
"output": "raw",
"formatter": None,
"config": MagicMock(),
"config": MagicMock(
project_client_id=None,
project_client_secret=None,
org_client_id=None,
org_client_secret=None,
),
}
yield ctx


@pytest.fixture(autouse=True, scope="module")
def mock_access_token():
with patch.object(StackletContext, "_ACCESS_TOKEN", new="token"):
yield


@pytest.fixture()
def empty_access_token():
with patch.object(StackletContext, "_ACCESS_TOKEN", new=None):
yield


@pytest.fixture()
def mock_project_creds(mock_current_context):
mock_config = MagicMock(
project_client_id="client-id",
project_client_secret="client-secret",
org_client_id=None,
org_client_secret=None,
)
with patch.dict(mock_current_context().obj, config=mock_config):
with patch.object(StackletContext, "do_project_auth", return_value="token"):
with patch.object(StackletContext, "do_org_auth", return_value=None):
yield


@pytest.fixture()
def mock_org_creds(mock_current_context):
mock_config = MagicMock(
project_client_id=None,
project_client_secret=None,
org_client_id="client-id",
org_client_secret="client-secret",
)
with patch.dict(mock_current_context().obj, config=mock_config):
with patch.object(StackletContext, "do_project_auth", return_value=None):
with patch.object(StackletContext, "do_org_auth", return_value="token"):
yield


sample_schema = {
"$id": "https://example.com/person.schema.json",
"$schema": "https://json-schema.org/draft/2020-12/schema",
Expand Down Expand Up @@ -209,6 +255,28 @@ def test_client_command_unauthorized():
client.delete(name="foo")


def test_client_command_auto_project_auth(empty_access_token, mock_project_creds):
client = sinistral_client().client("projects")
with patch.object(
RestExecutor, "get", return_value=get_mock_response(json={"foo": "bar"})
):
res = client.list()
assert res == {"foo": "bar"}
assert StackletContext.do_project_auth.called
assert not StackletContext.do_org_auth.called


def test_client_command_auto_org_auth(empty_access_token, mock_org_creds):
client = sinistral_client().client("projects")
with patch.object(
RestExecutor, "get", return_value=get_mock_response(json={"foo": "bar"})
):
res = client.list()
assert res == {"foo": "bar"}
assert not StackletContext.do_project_auth.called
assert StackletContext.do_org_auth.called


def test_client_command_other_error():
client = sinistral_client().client("projects")
with patch.object(
Expand Down