Source code for mpdaf.sdetect.segmap

"""
Copyright (c) 2010-2018 CNRS / Centre de Recherche Astrophysique de Lyon
Copyright (c)      2019 Simon Conseil <simon.conseil@univ-lyon1.fr>

All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this
   list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice,
   this list of conditions and the following disclaimer in the documentation
   and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its contributors
   may be used to endorse or promote products derived from this software
   without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
import astropy.units as u
import logging
import numpy as np
from collections import defaultdict
from os.path import exists
from scipy import ndimage as ndi

from ..obj import Image, moffat_image
from ..sdetect import Catalog
from ..tools import isiter, progressbar

__all__ = ('Segmap', 'create_masks_from_segmap')


[docs]class Segmap: """ Handle segmentation maps, where pixel values are sources ids. """ def __init__(self, file_or_image, cut_header_after='D001VER'): if isinstance(file_or_image, str): self.img = Image(file_or_image) elif isinstance(file_or_image, Image): self.img = file_or_image elif isinstance(file_or_image, np.ndarray): self.img = Image(data=file_or_image, copy=False, mask=np.ma.nomask) else: raise TypeError('unknown input') if cut_header_after: if cut_header_after in self.img.data_header: idx = self.img.data_header.index(cut_header_after) self.img.data_header = self.img.data_header[:idx] if cut_header_after in self.img.primary_header: idx = self.img.primary_header.index(cut_header_after) self.img.primary_header = self.img.primary_header[:idx]
[docs] def copy(self): im = self.__class__(self.img.copy()) im._mask = np.ma.nomask return im
[docs] def get_mask(self, value, dtype=np.uint8, dilate=None, inverse=False, struct=None, regrid_to=None, outname=None): if inverse: data = (self.img._data != value) else: data = (self.img._data == value) if dilate: data = dilate_mask(data, niter=dilate, struct=struct) im = Image.new_from_obj(self.img, data) if regrid_to: im = regrid_to_image(im, regrid_to, inplace=True, order=0, antialias=False) np.around(im._data, out=im._data) im._data = im._data.astype(dtype) im._mask = np.ma.nomask if inverse: np.logical_not(im._data, out=im._data) if outname: im.write(outname, savemask='none') return im
[docs] def get_source_mask(self, iden, center, size, minsize=None, dilate=None, dtype=np.uint8, struct=None, unit_center=u.deg, unit_size=u.arcsec, regrid_to=None, outname=None): if minsize is None: minsize = size im = self.img.subimage(center, size, minsize=minsize, unit_center=unit_center, unit_size=unit_size) if isiter(iden): # combine the masks for multiple ids data = np.logical_or.reduce([(im._data == i) for i in iden]) else: data = (im._data == iden) if dilate: data = dilate_mask(data, niter=dilate, struct=struct) if regrid_to: other = regrid_to.subimage(center, size, minsize=0., unit_center=unit_center, unit_size=unit_size) im._data = data.astype(float) im = regrid_to_image(im, other, size=size, order=0, inplace=True, antialias=False) data = np.around(im._data, out=im._data) im._data = data.astype(dtype) im._mask = np.ma.nomask logger = logging.getLogger(__name__) logger.debug('source %s (%.5f, %.5f), extract mask (%d masked pixels)', iden, center[1], center[0], np.count_nonzero(im._data)) if outname: im.write(outname, savemask='none') else: return im
[docs] def align_with_image(self, other, inplace=False, truncate=False, margin=0): """Rotate and truncate the segmap to match 'other'.""" out = self if inplace else self.copy() rot = other.wcs.get_rot() - self.img.wcs.get_rot() if np.abs(rot) > 1.e-3: out.img = self.img.rotate(rot, reshape=True, regrid=True, flux=False, order=0, inplace=inplace) if truncate: y0 = margin - 1 y1 = other.shape[0] - margin x0 = margin - 1 x1 = other.shape[1] - margin pixsky = other.wcs.pix2sky([[y0, x0], [y1, x0], [y0, x1], [y1, x1]], unit=u.deg) pixcrd = out.img.wcs.sky2pix(pixsky) ymin, xmin = pixcrd.min(axis=0) ymax, xmax = pixcrd.max(axis=0) out.img.truncate(ymin, ymax, xmin, xmax, mask=False, unit=None, inplace=True) out.img._data = np.around(out.img._data).astype(int) # FIXME: temporary workaround to make sure that the data_header is # up-to-date when pickling the segmap. This should be detected direclty # in MPDAF. out.img.data_header = out.img.get_wcs_header() return out
[docs] def cmap(self, background_color='#000000'): """matplotlib colormap with random colors. (taken from photutils' segmentation map class)""" return get_cmap(self.img.data.max() + 1, background_color=background_color)
def dilate_mask(data, thres=0.5, niter=1, struct=None): if struct is None: struct = ndi.generate_binary_structure(2, 1) if isinstance(data, np.ma.MaskedArray): data = data.filled(0) maxval = data.max() if maxval != 1: data /= maxval data = data > 0.5 return ndi.binary_dilation(data, structure=struct, iterations=niter) def get_cmap(ncolors, background_color='#000000'): from matplotlib import colors prng = np.random.RandomState(42) h = prng.uniform(low=0.0, high=1.0, size=ncolors) s = prng.uniform(low=0.2, high=0.7, size=ncolors) v = prng.uniform(low=0.5, high=1.0, size=ncolors) hsv = np.dstack((h, s, v)) rgb = np.squeeze(colors.hsv_to_rgb(hsv)) cmap = colors.ListedColormap(rgb) if background_color is not None: cmap.colors[0] = colors.hex2color(background_color) return cmap def regrid_to_image(im, other, order=1, inplace=False, antialias=True, size=None, unit_size=u.arcsec, **kwargs): im.data = im.data.astype(float) refpos = other.wcs.pix2sky([0, 0])[0] if size is not None: newdim = size / other.wcs.get_step(unit=unit_size) else: newdim = other.shape inc = other.wcs.get_axis_increments(unit=unit_size) im = im.regrid(newdim, refpos, [0, 0], inc, order=order, unit_inc=unit_size, inplace=inplace, antialias=antialias) return im def struct_from_moffat_fwhm(wcs, fwhm, psf_threshold=0.5, beta=2.5): """Compute a structuring element for the dilatation, to simulate a convolution with a psf.""" # image size will be twice the full-width, to account for # psf_threshold < 0.5 size = int(round(fwhm / wcs.get_step(u.arcsec)[0])) * 2 + 1 psf = moffat_image(fwhm=(fwhm, fwhm), n=beta, peak=True, wcs=wcs[:size, :size]) # remove useless zeros on the edges. psf.mask_selection(psf._data < psf_threshold) psf.crop() assert tuple(np.array(psf.shape) % 2) == (1, 1) return ~psf.mask def _get_psf_convolution_params(convolve_fwhm, segmap, psf_threshold): if convolve_fwhm: # compute a structuring element for the dilatation, to simulate # a convolution with a psf, but faster. dilateit = 1 struct = struct_from_moffat_fwhm(segmap.img.wcs, convolve_fwhm, psf_threshold=psf_threshold) else: dilateit = 0 struct = None return dilateit, struct
[docs]def create_masks_from_segmap( segmap, catalog, ref_image, n_jobs=1, skip_existing=True, masksky_name='mask-sky.fits', maskobj_name='mask-source-%05d.fits', idname='ID', raname='RA', decname='DEC', margin=0, mask_size=(20, 20), convolve_fwhm=0, psf_threshold=0.5, verbose=0): """Create binary masks from a segmentation map. For each source from the catalog, extract the segmap region, align with ref_image and regrid to the resolution of ref_image. Parameters ---------- segmap : str or `mpdaf.obj.Image` The segmentation map. catalog : str or `mpdaf.sdetect.Catalog` or `astropy.table.Table` The catalog with sources id and position. ref_image : str or `mpdaf.obj.Image` The reference image, with which the segmap is aligned. n_jobs : int Number of parallel processes (for joblib). skip_existing : bool If True, skip sources for which the mask file exists. masksky_name : str or callable The filename for the sky mask. maskobj_name : str or callable The filename for the source masks, with a format string that will be substituted with the ID, e.g. ``%05d``. idname, raname, decname : str Name of the 'id', 'ra' and 'dec' columns. margin : float Margin used for the segmap alignment (pixels). mask_size : tuple Size of the source masks (arcsec). convolve_fwhm : float FWHM for the PSF convolution (arcsec). psf_threshold : float Threshold applied to the PSF to get a binary image. verbose: int Verbosity level for joblib.Parallel. """ from joblib import delayed, Parallel logger = logging.getLogger(__name__) if isinstance(ref_image, str): ref_image = Image(ref_image) if isinstance(catalog, str): catalog = Catalog.read(catalog) if not isinstance(segmap, Segmap): segmap = Segmap(segmap) logger.info('Aligning segmap with reference image') segm = segmap.align_with_image(ref_image, truncate=True, margin=margin) dilateit, struct = _get_psf_convolution_params(convolve_fwhm, segm, psf_threshold) # create sky mask masksky = masksky_name() if callable(masksky_name) else masksky_name if exists(masksky) and skip_existing: logger.debug('sky mask exists, skipping') else: logger.debug('creating sky mask') segm.get_mask(0, inverse=True, dilate=dilateit, struct=struct, regrid_to=ref_image, outname=masksky) # extract source masks minsize = 0. to_compute = [] stats = defaultdict(list) for row in catalog: id_ = int(row[idname]) # need int, not np.int64 source_path = (maskobj_name(id_) if callable(maskobj_name) else maskobj_name % id_) if skip_existing and exists(source_path): stats['skipped'].append(id_) else: center = (row[decname], row[raname]) stats['computed'].append(id_) to_compute.append(delayed(segm.get_source_mask)( id_, center, mask_size, minsize=minsize, struct=struct, dilate=dilateit, outname=source_path, regrid_to=ref_image)) # FIXME: check which value to use for max_nbytes if to_compute: logger.info('computing masks for %d sources', len(to_compute)) Parallel(n_jobs=n_jobs, verbose=verbose)(progressbar(to_compute)) else: logger.info('nothing to compute')