Skip to content

Commit 0be72e2

Browse files
authored
feat: type check parameter classes (#3406)
1 parent f103af4 commit 0be72e2

File tree

2 files changed

+38
-28
lines changed

2 files changed

+38
-28
lines changed

luigi/parameter.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ class UnconsumedParameterWarning(UserWarning):
142142

143143

144144
T = TypeVar("T", default=str)
145+
_OptT = TypeVar("_OptT")
145146

146147

147148
class ConfigPath(TypedDict):
@@ -426,21 +427,31 @@ def _parser_kwargs(cls, param_name, task_name=None):
426427
}
427428

428429

429-
class OptionalParameterMixin:
430+
class OptionalParameterMixin(Generic[_OptT]):
430431
"""
431432
Mixin to make a parameter class optional and treat empty string as None.
432433
"""
433434

434-
expected_type = type(None)
435+
expected_type: type = type(None)
436+
437+
def __init__(
438+
self,
439+
default: Union[_OptT, None, _NoValueType] = _no_value,
440+
**kwargs: Unpack[_ParameterKwargs],
441+
):
442+
super().__init__(default=default, **kwargs) # type: ignore[arg-type, call-arg, misc]
435443

436444
@overload
437-
def __get__(self: "Parameter[T]", instance: None, owner: Any) -> "Parameter[Optional[T]]": ...
445+
def __get__(self, instance: None, owner: Any) -> "Parameter[Optional[_OptT]]": ...
438446

439447
@overload
440-
def __get__(self: "Parameter[T]", instance: Any, owner: Any) -> Optional[T]: ...
448+
def __get__(self, instance: Any, owner: Any) -> Optional[_OptT]: ...
441449

442450
def __get__(self, instance: Any, owner: Any) -> Any:
443-
return super().__get__(instance, owner)
451+
return super().__get__(instance, owner) # type: ignore[misc]
452+
453+
def __set__(self, instance: Any, value: Optional[_OptT]):
454+
super().__set__(instance, value) # type: ignore[misc]
444455

