Coverage for railway / core / type_check.py: 67%

45 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 00:06 +0900

1"""Type checking utilities for pipeline strict mode.""" 

2 

3from typing import Any, Callable, Union, get_args, get_origin, get_type_hints 

4 

5 

6def check_type_compatibility(value: Any, expected_type: type) -> bool: 

7 """ 

8 Check if value is compatible with expected type. 

9 

10 Args: 

11 value: The value to check. 

12 expected_type: The expected type. 

13 

14 Returns: 

15 True if compatible, False otherwise. 

16 """ 

17 if expected_type is Any: 17 ↛ 18line 17 didn't jump to line 18 because the condition on line 17 was never true

18 return True 

19 

20 # Handle None 

21 if value is None: 

22 origin = get_origin(expected_type) 

23 if origin is Union: 23 ↛ 26line 23 didn't jump to line 26 because the condition on line 23 was always true

24 args = get_args(expected_type) 

25 return type(None) in args 

26 return expected_type is type(None) 

27 

28 # Get origin for generic types 

29 origin = get_origin(expected_type) 

30 

31 # Handle Union types (including Optional) 

32 if origin is Union: 

33 args = get_args(expected_type) 

34 return any(check_type_compatibility(value, arg) for arg in args) 

35 

36 # Handle generic types (List, Dict, etc.) 

37 if origin is not None: 

38 # Basic check against origin type 

39 if not isinstance(value, origin): 39 ↛ 40line 39 didn't jump to line 40 because the condition on line 39 was never true

40 return False 

41 return True 

42 

43 # Simple isinstance check 

44 try: 

45 return isinstance(value, expected_type) 

46 except TypeError: 

47 # Some types can't be used with isinstance 

48 return True 

49 

50 

51def get_function_input_type(func: Callable) -> type | None: 

52 """ 

53 Get the input type of a function's first parameter. 

54 

55 Args: 

56 func: The function to inspect. 

57 

58 Returns: 

59 The type of the first parameter, or Any if not specified. 

60 """ 

61 # Get original function if wrapped 

62 original = getattr(func, "_original_func", func) 

63 

64 try: 

65 hints = get_type_hints(original) 

66 # Get first parameter's type (excluding 'return') 

67 params = [k for k in hints.keys() if k != "return"] 

68 if params: 68 ↛ 72line 68 didn't jump to line 72 because the condition on line 68 was always true

69 return hints[params[0]] 

70 except Exception: 

71 pass 

72 return Any 

73 

74 

75def get_function_output_type(func: Callable) -> type | None: 

76 """ 

77 Get the return type of a function. 

78 

79 Args: 

80 func: The function to inspect. 

81 

82 Returns: 

83 The return type, or Any if not specified. 

84 """ 

85 # Get original function if wrapped 

86 original = getattr(func, "_original_func", func) 

87 

88 try: 

89 hints = get_type_hints(original) 

90 return hints.get("return", Any) 

91 except Exception: 

92 return Any 

93 

94 

95def format_type_error( 

96 step_num: int, 

97 step_name: str, 

98 expected_type: type, 

99 actual_type: type, 

100 actual_value: Any, 

101) -> str: 

102 """ 

103 Format a type mismatch error message. 

104 

105 Args: 

106 step_num: The step number in the pipeline. 

107 step_name: The name of the step. 

108 expected_type: The expected type. 

109 actual_type: The actual type of the value. 

110 actual_value: The actual value. 

111 

112 Returns: 

113 Formatted error message. 

114 """ 

115 return ( 

116 f"Pipeline type mismatch at step {step_num} ({step_name}): " 

117 f"expected {_type_name(expected_type)}, " 

118 f"got {_type_name(actual_type)} (value: {repr(actual_value)[:50]})" 

119 ) 

120 

121 

122def _type_name(t: type) -> str: 

123 """ 

124 Get a readable name for a type. 

125 

126 Args: 

127 t: The type to get name for. 

128 

129 Returns: 

130 Human-readable type name. 

131 """ 

132 if hasattr(t, "__name__"): 132 ↛ 134line 132 didn't jump to line 134 because the condition on line 132 was always true

133 return t.__name__ 

134 return str(t)