#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CF utilities
"""
import re
import warnings
import numpy as np
import xarray as xr
CF_SPECS = {
'lon': {
'name': ['lon', 'longitude', 'nav_lon','LON'],
'units': ['degrees_east', 'degree_east', 'degree_E', 'degrees_E',
'degreeE', 'degreesE'],
'standard_name': ['longitude'],
'axis': 'X',
},
'lat': {
'name': ['lat', 'latitude', 'nav_lat','LAT'],
'units': ['degrees_north', 'degree_north', 'degree_N', 'degrees_N',
'degreeN', 'degreesN'],
'standard_name': ['latitude'],
'axis': 'Y'
},
'depth': {
'name': ['depth', 'deptht'],
'units': ['m'],
'standard_name': ['depth'],
'axis': 'Z'
},
'temp': {
'name': ['temp', 'votemper'],
'units': ['degrees_celsius', 'degrees_C'],
'standard_name': ['sea_water_temperature',
"sea_water_potential_temperature"],
},
'saln': {
'name': ['saln', 'votemper'],
'units': ['PSU'],
'standard_name': ['sea_water_salinity'],
},
'ssh': {
'name': ['ssh', 'xe', 'sossheig'],
'units': ['m'],
'standard_name': ['sea_surface_height_above_sea_level',
'sea_surface_height_above_geoid'],
},
}
RE_SN_AT_LOC = re.compile(r'^(\w+)(?:_at_(\w)_location)$').match
[docs]def get_dim_name_from_axis(ds, axis):
"""
Get a dimension name from its axis attribute in a xarray.Dataset
Parameters
----------
ds : xarray.Dataset
axis : str
X, Y, Z or T
Returns
-------
str
"""
axis = axis.upper()
assert axis in "XYZT", "Invalid value for axis attribute"
for name, coord in ds.coords.items():
if hasattr(coord, 'axis') and coord.axis == axis:
return name
raise AttributeError('Dimension from axis not found: '+axis)
[docs]def get_x_dim_name(ds):
'''
Get the X dimension name in a `xarray.Dataset` or a `xarray.DataArray`
Parameters
----------
ds: xarray.DataArray or xarray.Dataset
Returns
-------
str: Name of X dimension
'''
return get_dim_name_from_axis(ds, 'X')
[docs]def get_y_dim_name(ds):
'''
Get the Y dimension name in a `xarray.Dataset` or a `xarray.DataArray`
Parameters
----------
ds: xarray.DataArray or xarray.Dataset
Returns
-------
str: Name of Y dimension
'''
return get_dim_name_from_axis(ds, 'Y')
[docs]def get_t_dim_name(ds):
'''
Get the time dimension name in a `xarray.Dataset` or a `xarray.DataArray`
Parameters
----------
ds: xarray.DataArray or xarray.Dataset
Returns
-------
str: Name of time dimension
'''
return get_dim_name_from_axis(ds, 'T')
[docs]def get_name_from_specs(ds, specs, check=None, first=True):
'''
Get the name of a `xarray.DataArray` according to search specifications
Parameters
----------
ds: xarray.DataArray or xarray.Dataset
specs: dict
check: str, list of str
Attributes to check
first: bool
Returns
-------
str: Name of the xarray.DataArray
'''
# Targets
ds_names = []
if hasattr(ds, 'coords'):
ds_names.extend(list(ds.coords))
if hasattr(ds, 'variables'):
ds_names.extend(list(ds.variables))
if isinstance(check, str):
check = [check]
# Search in names first
if 'name' in specs and (check is None or 'name' in check):
for specs_name in specs['name']:
if specs_name in ds_names:
return specs_name
# Search in units attributes
if ('units' in specs and 'axis' in specs and
(check is None or 'units' in check)):
for ds_name in ds_names:
if hasattr(ds[ds_name], 'units'):
for specs_units in specs['units']:
if ds[ds_name].units == specs_units:
return ds_name
# Search in standard_name attributes
if ('standard_name' in specs and
(check is None or 'standard_name' in check)):
for ds_name in ds_names:
if hasattr(ds[ds_name], 'standard_name'):
for specs_sn in specs['standard_name']:
if ds[ds_name].standard_name == specs_sn:
return ds_name
m = RE_SN_AT_LOC(ds[ds_name].standard_name)
if m and m.group(1) == specs_sn:
return ds_name
# Search in axis attributes
if 'axis' in specs and (check is None or 'axis' in check):
for ds_name in ds_names:
if (hasattr(ds[ds_name], 'axis') and
ds[ds_name].axis.lower() == specs['axis'].lower()):
return ds_name
[docs]def get_lon_coord_name(ds):
'''
Get the longitude name in a `xarray.Dataset` or a `xarray.DataArray`
Parameters
----------
ds: xarray.DataArray or xarray.Dataset
Returns
-------
str: Name of the xarray.DataArray
'''
lon_name = get_name_from_specs(ds, CF_SPECS['lon'])
if lon_name is not None:
return lon_name
warnings.warn('longitude not found in dataset')
[docs]def get_lat_coord_name(ds):
'''
Get the latitude name in a `xarray.Dataset` or a `xarray.DataArray`
Parameters
----------
ds: xarray.DataArray or xarray.Dataset
Returns
-------
str: Name of the xarray.DataArray
'''
lat_name = get_name_from_specs(ds, CF_SPECS['lat'])
if lat_name is not None:
return lat_name
warnings.warn('latitude not found in dataset')
[docs]def get_time_coord_name(ds):
'''
Get the time coordinate name from its dtype
Parameters
----------
ds: xarray.DataArray or xarray.Dataset
Returns
-------
str: Name of the xarray.DataArray
'''
for coord_name in ds.coords:
if ds.coords[coord_name].dtype.char == 'M':
return coord_name
warnings.warn('time not found in dataset')
[docs]def rename_from_standard_name(ds, standard_name_to_name, mapping=None):
"""
Rename coords and variables according to their standard_name attribute
Parameters
----------
ds : xarray.Dataset
standard_name_to_name : dict
Key are standard_names and values are names.
Returns
-------
xarray.Dataset
"""
if isinstance(ds, xr.DataArray):
keys = ds.coords.keys()
else:
keys = ds.variables.keys()
keys = list(keys)
for standard_name, name in standard_name_to_name.items():
names = []
matcher = re.compile(r'^{}(_.+)?$'.format(standard_name)).match
for vname in keys:
if (hasattr(ds[vname], 'standard_name') and
matcher(ds[vname].standard_name)):
names.append(vname)
if not names:
continue
assert len(names) <= 1, ('No more than one variable must have this'
' standard_name: '+standard_name)
vname = names[0]
if vname != name:
ds = ds.rename({vname: name})
keys.remove(vname)
keys.append(name)
if isinstance(mapping, dict):
mapping[name] = vname
return ds
# with warnings.catch_warnings():
# warnings.simplefilter(
# "ignore", xr.core.extensions.AccessorRegistrationWarning)
# @xr.register_dataset_accessor('cf')
# @xr.register_dataarray_accessor('cf')
# class CFAccessor(object):
# def __init__(self, ds):
# self.ds = ds
# def __getitem__(self, cf_name):
# return self.ds[get_name_from_specs(self.ds, CF_SPECS[cf_name])]
# def get_lon(self):
# return self['lon']
# lon = property(fget=get_lon, doc='Longitude')
# def get_lat(self):
# return self['lon']
# lat = property(fget=get_lat, doc='Latitude')
# def get_depth(self):
# return self['depth']
# depth = property(fget=get_depth, doc='Depth')
# def get_time(self):
# return self.ds[get_time_coord_name(self.ds)]
# time = property(fget=get_time, doc='Time')