# Copyright (c) 2016 by Mike Jarvis and the other collaborators on GitHub at
# https://github.com/rmjarvis/Piff All rights reserved.
#
# Piff is free software: Redistribution and use in source and binary forms
# with or without modification, are permitted provided that the following
# conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the disclaimer given in the accompanying LICENSE
# file.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the disclaimer given in the documentation
# and/or other materials provided with the distribution.
"""
.. module:: knn_interp
"""
import numpy as np
import galsim
from .interp import Interp
from .star import Star, StarFit
[docs]class kNNInterp(Interp):
"""
An interpolator that uses sklearn KNeighborsRegressor to interpolate a
single surface
:param keys: A list of star attributes to interpolate from [default: ('u', 'v')]
:param n_neighbors: Number of neighbors used for interpolation. [default: 15]
:param weights: Weight function used in prediction. Possible values are 'uniform',
'distance', and a callable function which accepts an array of distances
and returns an array of the same shape containing the weights.
[default: 'uniform']
:param algorithm: Algorithm used to compute nearest neighbors. Possible values are
'ball_tree', 'kd_tree', 'brute', and 'auto', which tries to determine the
best choice. [default: 'auto']
:param p: Power parameter of distance metrice. p=2 is default euclidean distance,
p=1 is manhattan. [default: 2]
:param logger: A logger object for logging debug info. [default: None]
"""
def __init__(self, keys=('u','v'), n_neighbors=15, weights='uniform', algorithm='auto',
p=2,logger=None):
self.kwargs = {
'keys': keys,
}
self.knr_kwargs = {
'n_neighbors': n_neighbors,
'weights': weights,
'algorithm': algorithm,
'p': p,
}
self.kwargs.update(self.knr_kwargs)
self.keys = keys
from sklearn.neighbors import KNeighborsRegressor
self.knn = KNeighborsRegressor(**self.knr_kwargs)
@property
def property_names(self):
"""List of properties used by this interpolant.
"""
return self.keys
def _fit(self, locations, targets, logger=None):
"""Update the Neighbors Regressor with data
:param locations: The locations for interpolating. (n_samples, n_features).
(In sklearn parlance, this is 'X'.)
:param targets: The target values. (n_samples, n_targets).
(In sklearn parlance, this is 'y'.)
"""
logger = galsim.config.LoggerWrapper(logger)
self.knn.fit(locations, targets)
self.locations = locations
logger.debug('locations updated to shape: %s', self.locations.shape)
self.targets = targets
logger.debug('targets updated to shape: %s', self.targets.shape)
def _predict(self, locations, logger=None):
"""Predict from knn.
:param locations: The locations for interpolating. (n_samples, n_features).
In sklearn parlance, this is 'X'
:returns: Regressed parameters y (n_samples, n_targets)
"""
logger = galsim.config.LoggerWrapper(logger)
regression = self.knn.predict(locations)
logger.debug('Regression shape: %s', regression.shape)
return regression
[docs] def getProperties(self, star, logger=None):
"""Extract the appropriate properties to use as the independent variables for the
interpolation.
Take self.keys from star.data
:param star: A Star instances from which to extract the properties to use.
:returns: A np vector of these properties.
"""
return np.array([star.data[key] for key in self.keys])
[docs] def initialize(self, stars, logger=None):
"""Initialize both the interpolator to some state prefatory to any solve iterations and
initialize the stars for use with this interpolator.
:param stars: A list of Star instances to interpolate between
:param logger: A logger object for logging debug info. [default: None]
"""
return stars
[docs] def solve(self, star_list, logger=None):
"""Solve for the interpolation coefficients given stars and attributes
:param star_list: A list of Star instances to interpolate between
:param logger: A logger object for logging debug info. [default: None]
"""
locations = np.array([self.getProperties(star) for star in star_list])
targets = np.array([star.fit.params for star in star_list])
self._fit(locations, targets)
[docs] def interpolate(self, star, logger=None):
"""Perform the interpolation to find the interpolated parameter vector at some position. Calls interpolateList because sklearn prefers list input anyways
:param star: A Star instance to which one wants to interpolate
:param logger: A logger object for logging debug info. [default: None]
:returns: a new Star instance with its StarFit member holding the interpolated parameters
"""
# because of sklearn formatting, call interpolateList and take 0th entry
return self.interpolateList([star], logger=logger)[0]
[docs] def interpolateList(self, star_list, logger=None):
"""Perform the interpolation for a list of stars.
:param star_list: A list of Star instances to which to interpolate.
:param logger: A logger object for logging debug info. [default: None]
:returns: a list of new Star instances with interpolated parameters
"""
logger = galsim.config.LoggerWrapper(logger)
locations = np.array([self.getProperties(star) for star in star_list])
targets = self._predict(locations)
star_list_fitted = []
for yi, star in zip(targets, star_list):
if star.fit is None:
fit = StarFit(yi)
else:
fit = star.fit.newParams(yi)
star_list_fitted.append(Star(star.data, fit))
return star_list_fitted
def _finish_write(self, fits, extname):
"""Write the solution to a FITS binary table.
Save the knn params and the locations and targets arrays
:param fits: An open fitsio.FITS object.
:param extname: The base name of the extension with the interp information.
"""
dtypes = [('LOCATIONS', self.locations.dtype, self.locations.shape),
('TARGETS', self.targets.dtype, self.targets.shape),
]
data = np.empty(1, dtype=dtypes)
# assign
data['LOCATIONS'] = self.locations
data['TARGETS'] = self.targets
# write to fits
fits.write_table(data, extname=extname + '_solution')
def _finish_read(self, fits, extname):
"""Read the solution from a FITS binary table.
:param fits: An open fitsio.FITS object.
:param extname: The base name of the extension with the interp information.
"""
data = fits[extname + '_solution'].read()
# self.locations and self.targets assigned in _fit
self._fit(data['LOCATIONS'][0], data['TARGETS'][0])