Coverage for railway / core / type_check.py: 67%
45 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 00:06 +0900
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 00:06 +0900
1"""Type checking utilities for pipeline strict mode."""
3from typing import Any, Callable, Union, get_args, get_origin, get_type_hints
6def check_type_compatibility(value: Any, expected_type: type) -> bool:
7 """
8 Check if value is compatible with expected type.
10 Args:
11 value: The value to check.
12 expected_type: The expected type.
14 Returns:
15 True if compatible, False otherwise.
16 """
17 if expected_type is Any: 17 ↛ 18line 17 didn't jump to line 18 because the condition on line 17 was never true
18 return True
20 # Handle None
21 if value is None:
22 origin = get_origin(expected_type)
23 if origin is Union: 23 ↛ 26line 23 didn't jump to line 26 because the condition on line 23 was always true
24 args = get_args(expected_type)
25 return type(None) in args
26 return expected_type is type(None)
28 # Get origin for generic types
29 origin = get_origin(expected_type)
31 # Handle Union types (including Optional)
32 if origin is Union:
33 args = get_args(expected_type)
34 return any(check_type_compatibility(value, arg) for arg in args)
36 # Handle generic types (List, Dict, etc.)
37 if origin is not None:
38 # Basic check against origin type
39 if not isinstance(value, origin): 39 ↛ 40line 39 didn't jump to line 40 because the condition on line 39 was never true
40 return False
41 return True
43 # Simple isinstance check
44 try:
45 return isinstance(value, expected_type)
46 except TypeError:
47 # Some types can't be used with isinstance
48 return True
51def get_function_input_type(func: Callable) -> type | None:
52 """
53 Get the input type of a function's first parameter.
55 Args:
56 func: The function to inspect.
58 Returns:
59 The type of the first parameter, or Any if not specified.
60 """
61 # Get original function if wrapped
62 original = getattr(func, "_original_func", func)
64 try:
65 hints = get_type_hints(original)
66 # Get first parameter's type (excluding 'return')
67 params = [k for k in hints.keys() if k != "return"]
68 if params: 68 ↛ 72line 68 didn't jump to line 72 because the condition on line 68 was always true
69 return hints[params[0]]
70 except Exception:
71 pass
72 return Any
75def get_function_output_type(func: Callable) -> type | None:
76 """
77 Get the return type of a function.
79 Args:
80 func: The function to inspect.
82 Returns:
83 The return type, or Any if not specified.
84 """
85 # Get original function if wrapped
86 original = getattr(func, "_original_func", func)
88 try:
89 hints = get_type_hints(original)
90 return hints.get("return", Any)
91 except Exception:
92 return Any
95def format_type_error(
96 step_num: int,
97 step_name: str,
98 expected_type: type,
99 actual_type: type,
100 actual_value: Any,
101) -> str:
102 """
103 Format a type mismatch error message.
105 Args:
106 step_num: The step number in the pipeline.
107 step_name: The name of the step.
108 expected_type: The expected type.
109 actual_type: The actual type of the value.
110 actual_value: The actual value.
112 Returns:
113 Formatted error message.
114 """
115 return (
116 f"Pipeline type mismatch at step {step_num} ({step_name}): "
117 f"expected {_type_name(expected_type)}, "
118 f"got {_type_name(actual_type)} (value: {repr(actual_value)[:50]})"
119 )
122def _type_name(t: type) -> str:
123 """
124 Get a readable name for a type.
126 Args:
127 t: The type to get name for.
129 Returns:
130 Human-readable type name.
131 """
132 if hasattr(t, "__name__"): 132 ↛ 134line 132 didn't jump to line 134 because the condition on line 132 was always true
133 return t.__name__
134 return str(t)