Coverage for src / dataknobs_data / backends / postgres_mixins.py: 44%

62 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-26 15:45 -0700

1"""Shared mixins for PostgreSQL database backends. 

2 

3These mixins provide common functionality for both sync and async PostgreSQL implementations, 

4reducing code duplication and ensuring consistent behavior. 

5""" 

6 

7import logging 

8from typing import Any 

9 

10from ..records import Record 

11from .vector_config_mixin import VectorConfigMixin 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16class PostgresBaseConfig(VectorConfigMixin): 

17 """Shared configuration logic for PostgreSQL backends.""" 

18 

19 def _parse_postgres_config(self, config: dict[str, Any]) -> tuple[str, str, dict]: 

20 """Extract table, schema, and connection configuration. 

21  

22 Args: 

23 config: Configuration dictionary 

24  

25 Returns: 

26 Tuple of (table_name, schema_name, connection_config) 

27 """ 

28 config = config.copy() if config else {} 

29 

30 # Parse vector configuration using the mixin 

31 self._parse_vector_config(config) 

32 

33 # Extract PostgreSQL-specific configuration 

34 table_name = config.pop("table", config.pop("table_name", "records")) 

35 schema_name = config.pop("schema", config.pop("schema_name", "public")) 

36 

37 # Remove vector config parameters since they've been processed 

38 config.pop("vector_enabled", None) 

39 config.pop("vector_metric", None) 

40 

41 return table_name, schema_name, config 

42 

43 def _init_postgres_attributes(self, table_name: str, schema_name: str) -> None: 

44 """Initialize common PostgreSQL attributes. 

45  

46 Args: 

47 table_name: Name of the database table 

48 schema_name: Name of the database schema 

49 """ 

50 self.table_name = table_name 

51 self.schema_name = schema_name 

52 self._connected = False 

53 

54 # Initialize vector state using the mixin 

55 self._init_vector_state() 

56 

57 

58class PostgresTableManager: 

59 """Shared table management SQL and logic.""" 

60 

61 @staticmethod 

62 def get_create_table_sql(schema_name: str, table_name: str) -> str: 

63 """Get SQL for creating the records table with indexes. 

64  

65 Args: 

66 schema_name: Database schema name 

67 table_name: Database table name 

68  

69 Returns: 

70 SQL string for table creation 

71 """ 

72 return f""" 

73 CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ( 

74 id TEXT PRIMARY KEY, 

75 data JSONB NOT NULL, 

76 metadata JSONB, 

77 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 

78 updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP 

79 ); 

80  

81 CREATE INDEX IF NOT EXISTS idx_{table_name}_data  

82 ON {schema_name}.{table_name} USING GIN (data); 

83  

84 CREATE INDEX IF NOT EXISTS idx_{table_name}_metadata 

85 ON {schema_name}.{table_name} USING GIN (metadata); 

86 """ 

87 

88 @staticmethod 

89 def get_table_exists_sql(schema_name: str, table_name: str) -> str: 

90 """Get SQL to check if table exists. 

91  

92 Args: 

93 schema_name: Database schema name 

94 table_name: Database table name 

95  

96 Returns: 

97 SQL string to check table existence 

98 """ 

99 return f""" 

100 SELECT EXISTS ( 

101 SELECT FROM information_schema.tables  

102 WHERE table_schema = '{schema_name}'  

103 AND table_name = '{table_name}' 

104 ) 

105 """ 

106 

107 

108class PostgresVectorSupport: 

109 """Shared vector support detection and management.""" 

110 

111 def _has_vector_fields(self, record: Record) -> bool: 

112 """Check if record has vector fields. 

113  

114 Args: 

115 record: Record to check 

116  

117 Returns: 

118 True if record has vector fields 

119 """ 

120 from ..fields import VectorField 

121 return any(isinstance(field, VectorField) 

122 for field in record.fields.values()) 

123 

124 def _extract_vector_dimensions(self, record: Record) -> dict[str, int]: 

125 """Extract dimensions from vector fields in a record. 

126  

127 Args: 

128 record: Record containing potential vector fields 

129  

130 Returns: 

131 Dictionary mapping field names to dimensions 

132 """ 

133 from ..fields import VectorField 

134 dimensions = {} 

135 for name, field in record.fields.items(): 

136 if isinstance(field, VectorField) and field.dimensions: 

137 dimensions[name] = field.dimensions 

138 return dimensions 

139 

140 def _update_vector_dimensions(self, record: Record) -> None: 

141 """Update tracked vector dimensions from a record. 

142  

143 Args: 

144 record: Record containing vector fields 

145 """ 

146 if hasattr(self, '_vector_dimensions'): 

147 dimensions = self._extract_vector_dimensions(record) 

148 self._vector_dimensions.update(dimensions) 

149 

150 

151class PostgresErrorHandler: 

152 """Shared error handling logic for PostgreSQL operations.""" 

153 

154 @staticmethod 

155 def handle_connection_error(e: Exception) -> None: 

156 """Handle and log connection errors consistently. 

157  

158 Args: 

159 e: The exception that occurred 

160  

161 Raises: 

162 RuntimeError: With a user-friendly message 

163 """ 

164 logger.error(f"PostgreSQL connection error: {e}") 

165 raise RuntimeError(f"Database connection failed: {e}") 

166 

167 @staticmethod 

168 def handle_query_error(e: Exception, operation: str) -> None: 

169 """Handle and log query execution errors. 

170  

171 Args: 

172 e: The exception that occurred 

173 operation: The operation that failed (e.g., "create", "update") 

174  

175 Raises: 

176 RuntimeError: With a user-friendly message 

177 """ 

178 logger.error(f"PostgreSQL {operation} error: {e}") 

179 raise RuntimeError(f"Database {operation} failed: {e}") 

180 

181 @staticmethod 

182 def log_operation(operation: str, details: str = "") -> None: 

183 """Log a database operation for debugging. 

184  

185 Args: 

186 operation: The operation being performed 

187 details: Additional details about the operation 

188 """ 

189 if details: 

190 logger.debug(f"PostgreSQL {operation}: {details}") 

191 else: 

192 logger.debug(f"PostgreSQL {operation}") 

193 

194 

195class PostgresConnectionValidator: 

196 """Shared connection validation logic.""" 

197 

198 def _check_connection(self) -> None: 

199 """Check if database is connected. 

200  

201 Raises: 

202 RuntimeError: If not connected 

203 """ 

204 if not getattr(self, '_connected', False): 

205 raise RuntimeError("Database not connected. Call connect() first.") 

206 

207 def _check_async_connection(self) -> None: 

208 """Check if async database is connected with pool. 

209  

210 Raises: 

211 RuntimeError: If not connected or pool not initialized 

212 """ 

213 if not getattr(self, '_connected', False) or not getattr(self, '_pool', None): 

214 raise RuntimeError("Database not connected. Call connect() first.")