445456
def serialize(self, x):
446457
"""
@@ -485,13 +496,13 @@ def next_in_enumeration(self, value):
485496
return None
486497

487498

488-
class OptionalParameter(OptionalParameterMixin, Parameter[Optional[str]]):
499+
class OptionalParameter(OptionalParameterMixin[str], Parameter[Optional[str]]):
489500
"""Class to parse optional parameters."""
490501

491502
expected_type = str
492503

493504

494-
class OptionalStrParameter(OptionalParameterMixin, Parameter[Optional[str]]):
505+
class OptionalStrParameter(OptionalParameterMixin[str], Parameter[Optional[str]]):
495506
"""Class to parse optional str parameters."""
496507

497508
expected_type = str
@@ -798,7 +809,7 @@ def next_in_enumeration(self, value):
798809
return value + 1
799810

800811

801-
class OptionalIntParameter(OptionalParameterMixin, IntParameter):
812+
class OptionalIntParameter(OptionalParameterMixin[int], IntParameter): # type: ignore[misc]
802813
"""Class to parse optional int parameters."""
803814

804815
expected_type = int
@@ -816,7 +827,7 @@ def parse(self, x):
816827
return float(x)
817828

818829

819-
class OptionalFloatParameter(OptionalParameterMixin, FloatParameter):
830+
class OptionalFloatParameter(OptionalParameterMixin[float], FloatParameter): # type: ignore[misc]
820831
"""Class to parse optional float parameters."""
821832

822833
expected_type = float
@@ -897,7 +908,7 @@ def _parser_kwargs(self, *args, **kwargs):
897908
return parser_kwargs
898909

899910

900-
class OptionalBoolParameter(OptionalParameterMixin, BoolParameter):
911+
class OptionalBoolParameter(OptionalParameterMixin[bool], BoolParameter): # type: ignore[misc]
901912
"""Class to parse optional bool parameters."""
902913

903914
expected_type = bool
@@ -1299,7 +1310,7 @@ def serialize(self, x):
12991310
return json.dumps(x, cls=_DictParamEncoder)
13001311

13011312

1302-
class OptionalDictParameter(OptionalParameterMixin, DictParameter):
1313+
class OptionalDictParameter(OptionalParameterMixin[FrozenOrderedDict], DictParameter): # type: ignore[misc]
13031314
"""Class to parse optional dict parameters."""
13041315

13051316
expected_type = FrozenOrderedDict
@@ -1454,7 +1465,7 @@ def serialize(self, x):
14541465
return json.dumps(x, cls=_DictParamEncoder)
14551466

14561467

1457-
class OptionalListParameter(OptionalParameterMixin, ListParameter):
1468+
class OptionalListParameter(OptionalParameterMixin[ListT], ListParameter): # type: ignore[misc]
14581469
"""Class to parse optional list parameters."""
14591470

14601471
expected_type = tuple
@@ -1525,7 +1536,7 @@ def _convert_iterable_to_tuple(self, x):
15251536
return tuple(x)
15261537

15271538

1528-
class OptionalTupleParameter(OptionalParameterMixin, TupleParameter):
1539+
class OptionalTupleParameter(OptionalParameterMixin[ListT], TupleParameter): # type: ignore[misc]
15291540
"""Class to parse optional tuple parameters."""
15301541

15311542
expected_type = tuple
@@ -1588,13 +1599,13 @@ def __init__(
15881599
"""
15891600
if var_type is None:
15901601
raise ParameterException("var_type must be specified")
1591-
self._var_type = var_type
1602+
self._var_type: Type[NumericalType] = var_type
15921603
if min_value is None:
15931604
raise ParameterException("min_value must be specified")
1594-
self._min_value = min_value
1605+
self._min_value: NumericalType = min_value
15951606
if max_value is None:
15961607
raise ParameterException("max_value must be specified")
1597-
self._max_value = max_value
1608+
self._max_value: NumericalType = max_value
15981609
self._left_op = left_op
15991610
self._right_op = right_op
16001611
self._permitted_range = "{var_type} in {left_endpoint}{min_value}, {max_value}{right_endpoint}".format(
@@ -1604,7 +1615,7 @@ def __init__(
16041615
left_endpoint="[" if left_op == operator.le else "(",
16051616
right_endpoint=")" if right_op == operator.lt else "]",
16061617
)
1607-
super().__init__(default=default, **kwargs)
1618+
super().__init__(default=default, **kwargs) # type: ignore[arg-type]
16081619
if self.description:
16091620
self.description += " "
16101621
else:
@@ -1619,15 +1630,15 @@ def parse(self, x):
16191630
raise ValueError("{s} is not in the set of {permitted_range}".format(s=x, permitted_range=self._permitted_range))
16201631

16211632

1622-
class OptionalNumericalParameter(OptionalParameterMixin, NumericalParameter):
1633+
class OptionalNumericalParameter(OptionalParameterMixin[NumericalType], NumericalParameter[NumericalType]): # type: ignore[misc]
16231634
"""Class to parse optional numerical parameters."""
16241635

16251636
def __init__(
16261637
self,
16271638
default: Union[Optional[NumericalType], _NoValueType] = _no_value,
16281639
**kwargs: Unpack[_ParameterKwargs],
16291640
):
1630-
super().__init__(default=default, **kwargs)
1641+
NumericalParameter.__init__(self, default=default, **kwargs) # type: ignore[arg-type, misc]
16311642
self.expected_type = self._var_type
16321643

16331644

@@ -1664,7 +1675,7 @@ def __init__(
16641675
default: Union[ChoiceType, _NoValueType] = _no_value,
16651676
*,
16661677
choices: Optional[Sequence[ChoiceType]] = None,
1667-
var_type: Type[ChoiceType] = str,
1678+
var_type: Type[ChoiceType] = str, # type: ignore[assignment]
16681679
**kwargs: Unpack[_ParameterKwargs],
16691680
):
16701681
"""
@@ -1726,7 +1737,7 @@ class MyTask(luigi.Task):
17261737

17271738
_sep = ","
17281739

1729-
@overload
1740+
@overload # type: ignore[override]
17301741
def __get__(self, instance: None, owner: Any) -> "Parameter[Tuple[ChoiceType, ...]]": ...
17311742

17321743
@overload
@@ -1738,7 +1749,7 @@ def __get__(self, instance: Any, owner: Any) -> Any:
17381749
def __init__(
17391750
self,
17401751
default: Union[Tuple[ChoiceType, ...], _NoValueType] = _no_value,
1741-
var_type: Type[ChoiceType] = str,
1752+
var_type: Type[ChoiceType] = str, # type: ignore[assignment]
17421753
choices: Optional[Sequence[ChoiceType]] = None,
17431754
**kwargs: Unpack[_ParameterKwargs],
17441755
):
@@ -1758,17 +1769,17 @@ def serialize(self, x):
17581769
return self._sep.join(x)
17591770

17601771

1761-
class OptionalChoiceParameter(OptionalParameterMixin, ChoiceParameter[ChoiceType]):
1772+
class OptionalChoiceParameter(OptionalParameterMixin[ChoiceType], ChoiceParameter[ChoiceType]): # type: ignore[misc]
17621773
"""Class to parse optional choice parameters."""
17631774

17641775
def __init__(
17651776
self,
17661777
default: Union[Optional[ChoiceType], _NoValueType] = _no_value,
1767-
var_type: Type[ChoiceType] = str,
1778+
var_type: Type[ChoiceType] = str, # type: ignore[assignment]
17681779
choices: Optional[Sequence[ChoiceType]] = None,
17691780
**kwargs: Unpack[_ParameterKwargs],
17701781
):
1771-
super().__init__(default=default, var_type=var_type, choices=choices, **kwargs) # type: ignore[arg-type]
1782+
ChoiceParameter.__init__(self, default=default, var_type=var_type, choices=choices, **kwargs) # type: ignore[arg-type, misc]
17721783
self.expected_type = self._var_type
17731784

17741785

@@ -1831,7 +1842,7 @@ def normalize(self, x):
18311842
return path
18321843

18331844

1834-
class OptionalPathParameter(OptionalParameter, PathParameter):
1845+
class OptionalPathParameter(OptionalParameter, PathParameter): # type: ignore[misc]
18351846
"""Class to parse optional path parameters."""
18361847

1837-
expected_type = (str, Path) # type: ignore
1848+
expected_type = (str, Path) # type: ignore[assignment]

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ module = [
191191
"luigi.contrib.sqla",
192192
"luigi.interface",
193193
"luigi.notifications",
194-
"luigi.parameter",
195194
"luigi.tools.range",
196195
"luigi.worker",
197196
]

0 commit comments

Comments
 (0)