Coverage for src / dataknobs_data / vector / optimizations.py: 0%
212 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-26 15:45 -0700
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-26 15:45 -0700
1"""Vector store optimization and performance enhancements."""
3from __future__ import annotations
5import asyncio
6import logging
7from collections import deque
8from dataclasses import dataclass
9from threading import Lock
10from typing import Any, TYPE_CHECKING
12import numpy as np
14if TYPE_CHECKING:
15 from collections.abc import Callable
16 from .types import DistanceMetric
19logger = logging.getLogger(__name__)
22@dataclass
23class BatchConfig:
24 """Configuration for batch operations."""
26 size: int = 100
27 max_queue_size: int = 10000
28 flush_interval: float = 1.0 # seconds
29 parallel_workers: int = 4
30 retry_on_failure: bool = True
31 max_retries: int = 3
34@dataclass
35class ConnectionPoolConfig:
36 """Configuration for connection pooling."""
38 min_connections: int = 1
39 max_connections: int = 10
40 connection_timeout: float = 30.0
41 idle_timeout: float = 300.0
42 recycle_timeout: float = 3600.0
45class BatchProcessor:
46 """Handles batch processing of vector operations."""
48 def __init__(self, config: BatchConfig | None = None):
49 """Initialize the batch processor.
51 Args:
52 config: Batch configuration
53 """
54 self.config = config or BatchConfig()
55 self.queue: deque = deque(maxlen=self.config.max_queue_size)
56 self.lock = Lock()
57 self.processing = False
58 self._flush_task: asyncio.Task | None = None
60 async def add(self, item: Any, callback: Callable | None = None) -> None:
61 """Add an item to the batch queue.
63 Args:
64 item: Item to process
65 callback: Optional callback when item is processed
66 """
67 should_flush = False
68 with self.lock:
69 self.queue.append((item, callback))
70 # Check if we should flush
71 if len(self.queue) >= self.config.size:
72 should_flush = True
74 # Flush outside of lock to avoid deadlock
75 if should_flush:
76 await self.flush()
78 async def flush(self) -> int:
79 """Process all items in the queue.
81 Returns:
82 Number of items processed
83 """
84 items_to_process = []
86 with self.lock:
87 # Get batch of items
88 batch_size = min(len(self.queue), self.config.size)
89 for _ in range(batch_size):
90 if self.queue:
91 items_to_process.append(self.queue.popleft())
93 if not items_to_process:
94 return 0
96 # Process items in parallel if configured
97 if self.config.parallel_workers > 1:
98 return await self._process_parallel(items_to_process)
99 else:
100 return await self._process_sequential(items_to_process)
102 async def _process_sequential(self, items: list[tuple]) -> int:
103 """Process items sequentially.
105 Args:
106 items: List of (item, callback) tuples
108 Returns:
109 Number of items processed
110 """
111 processed = 0
112 for item, callback in items:
113 try:
114 if callback:
115 if asyncio.iscoroutinefunction(callback):
116 await callback(item)
117 else:
118 callback(item)
119 processed += 1
120 except Exception as e:
121 logger.error(f"Error processing item: {e}")
122 if self.config.retry_on_failure:
123 # Re-queue for retry
124 with self.lock:
125 self.queue.append((item, callback))
127 return processed
129 async def _process_parallel(self, items: list[tuple]) -> int:
130 """Process items in parallel.
132 Args:
133 items: List of (item, callback) tuples
135 Returns:
136 Number of items processed
137 """
138 # Split items into chunks for parallel processing
139 chunk_size = len(items) // self.config.parallel_workers
140 if chunk_size == 0:
141 chunk_size = 1
143 chunks = [
144 items[i:i + chunk_size]
145 for i in range(0, len(items), chunk_size)
146 ]
148 # Process chunks in parallel
149 tasks = [
150 self._process_sequential(chunk)
151 for chunk in chunks
152 ]
154 results = await asyncio.gather(*tasks, return_exceptions=True)
156 # Count successful processes
157 processed = sum(
158 r for r in results
159 if isinstance(r, int)
160 )
162 return processed
164 async def start_auto_flush(self) -> None:
165 """Start automatic flushing at intervals."""
166 if self._flush_task is None or self._flush_task.done():
167 self._flush_task = asyncio.create_task(self._auto_flush_loop())
169 async def stop_auto_flush(self) -> None:
170 """Stop automatic flushing."""
171 if self._flush_task and not self._flush_task.done():
172 self._flush_task.cancel()
173 try:
174 await self._flush_task
175 except asyncio.CancelledError:
176 pass
178 async def _auto_flush_loop(self) -> None:
179 """Background task for automatic flushing."""
180 while True:
181 try:
182 await asyncio.sleep(self.config.flush_interval)
183 await self.flush()
184 except asyncio.CancelledError:
185 break
186 except Exception as e:
187 logger.error(f"Error in auto-flush: {e}")
190class VectorOptimizer:
191 """Optimizes vector operations for better performance."""
193 @staticmethod
194 def optimize_batch_size(
195 num_vectors: int,
196 vector_dim: int,
197 available_memory: int = 1024 * 1024 * 1024 # 1GB default
198 ) -> int:
199 """Calculate optimal batch size based on available resources.
201 Args:
202 num_vectors: Total number of vectors
203 vector_dim: Dimension of each vector
204 available_memory: Available memory in bytes
206 Returns:
207 Optimal batch size
208 """
209 # Estimate memory per vector (float32 = 4 bytes)
210 bytes_per_vector = vector_dim * 4
212 # Add overhead for metadata and indexing (estimate 50% overhead)
213 bytes_per_vector = int(bytes_per_vector * 1.5)
215 # Calculate max vectors that fit in memory
216 max_batch = available_memory // bytes_per_vector
218 # Apply reasonable limits
219 min_batch = 10
220 max_reasonable = 10000
222 optimal = min(max_batch, max_reasonable, num_vectors)
223 optimal = max(optimal, min_batch)
225 return optimal
227 @staticmethod
228 def select_index_type(
229 num_vectors: int,
230 vector_dim: int,
231 metric: DistanceMetric
232 ) -> dict[str, Any]:
233 """Select optimal index type based on dataset characteristics.
235 Args:
236 num_vectors: Number of vectors
237 vector_dim: Vector dimensions
238 metric: Distance metric
240 Returns:
241 Index configuration
242 """
243 config = {"metric": metric}
245 # Small datasets: use flat index for exact search
246 if num_vectors < 10000:
247 config["type"] = "flat"
248 return config
250 # Medium datasets: use IVF
251 if num_vectors < 1000000:
252 # Calculate optimal number of clusters
253 nlist = int(np.sqrt(num_vectors))
254 nlist = min(max(nlist, 100), 4096)
256 config["type"] = "ivfflat"
257 config["nlist"] = nlist
258 config["nprobe"] = min(nlist // 10, 64)
259 return config
261 # Large datasets: use HNSW
262 config["type"] = "hnsw"
263 config["m"] = 16 # Number of connections
264 config["ef_construction"] = 200
265 config["ef_search"] = 50
267 return config
269 @staticmethod
270 def optimize_search_params(
271 index_type: str,
272 recall_target: float = 0.95
273 ) -> dict[str, Any]:
274 """Optimize search parameters for target recall.
276 Args:
277 index_type: Type of index
278 recall_target: Target recall rate (0-1)
280 Returns:
281 Optimized search parameters
282 """
283 params = {}
285 if index_type == "flat":
286 # Flat index is always exact
287 return params
289 elif index_type == "ivfflat":
290 # Adjust nprobe based on recall target
291 if recall_target >= 0.99:
292 params["nprobe"] = 128
293 elif recall_target >= 0.95:
294 params["nprobe"] = 64
295 elif recall_target >= 0.90:
296 params["nprobe"] = 32
297 else:
298 params["nprobe"] = 16
300 elif index_type == "hnsw":
301 # Adjust ef_search based on recall target
302 if recall_target >= 0.99:
303 params["ef_search"] = 200
304 elif recall_target >= 0.95:
305 params["ef_search"] = 100
306 elif recall_target >= 0.90:
307 params["ef_search"] = 50
308 else:
309 params["ef_search"] = 32
311 return params
314class ConnectionPool:
315 """Manages a pool of connections for vector stores."""
317 def __init__(self,
318 factory: Callable,
319 config: ConnectionPoolConfig | None = None):
320 """Initialize the connection pool.
322 Args:
323 factory: Function to create new connections
324 config: Pool configuration
325 """
326 self.factory = factory
327 self.config = config or ConnectionPoolConfig()
328 self.available: deque = deque()
329 self.in_use: set = set()
330 self.lock = Lock()
331 self._closed = False
333 async def acquire(self) -> Any:
334 """Acquire a connection from the pool.
336 Returns:
337 A connection object
338 """
339 if self._closed:
340 raise RuntimeError("Connection pool is closed")
342 with self.lock:
343 # Try to get an available connection
344 while self.available:
345 conn = self.available.popleft()
346 # TODO: Check if connection is still valid
347 self.in_use.add(conn)
348 return conn
350 # Create new connection if under limit
351 if len(self.in_use) < self.config.max_connections:
352 conn = await self.factory()
353 self.in_use.add(conn)
354 return conn
356 # Wait for a connection to become available
357 retry_count = 0
358 while retry_count < 100: # Avoid infinite loop
359 await asyncio.sleep(0.1)
360 with self.lock:
361 if self.available:
362 conn = self.available.popleft()
363 self.in_use.add(conn)
364 return conn
365 retry_count += 1
367 raise TimeoutError("Could not acquire connection from pool")
369 async def release(self, conn: Any) -> None:
370 """Release a connection back to the pool.
372 Args:
373 conn: Connection to release
374 """
375 with self.lock:
376 if conn in self.in_use:
377 self.in_use.remove(conn)
378 if not self._closed:
379 self.available.append(conn)
381 async def close(self) -> None:
382 """Close all connections in the pool."""
383 self._closed = True
385 with self.lock:
386 # Close all connections
387 all_conns = list(self.available) + list(self.in_use)
388 self.available.clear()
389 self.in_use.clear()
391 # Close connections (if they have close method)
392 for conn in all_conns:
393 if hasattr(conn, 'close'):
394 try:
395 if asyncio.iscoroutinefunction(conn.close):
396 await conn.close()
397 else:
398 conn.close()
399 except Exception as e:
400 logger.error(f"Error closing connection: {e}")
403class QueryOptimizer:
404 """Optimizes vector queries for better performance."""
406 @staticmethod
407 def should_use_index(
408 num_vectors: int,
409 k: int,
410 filter_selectivity: float = 1.0
411 ) -> bool:
412 """Determine if index should be used for query.
414 Args:
415 num_vectors: Total number of vectors
416 k: Number of results to return
417 filter_selectivity: Estimated filter selectivity (0-1)
419 Returns:
420 True if index should be used
421 """
422 # If we're retrieving most vectors, scan might be faster
423 if k / num_vectors > 0.1:
424 return False
426 # If filter is very selective, scan filtered results
427 if filter_selectivity < 0.01:
428 return False
430 # Otherwise use index
431 return True
433 @staticmethod
434 def optimize_reranking(
435 initial_k: int,
436 final_k: int,
437 rerank_factor: float = 3.0
438 ) -> int:
439 """Calculate optimal number of candidates for reranking.
441 Args:
442 initial_k: Initial number of results
443 final_k: Final number of results after reranking
444 rerank_factor: Multiplier for candidates
446 Returns:
447 Optimal number of candidates
448 """
449 candidates = int(final_k * rerank_factor)
451 # Apply reasonable limits
452 min_candidates = final_k * 2
453 max_candidates = min(initial_k, final_k * 10)
455 candidates = max(candidates, min_candidates)
456 candidates = min(candidates, max_candidates)
458 return candidates
461# Export main classes
462__all__ = [
463 "BatchConfig",
464 "BatchProcessor",
465 "ConnectionPool",
466 "ConnectionPoolConfig",
467 "QueryOptimizer",
468 "VectorOptimizer",
469]