# Copyright (c) 2003-2024 by Mike Jarvis
#
# TreeCorr is free software: redistribution and use in source and binary forms,
# with or without modification, are permitted provided that the following
# conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions, and the disclaimer given in the accompanying LICENSE
# file.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions, and the disclaimer given in the documentation
# and/or other materials provided with the distribution.
"""
.. module:: util
"""
import numpy as np
import os
import coord
import functools
import inspect
import warnings
from . import _treecorr
from . import Rperp_alias
from .writer import AsciiWriter, FitsWriter, HdfWriter
from .reader import AsciiReader, FitsReader, HdfReader
max_omp_threads=None
[docs]def set_max_omp_threads(num_threads, logger=None):
"""Set the maximum allowed number of OpenMP threads to use in the C++ layer in any
further TreeCorr functions
:param num_threads: The target maximum number of threads to allow. None means no limit.
:param logger: If desired, a logger object for logging any warnings here. (default: None)
"""
global max_omp_threads
max_omp_threads=num_threads
[docs]def set_omp_threads(num_threads, logger=None):
"""Set the number of OpenMP threads to use in the C++ layer.
:param num_threads: The target number of threads to use
:param logger: If desired, a logger object for logging any warnings here. (default: None)
:returns: The number of threads OpenMP reports that it will use. Typically this
matches the input, but OpenMP reserves the right not to comply with
the requested number of threads.
"""
input_num_threads = num_threads # Save the input value.
# If num_threads is auto, get it from cpu_count
if num_threads is None or num_threads <= 0:
import multiprocessing
num_threads = multiprocessing.cpu_count()
if logger:
logger.debug('multiprocessing.cpu_count() = %d',num_threads)
# Max at max_omp_threads, if set.
if max_omp_threads is not None and num_threads > max_omp_threads:
num_threads = max_omp_threads
if logger:
logger.debug('max_omp_threads = %d',max_omp_threads)
# Tell OpenMP to use this many threads
if logger:
logger.debug('Telling OpenMP to use %d threads',num_threads)
# See comment about this in get_omp_threads. Do it here too.
var = "OMP_PROC_BIND"
if var not in os.environ: # pragma: no cover
os.environ[var] = "false"
num_threads = _treecorr.SetOMPThreads(num_threads)
# Report back appropriately.
if logger:
logger.debug('OpenMP reports that it will use %d threads',num_threads)
if num_threads > 1:
logger.info('Using %d threads.',num_threads)
elif input_num_threads is not None and input_num_threads != 1:
# Only warn if the user specifically asked for num_threads != 1.
logger.warning("Unable to use multiple threads, since OpenMP is not enabled.")
return num_threads
[docs]def get_omp_threads():
"""Get the number of OpenMP threads currently set to be used in the C++ layer.
:returns: The number of threads OpenMP reports that it will use.
"""
# Some OMP implemenations have a bug where if omp_get_max_threads() is called
# (which is what this function does), it sets something called thread affinity.
# The upshot of that is that multiprocessing (i.e. not even just omp threading) is
# confined to a single hardware thread. Yeah, it's idiotic, but that seems to be
# the case. The only solution found by Eli, who looked into it pretty hard, is to
# set the env variable OMP_PROC_BIND to "false". This seems to stop the bad behavior.
# So we do it here always before calling GetOMPThreads.
# If this breaks someone valid use of this variable, let us know and we can try to
# come up with another solution, but without this lots of multiprocessing breaks.
var = "OMP_PROC_BIND"
if var not in os.environ: # pragma: no cover
os.environ[var] = "false"
return _treecorr.GetOMPThreads()
def parse_file_type(file_type, file_name, output=False, logger=None):
"""Parse the file_type from the file_name if necessary
:param file_type: The input file_type. If None, then parse from file_name's extension.
:param file_name: The filename to use for parsing if necessary.
:param output: Limit to output file types (FITS/HDF/ASCII)? (default: False)
:param logger: A logger if desired. (default: None)
:returns: The parsed file_type.
"""
if file_type is None:
import os
name, ext = os.path.splitext(file_name)
ext = ext.lower()
if ext.startswith('.fit'):
file_type = 'FITS'
elif ext.startswith('.hdf') or ext.startswith('.h5'):
file_type = 'HDF'
elif not output and ext.startswith('.par'):
file_type = 'Parquet'
else:
file_type = 'ASCII'
if logger:
logger.info(" file_type assumed to be %s from the file name.",file_type)
return file_type.upper()
def make_writer(file_name, precision=4, file_type=None, logger=None):
"""Factory function to make a writer instance of the correct type.
"""
# Figure out which file type to use.
file_type = parse_file_type(file_type, file_name, output=True, logger=logger)
if file_type == 'FITS':
writer = FitsWriter(file_name, logger=logger)
elif file_type == 'HDF':
writer = HdfWriter(file_name, logger=logger)
elif file_type == 'ASCII':
writer = AsciiWriter(file_name, precision=precision, logger=logger)
else:
raise ValueError("Invalid file_type %s"%file_type)
return writer
def make_reader(file_name, file_type=None, logger=None):
"""Factory function to make a writer instance of the correct type.
"""
# Figure out which file type to use.
file_type = parse_file_type(file_type, file_name, output=False, logger=logger)
if file_type == 'FITS':
reader = FitsReader(file_name, logger=logger)
elif file_type == 'HDF':
reader = HdfReader(file_name, logger=logger)
elif file_type == 'ASCII':
reader = AsciiReader(file_name, logger=logger)
else:
raise ValueError("Invalid file_type %s"%file_type)
return reader
class LRU_Cache(object):
""" Simplified Least Recently Used Cache.
Mostly stolen from http://code.activestate.com/recipes/577970-simplified-lru-cache/,
but added a method for dynamic resizing. The least recently used cached item is
overwritten on a cache miss.
Note: This has additional functionality beyond what functools.lru_cache provides.
1. The ability to resize the maxsize non-destructively.
2. The key is only on the args, not kwargs, so a logger can be provided as a kwarg
without triggering a cache miss.
:param user_function: A python function to cache.
:param maxsize: Maximum number of inputs to cache. [Default: 1024]
Usage
-----
>>> def slow_function(*args) # A slow-to-evaluate python function
>>> ...
>>>
>>> v1 = slow_function(*k1) # Calling function is slow
>>> v1 = slow_function(*k1) # Calling again with same args is still slow
>>> cache = galsim.utilities.LRU_Cache(slow_function)
>>> v1 = cache(*k1) # Returns slow_function(*k1), slowly the first time
>>> v1 = cache(*k1) # Returns slow_function(*k1) again, but fast this time.
Methods
-------
>>> cache.resize(maxsize) # Resize the cache, either upwards or downwards. Upwards resizing
# is non-destructive. Downwards resizing will remove the least
# recently used items first.
"""
def __init__(self, user_function, maxsize=1024):
# Link layout: [PREV, NEXT, KEY, RESULT]
self.root = [None, None, None, None]
self.user_function = user_function
self.cache = {}
last = self.root
for i in range(maxsize):
key = object()
self.cache[key] = last[1] = last = [last, self.root, key, None]
self.root[0] = last
self.count = 0
def __call__(self, *key, **kwargs):
link = self.cache.get(key)
if link is not None:
# Cache hit: move link to last position
link_prev, link_next, _, result = link
link_prev[1] = link_next
link_next[0] = link_prev
last = self.root[0]
last[1] = self.root[0] = link
link[0] = last
link[1] = self.root
return result
# Cache miss: evaluate and insert new key/value at root, then increment root
# so that just-evaluated value is in last position.
result = self.user_function(*key, **kwargs)
self.root[2] = key
self.root[3] = result
oldroot = self.root
self.root = self.root[1]
oldkey = self.root[2]
self.root[2] = None
self.root[3] = None
self.cache[key] = oldroot
del self.cache[oldkey]
if self.count < self.size: self.count += 1
return result
def values(self):
"""Lists all items stored in the cache"""
return list([v[3] for v in self.cache.values() if v[3] is not None])
@property
def last_value(self):
"""Return the most recently used value"""
return self.root[0][3]
def resize(self, maxsize):
""" Resize the cache. Increasing the size of the cache is non-destructive, i.e.,
previously cached inputs remain in the cache. Decreasing the size of the cache will
necessarily remove items from the cache if the cache is already filled. Items are removed
in least recently used order.
:param maxsize: The new maximum number of inputs to cache.
"""
oldsize = len(self.cache)
if maxsize == oldsize:
return
else:
if maxsize < 0:
raise ValueError("Invalid maxsize")
elif maxsize < oldsize:
for i in range(oldsize - maxsize):
# Delete root.next
current_next_link = self.root[1]
new_next_link = self.root[1] = self.root[1][1]
new_next_link[0] = self.root
del self.cache[current_next_link[2]]
self.count = min(self.count, maxsize)
else: # maxsize > oldsize:
for i in range(maxsize - oldsize):
# Insert between root and root.next
key = object()
self.cache[key] = link = [self.root, self.root[1], key, None]
self.root[1][0] = link
self.root[1] = link
def clear(self):
""" Clear all items from the cache.
"""
maxsize = len(self.cache)
self.cache.clear()
last = self.root
for i in range(maxsize):
last[3] = None # Sever pointer to any existing result.
key = object()
self.cache[key] = last[1] = last = [last, self.root, key, None]
self.root[0] = last
self.count = 0
@property
def size(self):
return len(self.cache)
def parse_metric(metric, coords, coords2=None, coords3=None):
"""
Convert a string metric into the corresponding enum to pass to the C code.
"""
if coords2 is None:
auto = True
else:
auto = False
# Special Rlens doesn't care about the distance to the sources, so spherical is fine
# for cat2, cat3 in that case.
if metric == 'Rlens':
if coords2 == 'spherical': coords2 = '3d'
if coords3 == 'spherical': coords3 = '3d'
if metric == 'Arc':
# If all coords are 3d, then leave it 3d, but if any are spherical,
# then convert to spherical.
if all([c in [None, '3d'] for c in [coords, coords2, coords3]]):
# Leave coords as '3d'
pass
elif any([c not in [None, 'spherical', '3d'] for c in [coords, coords2, coords3]]):
raise ValueError("Arc metric is only valid for catalogs with spherical positions.")
elif any([c == 'spherical' for c in [coords, coords2, coords3]]): # pragma: no branch
# Switch to spherical
coords = 'spherical'
else: # pragma: no cover
# This is impossible now, but here in case we add additional coordinates.
raise ValueError("Cannot correlate catalogs with different coordinate systems.")
else:
if ( (coords2 != coords) or (coords3 is not None and coords3 != coords) ):
raise ValueError("Cannot correlate catalogs with different coordinate systems.")
if coords not in ['flat', 'spherical', '3d']:
raise ValueError("Invalid coords %s"%coords)
if metric not in ['Euclidean', 'Rperp', 'OldRperp', 'FisherRperp', 'Rlens', 'Arc', 'Periodic']:
raise ValueError("Invalid metric %s"%metric)
if metric in ['Rperp', 'OldRperp', 'FisherRperp'] and coords != '3d':
raise ValueError("%s metric is only valid for catalogs with 3d positions."%metric)
if metric == 'Rlens' and auto:
raise ValueError("Rlens metric is only valid for cross correlations.")
if metric == 'Rlens' and coords != '3d':
raise ValueError("Rlens metric is only valid for catalogs with 3d positions.")
if metric == 'Arc' and coords not in ['spherical', '3d']:
raise ValueError("Arc metric is only valid for catalogs with spherical positions.")
return coords, metric
def coord_enum(coords):
"""Return the C++-layer enum for the given string value of coords.
"""
if coords == 'flat':
return _treecorr.Flat
elif coords == 'spherical':
return _treecorr.Sphere
elif coords == '3d':
return _treecorr.ThreeD
else:
raise ValueError("Invalid coords %s"%coords)
def metric_enum(metric):
"""Return the C++-layer enum for the given string value of metric.
"""
if metric == 'Euclidean':
return _treecorr.Euclidean
elif metric == 'Rperp':
return metric_enum(Rperp_alias)
elif metric == 'FisherRperp':
return _treecorr.Rperp
elif metric in ['OldRperp']:
return _treecorr.OldRperp
elif metric == 'Rlens':
return _treecorr.Rlens
elif metric == 'Arc':
return _treecorr.Arc
elif metric == 'Periodic':
return _treecorr.Periodic
else:
raise ValueError("Invalid metric %s"%metric)
def parse_xyzsep(args, kwargs, _coords):
"""Parse the different options for passing a coordinate and separation.
The allowed parameters are:
1. If _coords == Flat:
:param x: The x coordinate of the location for which to count nearby points.
:param y: The y coordinate of the location for which to count nearby points.
:param sep: The separation distance
2. If _coords == ThreeD:
Either
:param x: The x coordinate of the location for which to count nearby points.
:param y: The y coordinate of the location for which to count nearby points.
:param z: The z coordinate of the location for which to count nearby points.
:param sep: The separation distance
Or
:param ra: The right ascension of the location for which to count nearby points.
:param dec: The declination of the location for which to count nearby points.
:param r: The distance to the location for which to count nearby points.
:param sep: The separation distance
3. If _coords == Sphere:
:param ra: The right ascension of the location for which to count nearby points.
:param dec: The declination of the location for which to count nearby points.
:param sep: The separation distance as an angle
For all angle parameters (ra, dec, sep), this quantity may be a coord.Angle instance, or
units maybe be provided as ra_units, dec_units or sep_units respectively.
Finally, in cases where ra, dec are allowed, a coord.CelestialCoord instance may be
provided as the first argument.
:returns: The effective (x, y, z, sep) as a tuple.
"""
radec = False
if _coords == _treecorr.Flat:
if len(args) == 0:
if 'x' not in kwargs:
raise TypeError("Missing required argument x")
if 'y' not in kwargs:
raise TypeError("Missing required argument y")
if 'sep' not in kwargs:
raise TypeError("Missing required argument sep")
x = kwargs.pop('x')
y = kwargs.pop('y')
sep = kwargs.pop('sep')
elif len(args) == 1:
raise TypeError("x,y should be given as either args or kwargs, not mixed.")
elif len(args) == 2:
if 'sep' not in kwargs:
raise TypeError("Missing required argument sep")
x,y = args
sep = kwargs.pop('sep')
elif len(args) == 3:
x,y,sep = args
else:
raise TypeError("Too many positional args")
z = 0
elif _coords == _treecorr.ThreeD:
if len(args) == 0:
if 'x' in kwargs:
if 'y' not in kwargs:
raise TypeError("Missing required argument y")
if 'z' not in kwargs:
raise TypeError("Missing required argument z")
x = kwargs.pop('x')
y = kwargs.pop('y')
z = kwargs.pop('z')
else:
if 'ra' not in kwargs:
raise TypeError("Missing required argument ra")
if 'dec' not in kwargs:
raise TypeError("Missing required argument dec")
ra = kwargs.pop('ra')
dec = kwargs.pop('dec')
radec = True
if 'r' not in kwargs:
raise TypeError("Missing required argument r")
r = kwargs.pop('r')
if 'sep' not in kwargs:
raise TypeError("Missing required argument sep")
sep = kwargs.pop('sep')
elif len(args) == 1:
if not isinstance(args[0], coord.CelestialCoord):
raise TypeError("Invalid unnamed argument %r"%args[0])
ra = args[0].ra
dec = args[0].dec
radec = True
if 'r' not in kwargs:
raise TypeError("Missing required argument r")
r = kwargs.pop('r')
if 'sep' not in kwargs:
raise TypeError("Missing required argument sep")
sep = kwargs.pop('sep')
elif len(args) == 2:
if isinstance(args[0], coord.CelestialCoord):
ra = args[0].ra
dec = args[0].dec
radec = True
r = args[1]
else:
ra, dec = args
radec = True
if 'r' not in kwargs:
raise TypeError("Missing required argument r")
r = kwargs.pop('r')
if 'sep' not in kwargs:
raise TypeError("Missing required argument sep")
sep = kwargs.pop('sep')
elif len(args) == 3:
if isinstance(args[0], coord.CelestialCoord):
ra = args[0].ra
dec = args[0].dec
radec = True
r = args[1]
sep = args[2]
elif isinstance(args[0], coord.Angle):
ra, dec, r = args
radec = True
if 'sep' not in kwargs:
raise TypeError("Missing required argument sep")
sep = kwargs.pop('sep')
elif 'ra_units' in kwargs or 'dec_units' in kwargs:
ra, dec, r = args
radec = True
if 'sep' not in kwargs:
raise TypeError("Missing required argument sep")
sep = kwargs.pop('sep')
else:
x, y, z = args
if 'sep' not in kwargs:
raise TypeError("Missing required argument sep")
sep = kwargs.pop('sep')
elif len(args) == 4:
if isinstance(args[0], coord.Angle):
ra, dec, r, sep = args
radec = True
elif 'ra_units' in kwargs or 'dec_units' in kwargs:
ra, dec, r, sep = args
radec = True
else:
x, y, z, sep = args
else:
raise TypeError("Too many positional args")
else: # Sphere
if len(args) == 0:
if 'ra' not in kwargs:
raise TypeError("Missing required argument ra")
if 'dec' not in kwargs:
raise TypeError("Missing required argument dec")
ra = kwargs.pop('ra')
dec = kwargs.pop('dec')
radec = True
if 'sep' not in kwargs:
raise TypeError("Missing required argument sep")
sep = kwargs.pop('sep')
elif len(args) == 1:
if not isinstance(args[0], coord.CelestialCoord):
raise TypeError("Invalid unnamed argument %r"%args[0])
ra = args[0].ra
dec = args[0].dec
radec = True
if 'sep' not in kwargs:
raise TypeError("Missing required argument sep")
sep = kwargs.pop('sep')
elif len(args) == 2:
if isinstance(args[0], coord.CelestialCoord):
ra = args[0].ra
dec = args[0].dec
radec = True
sep = args[1]
else:
ra, dec = args
radec = True
if 'sep' not in kwargs:
raise TypeError("Missing required argument sep")
sep = kwargs.pop('sep')
elif len(args) == 3:
ra, dec, sep = args
radec = True
else:
raise TypeError("Too many positional args")
if not isinstance(sep, coord.Angle):
if 'sep_units' not in kwargs:
raise TypeError("Missing required argument sep_units")
sep = sep * coord.AngleUnit.from_name(kwargs.pop('sep_units'))
# We actually want the chord distance for this angle.
sep = 2. * np.sin(sep/2.)
if radec:
if not isinstance(ra, coord.Angle):
if 'ra_units' not in kwargs:
raise TypeError("Missing required argument ra_units")
ra = ra * coord.AngleUnit.from_name(kwargs.pop('ra_units'))
if not isinstance(dec, coord.Angle):
if 'dec_units' not in kwargs:
raise TypeError("Missing required argument dec_units")
dec = dec * coord.AngleUnit.from_name(kwargs.pop('dec_units'))
x,y,z = coord.CelestialCoord(ra, dec).get_xyz()
if _coords == _treecorr.ThreeD:
x *= r
y *= r
z *= r
if len(kwargs) > 0:
raise TypeError("Invalid kwargs: %s"%(kwargs))
return float(x), float(y), float(z), float(sep)
class lazy_property(object):
"""
This decorator will act similarly to @property, but will be efficient for multiple access
to values that require some significant calculation.
It works by replacing the attribute with the computed value, so after the first access,
the property (an attribute of the class) is superseded by the new attribute of the instance.
Usage::
@lazy_property
def slow_function_to_be_used_as_a_property(self):
x = ... # Some slow calculation.
return x
Base on an answer from http://stackoverflow.com/a/6849299
This implementation taken from GalSim utilities.py
"""
def __init__(self, fget):
self.fget = fget
self.func_name = fget.__name__
def __get__(self, obj, cls):
if obj is None:
return self
value = self.fget(obj)
setattr(obj, self.func_name, value)
return value