Source code for treecorr.util

# Copyright (c) 2003-2024 by Mike Jarvis
#
# TreeCorr is free software: redistribution and use in source and binary forms,
# with or without modification, are permitted provided that the following
# conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions, and the disclaimer given in the accompanying LICENSE
#    file.
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions, and the disclaimer given in the documentation
#    and/or other materials provided with the distribution.

"""
.. module:: util
"""

import numpy as np
import os
import coord
import functools
import inspect
import warnings

from . import _treecorr
from . import Rperp_alias
from .writer import AsciiWriter, FitsWriter, HdfWriter
from .reader import AsciiReader, FitsReader, HdfReader

max_omp_threads=None

[docs]def set_max_omp_threads(num_threads, logger=None): """Set the maximum allowed number of OpenMP threads to use in the C++ layer in any further TreeCorr functions :param num_threads: The target maximum number of threads to allow. None means no limit. :param logger: If desired, a logger object for logging any warnings here. (default: None) """ global max_omp_threads max_omp_threads=num_threads
[docs]def set_omp_threads(num_threads, logger=None): """Set the number of OpenMP threads to use in the C++ layer. :param num_threads: The target number of threads to use :param logger: If desired, a logger object for logging any warnings here. (default: None) :returns: The number of threads OpenMP reports that it will use. Typically this matches the input, but OpenMP reserves the right not to comply with the requested number of threads. """ input_num_threads = num_threads # Save the input value. # If num_threads is auto, get it from cpu_count if num_threads is None or num_threads <= 0: import multiprocessing num_threads = multiprocessing.cpu_count() if logger: logger.debug('multiprocessing.cpu_count() = %d',num_threads) # Max at max_omp_threads, if set. if max_omp_threads is not None and num_threads > max_omp_threads: num_threads = max_omp_threads if logger: logger.debug('max_omp_threads = %d',max_omp_threads) # Tell OpenMP to use this many threads if logger: logger.debug('Telling OpenMP to use %d threads',num_threads) # See comment about this in get_omp_threads. Do it here too. var = "OMP_PROC_BIND" if var not in os.environ: # pragma: no cover os.environ[var] = "false" num_threads = _treecorr.SetOMPThreads(num_threads) # Report back appropriately. if logger: logger.debug('OpenMP reports that it will use %d threads',num_threads) if num_threads > 1: logger.info('Using %d threads.',num_threads) elif input_num_threads is not None and input_num_threads != 1: # Only warn if the user specifically asked for num_threads != 1. logger.warning("Unable to use multiple threads, since OpenMP is not enabled.") return num_threads
[docs]def get_omp_threads(): """Get the number of OpenMP threads currently set to be used in the C++ layer. :returns: The number of threads OpenMP reports that it will use. """ # Some OMP implemenations have a bug where if omp_get_max_threads() is called # (which is what this function does), it sets something called thread affinity. # The upshot of that is that multiprocessing (i.e. not even just omp threading) is # confined to a single hardware thread. Yeah, it's idiotic, but that seems to be # the case. The only solution found by Eli, who looked into it pretty hard, is to # set the env variable OMP_PROC_BIND to "false". This seems to stop the bad behavior. # So we do it here always before calling GetOMPThreads. # If this breaks someone valid use of this variable, let us know and we can try to # come up with another solution, but without this lots of multiprocessing breaks. var = "OMP_PROC_BIND" if var not in os.environ: # pragma: no cover os.environ[var] = "false" return _treecorr.GetOMPThreads()
def parse_file_type(file_type, file_name, output=False, logger=None): """Parse the file_type from the file_name if necessary :param file_type: The input file_type. If None, then parse from file_name's extension. :param file_name: The filename to use for parsing if necessary. :param output: Limit to output file types (FITS/HDF/ASCII)? (default: False) :param logger: A logger if desired. (default: None) :returns: The parsed file_type. """ if file_type is None: import os name, ext = os.path.splitext(file_name) ext = ext.lower() if ext.startswith('.fit'): file_type = 'FITS' elif ext.startswith('.hdf') or ext.startswith('.h5'): file_type = 'HDF' elif not output and ext.startswith('.par'): file_type = 'Parquet' else: file_type = 'ASCII' if logger: logger.info(" file_type assumed to be %s from the file name.",file_type) return file_type.upper() def make_writer(file_name, precision=4, file_type=None, logger=None): """Factory function to make a writer instance of the correct type. """ # Figure out which file type to use. file_type = parse_file_type(file_type, file_name, output=True, logger=logger) if file_type == 'FITS': writer = FitsWriter(file_name, logger=logger) elif file_type == 'HDF': writer = HdfWriter(file_name, logger=logger) elif file_type == 'ASCII': writer = AsciiWriter(file_name, precision=precision, logger=logger) else: raise ValueError("Invalid file_type %s"%file_type) return writer def make_reader(file_name, file_type=None, logger=None): """Factory function to make a writer instance of the correct type. """ # Figure out which file type to use. file_type = parse_file_type(file_type, file_name, output=False, logger=logger) if file_type == 'FITS': reader = FitsReader(file_name, logger=logger) elif file_type == 'HDF': reader = HdfReader(file_name, logger=logger) elif file_type == 'ASCII': reader = AsciiReader(file_name, logger=logger) else: raise ValueError("Invalid file_type %s"%file_type) return reader class LRU_Cache(object): """ Simplified Least Recently Used Cache. Mostly stolen from http://code.activestate.com/recipes/577970-simplified-lru-cache/, but added a method for dynamic resizing. The least recently used cached item is overwritten on a cache miss. Note: This has additional functionality beyond what functools.lru_cache provides. 1. The ability to resize the maxsize non-destructively. 2. The key is only on the args, not kwargs, so a logger can be provided as a kwarg without triggering a cache miss. :param user_function: A python function to cache. :param maxsize: Maximum number of inputs to cache. [Default: 1024] Usage ----- >>> def slow_function(*args) # A slow-to-evaluate python function >>> ... >>> >>> v1 = slow_function(*k1) # Calling function is slow >>> v1 = slow_function(*k1) # Calling again with same args is still slow >>> cache = galsim.utilities.LRU_Cache(slow_function) >>> v1 = cache(*k1) # Returns slow_function(*k1), slowly the first time >>> v1 = cache(*k1) # Returns slow_function(*k1) again, but fast this time. Methods ------- >>> cache.resize(maxsize) # Resize the cache, either upwards or downwards. Upwards resizing # is non-destructive. Downwards resizing will remove the least # recently used items first. """ def __init__(self, user_function, maxsize=1024): # Link layout: [PREV, NEXT, KEY, RESULT] self.root = [None, None, None, None] self.user_function = user_function self.cache = {} last = self.root for i in range(maxsize): key = object() self.cache[key] = last[1] = last = [last, self.root, key, None] self.root[0] = last self.count = 0 def __call__(self, *key, **kwargs): link = self.cache.get(key) if link is not None: # Cache hit: move link to last position link_prev, link_next, _, result = link link_prev[1] = link_next link_next[0] = link_prev last = self.root[0] last[1] = self.root[0] = link link[0] = last link[1] = self.root return result # Cache miss: evaluate and insert new key/value at root, then increment root # so that just-evaluated value is in last position. result = self.user_function(*key, **kwargs) self.root[2] = key self.root[3] = result oldroot = self.root self.root = self.root[1] oldkey = self.root[2] self.root[2] = None self.root[3] = None self.cache[key] = oldroot del self.cache[oldkey] if self.count < self.size: self.count += 1 return result def values(self): """Lists all items stored in the cache""" return list([v[3] for v in self.cache.values() if v[3] is not None]) @property def last_value(self): """Return the most recently used value""" return self.root[0][3] def resize(self, maxsize): """ Resize the cache. Increasing the size of the cache is non-destructive, i.e., previously cached inputs remain in the cache. Decreasing the size of the cache will necessarily remove items from the cache if the cache is already filled. Items are removed in least recently used order. :param maxsize: The new maximum number of inputs to cache. """ oldsize = len(self.cache) if maxsize == oldsize: return else: if maxsize < 0: raise ValueError("Invalid maxsize") elif maxsize < oldsize: for i in range(oldsize - maxsize): # Delete root.next current_next_link = self.root[1] new_next_link = self.root[1] = self.root[1][1] new_next_link[0] = self.root del self.cache[current_next_link[2]] self.count = min(self.count, maxsize) else: # maxsize > oldsize: for i in range(maxsize - oldsize): # Insert between root and root.next key = object() self.cache[key] = link = [self.root, self.root[1], key, None] self.root[1][0] = link self.root[1] = link def clear(self): """ Clear all items from the cache. """ maxsize = len(self.cache) self.cache.clear() last = self.root for i in range(maxsize): last[3] = None # Sever pointer to any existing result. key = object() self.cache[key] = last[1] = last = [last, self.root, key, None] self.root[0] = last self.count = 0 @property def size(self): return len(self.cache) def parse_metric(metric, coords, coords2=None, coords3=None): """ Convert a string metric into the corresponding enum to pass to the C code. """ if coords2 is None: auto = True else: auto = False # Special Rlens doesn't care about the distance to the sources, so spherical is fine # for cat2, cat3 in that case. if metric == 'Rlens': if coords2 == 'spherical': coords2 = '3d' if coords3 == 'spherical': coords3 = '3d' if metric == 'Arc': # If all coords are 3d, then leave it 3d, but if any are spherical, # then convert to spherical. if all([c in [None, '3d'] for c in [coords, coords2, coords3]]): # Leave coords as '3d' pass elif any([c not in [None, 'spherical', '3d'] for c in [coords, coords2, coords3]]): raise ValueError("Arc metric is only valid for catalogs with spherical positions.") elif any([c == 'spherical' for c in [coords, coords2, coords3]]): # pragma: no branch # Switch to spherical coords = 'spherical' else: # pragma: no cover # This is impossible now, but here in case we add additional coordinates. raise ValueError("Cannot correlate catalogs with different coordinate systems.") else: if ( (coords2 != coords) or (coords3 is not None and coords3 != coords) ): raise ValueError("Cannot correlate catalogs with different coordinate systems.") if coords not in ['flat', 'spherical', '3d']: raise ValueError("Invalid coords %s"%coords) if metric not in ['Euclidean', 'Rperp', 'OldRperp', 'FisherRperp', 'Rlens', 'Arc', 'Periodic']: raise ValueError("Invalid metric %s"%metric) if metric in ['Rperp', 'OldRperp', 'FisherRperp'] and coords != '3d': raise ValueError("%s metric is only valid for catalogs with 3d positions."%metric) if metric == 'Rlens' and auto: raise ValueError("Rlens metric is only valid for cross correlations.") if metric == 'Rlens' and coords != '3d': raise ValueError("Rlens metric is only valid for catalogs with 3d positions.") if metric == 'Arc' and coords not in ['spherical', '3d']: raise ValueError("Arc metric is only valid for catalogs with spherical positions.") return coords, metric def coord_enum(coords): """Return the C++-layer enum for the given string value of coords. """ if coords == 'flat': return _treecorr.Flat elif coords == 'spherical': return _treecorr.Sphere elif coords == '3d': return _treecorr.ThreeD else: raise ValueError("Invalid coords %s"%coords) def metric_enum(metric): """Return the C++-layer enum for the given string value of metric. """ if metric == 'Euclidean': return _treecorr.Euclidean elif metric == 'Rperp': return metric_enum(Rperp_alias) elif metric == 'FisherRperp': return _treecorr.Rperp elif metric in ['OldRperp']: return _treecorr.OldRperp elif metric == 'Rlens': return _treecorr.Rlens elif metric == 'Arc': return _treecorr.Arc elif metric == 'Periodic': return _treecorr.Periodic else: raise ValueError("Invalid metric %s"%metric) def parse_xyzsep(args, kwargs, _coords): """Parse the different options for passing a coordinate and separation. The allowed parameters are: 1. If _coords == Flat: :param x: The x coordinate of the location for which to count nearby points. :param y: The y coordinate of the location for which to count nearby points. :param sep: The separation distance 2. If _coords == ThreeD: Either :param x: The x coordinate of the location for which to count nearby points. :param y: The y coordinate of the location for which to count nearby points. :param z: The z coordinate of the location for which to count nearby points. :param sep: The separation distance Or :param ra: The right ascension of the location for which to count nearby points. :param dec: The declination of the location for which to count nearby points. :param r: The distance to the location for which to count nearby points. :param sep: The separation distance 3. If _coords == Sphere: :param ra: The right ascension of the location for which to count nearby points. :param dec: The declination of the location for which to count nearby points. :param sep: The separation distance as an angle For all angle parameters (ra, dec, sep), this quantity may be a coord.Angle instance, or units maybe be provided as ra_units, dec_units or sep_units respectively. Finally, in cases where ra, dec are allowed, a coord.CelestialCoord instance may be provided as the first argument. :returns: The effective (x, y, z, sep) as a tuple. """ radec = False if _coords == _treecorr.Flat: if len(args) == 0: if 'x' not in kwargs: raise TypeError("Missing required argument x") if 'y' not in kwargs: raise TypeError("Missing required argument y") if 'sep' not in kwargs: raise TypeError("Missing required argument sep") x = kwargs.pop('x') y = kwargs.pop('y') sep = kwargs.pop('sep') elif len(args) == 1: raise TypeError("x,y should be given as either args or kwargs, not mixed.") elif len(args) == 2: if 'sep' not in kwargs: raise TypeError("Missing required argument sep") x,y = args sep = kwargs.pop('sep') elif len(args) == 3: x,y,sep = args else: raise TypeError("Too many positional args") z = 0 elif _coords == _treecorr.ThreeD: if len(args) == 0: if 'x' in kwargs: if 'y' not in kwargs: raise TypeError("Missing required argument y") if 'z' not in kwargs: raise TypeError("Missing required argument z") x = kwargs.pop('x') y = kwargs.pop('y') z = kwargs.pop('z') else: if 'ra' not in kwargs: raise TypeError("Missing required argument ra") if 'dec' not in kwargs: raise TypeError("Missing required argument dec") ra = kwargs.pop('ra') dec = kwargs.pop('dec') radec = True if 'r' not in kwargs: raise TypeError("Missing required argument r") r = kwargs.pop('r') if 'sep' not in kwargs: raise TypeError("Missing required argument sep") sep = kwargs.pop('sep') elif len(args) == 1: if not isinstance(args[0], coord.CelestialCoord): raise TypeError("Invalid unnamed argument %r"%args[0]) ra = args[0].ra dec = args[0].dec radec = True if 'r' not in kwargs: raise TypeError("Missing required argument r") r = kwargs.pop('r') if 'sep' not in kwargs: raise TypeError("Missing required argument sep") sep = kwargs.pop('sep') elif len(args) == 2: if isinstance(args[0], coord.CelestialCoord): ra = args[0].ra dec = args[0].dec radec = True r = args[1] else: ra, dec = args radec = True if 'r' not in kwargs: raise TypeError("Missing required argument r") r = kwargs.pop('r') if 'sep' not in kwargs: raise TypeError("Missing required argument sep") sep = kwargs.pop('sep') elif len(args) == 3: if isinstance(args[0], coord.CelestialCoord): ra = args[0].ra dec = args[0].dec radec = True r = args[1] sep = args[2] elif isinstance(args[0], coord.Angle): ra, dec, r = args radec = True if 'sep' not in kwargs: raise TypeError("Missing required argument sep") sep = kwargs.pop('sep') elif 'ra_units' in kwargs or 'dec_units' in kwargs: ra, dec, r = args radec = True if 'sep' not in kwargs: raise TypeError("Missing required argument sep") sep = kwargs.pop('sep') else: x, y, z = args if 'sep' not in kwargs: raise TypeError("Missing required argument sep") sep = kwargs.pop('sep') elif len(args) == 4: if isinstance(args[0], coord.Angle): ra, dec, r, sep = args radec = True elif 'ra_units' in kwargs or 'dec_units' in kwargs: ra, dec, r, sep = args radec = True else: x, y, z, sep = args else: raise TypeError("Too many positional args") else: # Sphere if len(args) == 0: if 'ra' not in kwargs: raise TypeError("Missing required argument ra") if 'dec' not in kwargs: raise TypeError("Missing required argument dec") ra = kwargs.pop('ra') dec = kwargs.pop('dec') radec = True if 'sep' not in kwargs: raise TypeError("Missing required argument sep") sep = kwargs.pop('sep') elif len(args) == 1: if not isinstance(args[0], coord.CelestialCoord): raise TypeError("Invalid unnamed argument %r"%args[0]) ra = args[0].ra dec = args[0].dec radec = True if 'sep' not in kwargs: raise TypeError("Missing required argument sep") sep = kwargs.pop('sep') elif len(args) == 2: if isinstance(args[0], coord.CelestialCoord): ra = args[0].ra dec = args[0].dec radec = True sep = args[1] else: ra, dec = args radec = True if 'sep' not in kwargs: raise TypeError("Missing required argument sep") sep = kwargs.pop('sep') elif len(args) == 3: ra, dec, sep = args radec = True else: raise TypeError("Too many positional args") if not isinstance(sep, coord.Angle): if 'sep_units' not in kwargs: raise TypeError("Missing required argument sep_units") sep = sep * coord.AngleUnit.from_name(kwargs.pop('sep_units')) # We actually want the chord distance for this angle. sep = 2. * np.sin(sep/2.) if radec: if not isinstance(ra, coord.Angle): if 'ra_units' not in kwargs: raise TypeError("Missing required argument ra_units") ra = ra * coord.AngleUnit.from_name(kwargs.pop('ra_units')) if not isinstance(dec, coord.Angle): if 'dec_units' not in kwargs: raise TypeError("Missing required argument dec_units") dec = dec * coord.AngleUnit.from_name(kwargs.pop('dec_units')) x,y,z = coord.CelestialCoord(ra, dec).get_xyz() if _coords == _treecorr.ThreeD: x *= r y *= r z *= r if len(kwargs) > 0: raise TypeError("Invalid kwargs: %s"%(kwargs)) return float(x), float(y), float(z), float(sep) class lazy_property(object): """ This decorator will act similarly to @property, but will be efficient for multiple access to values that require some significant calculation. It works by replacing the attribute with the computed value, so after the first access, the property (an attribute of the class) is superseded by the new attribute of the instance. Usage:: @lazy_property def slow_function_to_be_used_as_a_property(self): x = ... # Some slow calculation. return x Base on an answer from http://stackoverflow.com/a/6849299 This implementation taken from GalSim utilities.py """ def __init__(self, fget): self.fget = fget self.func_name = fget.__name__ def __get__(self, obj, cls): if obj is None: return self value = self.fget(obj) setattr(obj, self.func_name, value) return value