Source code for sloop.interp

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Interpolation stuff
"""
import os
import re
import logging
from enum import IntEnum
import numpy as np
import pandas as pd
import xarray as xr
import xesmf
import xoa.cf
from xoa.misc import XEnumMeta
from .cf import (rename_from_standard_name,
                 get_lon_coord_name, get_lat_coord_name)
from .filters import erode_coast
from .io import xr_open_concat
from .times import period_vortex_to_pandas


LOGGER = logging.getLogger(__name__)


[docs]class InterpFlag(IntEnum, metaclass=XEnumMeta): PICKED = 1 LINEAR = 2 PERSISTENT = 3
[docs]def interp_time(ds, dates, preserve_encoding=False, keep_flag=False, diurnal_vars=None, persistency=None): """Interpolate a dataset in time A single interpolation per time step, with persistency. Parameters ---------- ds : xarray.Dataset Dataset to interpolate. dates : single or list of dates Interpolation dates. Returns ------- xarray.Dataset """ # Time dim cfs = xoa.cf.get_cf_specs() time_name = cfs.search_coord(ds, cf_name="time", get="cf_name") time = ds.indexes[time_name] begindate_nc = time[0] enddate_nc = time[-1] dates = pd.to_datetime(dates) # For persistency ds_first = ds.isel({time_name: slice(0, 1)}) ds_last = ds.isel({time_name: slice(-1, None)}) flag = xr.DataArray(np.zeros(len(dates), dtype="i"), coords=[dates], dims=[time_name], name='flag') # Loop on target times dss = [] for date in dates: # Persistency if date < begindate_nc or date > enddate_nc: # New dataset this_ds = (ds_first if date < begindate_nc else ds_last) # New time this_ds = this_ds.assign_coords({time_name: [date]}) this_ds[time_name].attrs.update(ds[time_name].attrs) if persistency: m = re.search(r"mean_(?P<term>.*)(?P<units>.)", persistency, re.DOTALL) if m: ds_tr = ds.copy() timedelta = pd.to_timedelta(int(m["term"]), m["units"]) if date > enddate_nc: date_target = enddate_nc-timedelta ds_tr = ds_tr.where(ds_tr[time_name]>=date_target).mean(dim=time_name, keepdims=True) elif date < begindate_nc: date_target = begindate_nc+timedelta ds_tr = ds_tr.where(ds_tr[time_name]<=date_target).mean(dim=time_name, keepdims=True) for vname in ds_tr: this_ds[vname] = this_ds[vname].copy(data=ds_tr[vname].data) if diurnal_vars: if date < begindate_nc: time_target = date+pd.Timedelta("24h") elif date > enddate_nc: time_target = date-pd.Timedelta("24h") for dss_target in dss: if dss_target[time_name].values[0]==time_target: for vname in this_ds: if vname in diurnal_vars: this_ds[vname] = this_ds[vname].copy(data=dss_target[vname].data) flag.loc[date] = InterpFlag.PERSISTENT # Interpolation else: # Pick this date directly if available if date in time: this_ds = ds.sel({time_name: [date]}) flag.loc[date] = InterpFlag.PICKED # Interpolate because not available else: # Load a two-time steps dataset it = ds.indexes[time_name].get_loc(date, method='ffill') tmp_ds = ds.isel({time_name: slice(it, it+2)}).chunk( {time_name: (2,)}) # Interpolate this_ds = tmp_ds.interp({time_name: [date]}).load() flag.loc[date] = InterpFlag.LINEAR # Preserve netcdf compression if preserve_encoding: for vname, var in tmp_ds.items(): this_ds[vname].encoding.update(var.encoding) # New time this_ds = this_ds.assign_coords({time_name: [date]}) this_ds[time_name].attrs.update(ds[time_name].attrs) # Append to list dss.append(this_ds) del this_ds # Concatenate dso = xr.concat( dss, time_name, compat='no_conflicts', coords='different', data_vars='all') flag.attrs['flags'] = str(InterpFlag) if keep_flag: dso["flag"] = flag # Attributes and encoding dso[time_name].attrs.update(ds[time_name].attrs) dso[time_name].encoding.update(ds[time_name].encoding) if preserve_encoding: for vname, var in ds.items(): dso[vname].encoding.update(var.encoding) dso = dso.transpose(time_name,...) return dso
[docs]def nc_interp_time(ncfiles, dates, ncout, preproc=None, postproc=None): """Interpolate netcdf files to a single netcdf file at a given frequency Parameters ---------- ncfiles: str, list Single or list of netcdf files dates: list Output dates ncout: str Netcdf file with optional date patterns preproc: callable Callable that operates on the input datasets after reading it postproc: callable Callable that operates on the output dataset before saving it Return ------ str Output netcdf file """ # Input dsi = xr_open_concat(ncfiles) # Preproc if preproc: dsi = preproc(dsi) # Output dates if not isinstance(dates, list): dates = [dates] dates = [pd.to_datetime(date) for date in dates] date0 = dates[0] date1 = dates[-1] LOGGER.debug("Interpolating to a single file: " f"from {date0:%Y-%m-%dT%H:%M} to {date1:%Y-%m-%dT%H:%M}") # Interpolate dso = interp_time(dsi, dates) # Callback postprocessing if postproc is not None: dso = postproc(dso) # Write netcdf ncfile = ncout.format(**locals()) dso.to_netcdf(ncfile) LOGGER.info(f"Created interpolation file: {ncfile}") return ncfile
[docs]def nc_interp_at_freq_to_daily_nc( ncfiles, freq, ncfmt, preproc=None, postproc=None): """Interpolate netcdf files to daily netcdf files at a given frequency Parameters ---------- ncfiles: str, list Single or list of netcdf files freq: str Output frequency ncfmt: str Netcdf file path format with date patterns preproc: callable Callable that operates on the input datasets after reading it postproc: callable Callable that operates on the output dataset before saving it Return ------ list Effective list of output netcdf files """ # Input dsi = xr_open_concat(ncfiles) # Preproc if preproc: dsi = preproc(dsi) # Output dates cfs = xoa.cf.get_cf_specs() time = cfs.coords.get(dsi, "time") freq = period_vortex_to_pandas(freq) date0 = pd.to_datetime(time[0].values).ceil(freq) date1 = pd.to_datetime(time[-1].values).ceil(freq) day0 = date0.floor("1D") day1 = date1.floor("1D") LOGGER.info("Interpolating to daily files: " f"from {day0:%Y-%m-%dT%H:%M} to {day1:%Y-%m-%dT%H:%M} " f"at frequency {freq}") days = pd.date_range(day0, day1, freq='D') # Loop on dates ncfiles = [] for iday in range(len(days)): # Date date = days[iday] LOGGER.debug(f"Date: {date}") t = time.sel(time=slice(date, None)) # Hours of the day if len(t) == 1: times = [date] else: times = pd.date_range( date, date + pd.to_timedelta("24 hours"), freq=freq, closed='left') # Interpolate dso = interp_time(dsi, times, preserve_encoding=True) # Callback postprocessing if postproc is not None: dso = postproc(dso) date = pd.to_datetime(date).strftime("%Y%m%d") # Save ncfile = ncfmt.format(**locals()) ncfiles.append(ncfile) dso.to_netcdf(ncfile) LOGGER.info(f"Created daily interpolation file: {ncfile}") del dso return ncfiles
[docs]def nc_interp_at_freq_to_nc( ncfiles, freq, ncout, begindate=None, preproc=None, postproc=None): """Interpolate netcdf files to a single netcdf file at a given frequency Parameters ---------- ncfiles: str, list Single or list of netcdf files freq: str Output frequency ncout: str Netcdf file with optional date patterns begindate: date Start from this date preproc: callable Callable that operates on the input datasets after reading it postproc: callable Callable that operates on the output dataset before saving it Return ------ str Output netcdf file """ # Input dsi = xr_open_concat(ncfiles) # Preproc if preproc: dsi = preproc(dsi) # Output dates cfs = xoa.cf.get_cf_specs() time = cfs.coords.get(dsi, "time").values freq = period_vortex_to_pandas(freq) date0 = pd.to_datetime(begindate or time[0]).ceil(freq) date1 = pd.to_datetime(time[-1]).floor(freq) LOGGER.debug("Interpolating to a single file: " f"from {date0:%Y-%m-%dT%H:%M} to {date1:%Y-%m-%dT%H:%M} " f"at frequency {freq}") dates = pd.date_range(date0, date1, freq=freq) # Interpolate dso = interp_time(dsi, dates) dso = dso.astype("float32") # Callback postprocessing if postproc is not None: dso = postproc(dso) # Write netcdf ncfile = ncout.format(**locals()) dso.to_netcdf(ncfile) LOGGER.info(f"Created interpolation file: {ncfile}") return ncfile
[docs]class Regridder(object): """Regridder based on the `xesmf` library"""
[docs] def __init__(self, dsi, dso, regridder=None, filename=None, reuse_weights=True, **kwregrid): kwregrid.setdefault("method", "bilinear") self._dsi = rename_from_standard_name( dsi, {'longitude': get_lon_coord_name(dso), 'latitude': get_lat_coord_name(dso)}) self._dso_mapping = {} self._dso = rename_from_standard_name( dso, {'longitude': get_lon_coord_name(dso), 'latitude': get_lat_coord_name(dso)}, mapping=self._dso_mapping) if regridder is not None: self.regridder = regridder else: if not filename or not os.path.exists(filename): reuse_weights = False self.regridder = xesmf.Regridder( self._dsi, self._dso, filename=filename, reuse_weights=reuse_weights, **kwregrid)
[docs] def regrid(self, dai, erode=None): """ Regrid a variable Parameters ---------- dai : xarray.DataArray Single variable to regrid. Returns ------- xarray.DataArray Regridded variable. """ # Erode coast if erode: dai = erode_coast(dai) # Appropriate coord names dai = rename_from_standard_name( dai, {'longitude': 'lon', 'latitude': 'lat'}) dao = self.regridder(dai).rename(self._dso_mapping) return xr.merge([self._dso, dao])
__call__ = regrid
[docs] def clean(self): """ Remove the weights file """ if hasattr(self, 'regridder'): self.regridder.clean_weight_file()