Source code for sloop.cf

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CF utilities
"""
import re
import warnings

import numpy as np
import xarray as xr

CF_SPECS = {
    'lon': {
        'name': ['lon', 'longitude', 'nav_lon','LON'],
        'units': ['degrees_east', 'degree_east', 'degree_E', 'degrees_E',
                  'degreeE', 'degreesE'],
        'standard_name': ['longitude'],
        'axis': 'X',
        },
    'lat': {
        'name': ['lat', 'latitude', 'nav_lat','LAT'],
        'units': ['degrees_north', 'degree_north', 'degree_N', 'degrees_N',
                  'degreeN', 'degreesN'],
        'standard_name': ['latitude'],
        'axis': 'Y'
        },
    'depth': {
        'name': ['depth', 'deptht'],
        'units': ['m'],
        'standard_name': ['depth'],
        'axis': 'Z'
        },
    'temp': {
        'name': ['temp', 'votemper'],
        'units': ['degrees_celsius', 'degrees_C'],
        'standard_name': ['sea_water_temperature',
                          "sea_water_potential_temperature"],
        },
    'saln': {
        'name': ['saln', 'votemper'],
        'units': ['PSU'],
        'standard_name': ['sea_water_salinity'],
        },
    'ssh': {
        'name': ['ssh', 'xe', 'sossheig'],
        'units': ['m'],
        'standard_name': ['sea_surface_height_above_sea_level',
                          'sea_surface_height_above_geoid'],
        },
    }

RE_SN_AT_LOC = re.compile(r'^(\w+)(?:_at_(\w)_location)$').match


[docs]def get_dim_name_from_axis(ds, axis): """ Get a dimension name from its axis attribute in a xarray.Dataset Parameters ---------- ds : xarray.Dataset axis : str X, Y, Z or T Returns ------- str """ axis = axis.upper() assert axis in "XYZT", "Invalid value for axis attribute" for name, coord in ds.coords.items(): if hasattr(coord, 'axis') and coord.axis == axis: return name raise AttributeError('Dimension from axis not found: '+axis)
[docs]def get_x_dim_name(ds): ''' Get the X dimension name in a `xarray.Dataset` or a `xarray.DataArray` Parameters ---------- ds: xarray.DataArray or xarray.Dataset Returns ------- str: Name of X dimension ''' return get_dim_name_from_axis(ds, 'X')
[docs]def get_y_dim_name(ds): ''' Get the Y dimension name in a `xarray.Dataset` or a `xarray.DataArray` Parameters ---------- ds: xarray.DataArray or xarray.Dataset Returns ------- str: Name of Y dimension ''' return get_dim_name_from_axis(ds, 'Y')
[docs]def get_t_dim_name(ds): ''' Get the time dimension name in a `xarray.Dataset` or a `xarray.DataArray` Parameters ---------- ds: xarray.DataArray or xarray.Dataset Returns ------- str: Name of time dimension ''' return get_dim_name_from_axis(ds, 'T')
[docs]def get_name_from_specs(ds, specs, check=None, first=True): ''' Get the name of a `xarray.DataArray` according to search specifications Parameters ---------- ds: xarray.DataArray or xarray.Dataset specs: dict check: str, list of str Attributes to check first: bool Returns ------- str: Name of the xarray.DataArray ''' # Targets ds_names = [] if hasattr(ds, 'coords'): ds_names.extend(list(ds.coords)) if hasattr(ds, 'variables'): ds_names.extend(list(ds.variables)) if isinstance(check, str): check = [check] # Search in names first if 'name' in specs and (check is None or 'name' in check): for specs_name in specs['name']: if specs_name in ds_names: return specs_name # Search in units attributes if ('units' in specs and 'axis' in specs and (check is None or 'units' in check)): for ds_name in ds_names: if hasattr(ds[ds_name], 'units'): for specs_units in specs['units']: if ds[ds_name].units == specs_units: return ds_name # Search in standard_name attributes if ('standard_name' in specs and (check is None or 'standard_name' in check)): for ds_name in ds_names: if hasattr(ds[ds_name], 'standard_name'): for specs_sn in specs['standard_name']: if ds[ds_name].standard_name == specs_sn: return ds_name m = RE_SN_AT_LOC(ds[ds_name].standard_name) if m and m.group(1) == specs_sn: return ds_name # Search in axis attributes if 'axis' in specs and (check is None or 'axis' in check): for ds_name in ds_names: if (hasattr(ds[ds_name], 'axis') and ds[ds_name].axis.lower() == specs['axis'].lower()): return ds_name
[docs]def get_lon_coord_name(ds): ''' Get the longitude name in a `xarray.Dataset` or a `xarray.DataArray` Parameters ---------- ds: xarray.DataArray or xarray.Dataset Returns ------- str: Name of the xarray.DataArray ''' lon_name = get_name_from_specs(ds, CF_SPECS['lon']) if lon_name is not None: return lon_name warnings.warn('longitude not found in dataset')
[docs]def get_lat_coord_name(ds): ''' Get the latitude name in a `xarray.Dataset` or a `xarray.DataArray` Parameters ---------- ds: xarray.DataArray or xarray.Dataset Returns ------- str: Name of the xarray.DataArray ''' lat_name = get_name_from_specs(ds, CF_SPECS['lat']) if lat_name is not None: return lat_name warnings.warn('latitude not found in dataset')
[docs]def get_time_coord_name(ds): ''' Get the time coordinate name from its dtype Parameters ---------- ds: xarray.DataArray or xarray.Dataset Returns ------- str: Name of the xarray.DataArray ''' for coord_name in ds.coords: if ds.coords[coord_name].dtype.char == 'M': return coord_name warnings.warn('time not found in dataset')
[docs]def rename_from_standard_name(ds, standard_name_to_name, mapping=None): """ Rename coords and variables according to their standard_name attribute Parameters ---------- ds : xarray.Dataset standard_name_to_name : dict Key are standard_names and values are names. Returns ------- xarray.Dataset """ if isinstance(ds, xr.DataArray): keys = ds.coords.keys() else: keys = ds.variables.keys() keys = list(keys) for standard_name, name in standard_name_to_name.items(): names = [] matcher = re.compile(r'^{}(_.+)?$'.format(standard_name)).match for vname in keys: if (hasattr(ds[vname], 'standard_name') and matcher(ds[vname].standard_name)): names.append(vname) if not names: continue assert len(names) <= 1, ('No more than one variable must have this' ' standard_name: '+standard_name) vname = names[0] if vname != name: ds = ds.rename({vname: name}) keys.remove(vname) keys.append(name) if isinstance(mapping, dict): mapping[name] = vname return ds
# with warnings.catch_warnings(): # warnings.simplefilter( # "ignore", xr.core.extensions.AccessorRegistrationWarning) # @xr.register_dataset_accessor('cf') # @xr.register_dataarray_accessor('cf') # class CFAccessor(object): # def __init__(self, ds): # self.ds = ds # def __getitem__(self, cf_name): # return self.ds[get_name_from_specs(self.ds, CF_SPECS[cf_name])] # def get_lon(self): # return self['lon'] # lon = property(fget=get_lon, doc='Longitude') # def get_lat(self): # return self['lon'] # lat = property(fget=get_lat, doc='Latitude') # def get_depth(self): # return self['depth'] # depth = property(fget=get_depth, doc='Depth') # def get_time(self): # return self.ds[get_time_coord_name(self.ds)] # time = property(fget=get_time, doc='Time')