Source code for adam_core.ray_cluster

import logging
import warnings
from typing import Optional

import ray

logger = logging.getLogger(__name__)

_JAX_FORK_RUNTIMEWARNING_RE = r"os\.fork\(\) was called\..*JAX is multithreaded.*"


def _silence_jax_fork_runtimewarning() -> None:
    warnings.filterwarnings(
        "ignore",
        message=_JAX_FORK_RUNTIMEWARNING_RE,
        category=RuntimeWarning,
    )


[docs] def initialize_use_ray( num_cpus: Optional[int] = None, object_store_bytes: Optional[int] = None, **kwargs ) -> bool: """ Ensures we use existing local cluster, or starts new one with desired resources """ use_ray = False if num_cpus is None or num_cpus > 1: _silence_jax_fork_runtimewarning() # Default Ray configuration for this codebase. # # - We don't need the dashboard in library usage (and it can bring in extra # background services). kwargs.setdefault("include_dashboard", False) # Initialize ray if not ray.is_initialized(): logger.info("Ray is not initialized. Initializing...") # For some reason, ray does not seem to automatically # find existing local clusters without `address="auto"` # but it will fail if we use auto and there is no existing cluster. # So we wrap it in a try/except, using an existing cluster if we can # Otherwise starting fresh. try: logger.info("Attempting to connect to existing ray cluster...") ray.init(address="auto", **kwargs) except ConnectionError: logger.info("Could not connect to existing ray cluster.") logger.info( f"Attempting ray with {num_cpus} cpus and {object_store_bytes} bytes." ) ray.init( num_cpus=num_cpus, object_store_memory=object_store_bytes, **kwargs, ) logger.info(f"Ray Resources: {ray.cluster_resources()}") use_ray = True return use_ray