# Engine-agnostic rasterization API
from __future__ import annotations
from functools import partial
from typing import TYPE_CHECKING, Any, Literal
import geopandas as gpd
import numpy as np
import xarray as xr
from ..raster_index import RasterIndex
from ..utils import get_affine, get_grid_mapping_var
from .utils import XAXIS, YAXIS, clip_to_bbox, is_in_memory, prepare_for_dask
if TYPE_CHECKING:
import dask_geopandas
from affine import Affine
__all__ = ["rasterize", "geometry_mask", "geometry_clip"]
def _get_affine(obj: xr.Dataset | xr.DataArray, *, x_dim: str, y_dim: str) -> Affine:
"""Get affine transform, preferring RasterIndex if available."""
idx = obj.xindexes.get(x_dim)
if isinstance(idx, RasterIndex):
return idx.transform()
return get_affine(obj, x_dim=x_dim, y_dim=y_dim)
Engine = Literal["rasterio", "rusterize", "exactextract"]
def _get_engine(engine: Engine | None) -> Engine:
"""Determine which engine to use based on availability."""
if engine is not None:
# Validate explicitly requested engine
if engine == "rusterize":
try:
import rusterize as _ # noqa: F401
except ImportError as e:
raise ImportError("rusterize is not installed. Install it with: pip install rusterize") from e
elif engine == "rasterio":
try:
import rasterio as _ # noqa: F401
except ImportError as e:
raise ImportError("rasterio is not installed. Install it with: pip install rasterio") from e
elif engine == "exactextract":
try:
import exactextract as _ # noqa: F401
except ImportError as e:
raise ImportError(
"exactextract is not installed. Install it with: pip install exactextract"
) from e
return engine
# Auto-detect: prefer rusterize, fall back to rasterio
# Note: exactextract is not auto-selected - it must be explicitly requested
try:
import rusterize as _ # noqa: F401
return "rusterize"
except ImportError:
pass
try:
import rasterio as _ # noqa: F401
return "rasterio"
except ImportError:
pass
raise ImportError(
"Neither rusterize nor rasterio is installed. "
"Install one with: pip install rusterize OR pip install rasterio"
)
def _get_rasterize_funcs(engine: Engine):
"""Get the engine-specific rasterize functions."""
if engine == "rasterio":
from . import rasterio as engine_module
elif engine == "exactextract":
from . import exact as engine_module
else:
from . import rusterize as engine_module
return (
engine_module.rasterize_geometries,
engine_module.dask_rasterize_wrapper,
)
def _get_mask_funcs(engine: Engine):
"""Get the engine-specific geometry_mask functions."""
if engine == "rasterio":
from . import rasterio as engine_module
elif engine == "exactextract":
from . import exact as engine_module
else:
from . import rusterize as engine_module
return (
engine_module.np_geometry_mask,
engine_module.dask_mask_wrapper,
)
def _normalize_merge_alg(merge_alg: str, engine: Engine) -> Any:
"""Normalize merge_alg string to engine-specific value.
rasterio and exactextract use "replace" and "add".
rusterize uses "last" and "sum" (plus "first", "min", "max", "count", "any").
We translate "replace" -> "last" and "add" -> "sum" for rusterize.
"""
if engine == "rasterio":
from rasterio.features import MergeAlg
mapping = {
"replace": MergeAlg.replace,
"add": MergeAlg.add,
}
if merge_alg not in mapping:
raise ValueError(
f"Invalid merge_alg {merge_alg!r} for rasterio. Must be one of: {list(mapping.keys())}"
)
return mapping[merge_alg]
elif engine == "exactextract":
valid_values = ("replace", "add")
if merge_alg not in valid_values:
raise ValueError(
f"Invalid merge_alg {merge_alg!r} for exactextract. Must be one of: {list(valid_values)}"
)
return merge_alg
else:
# rusterize: translate common names, pass through native names
translation = {"replace": "last", "add": "sum"}
return translation.get(merge_alg, merge_alg)
def replace_values(array: np.ndarray, to, *, from_=0) -> np.ndarray:
"""Replace fill values and adjust offsets after dask rasterization."""
mask = array == from_
array[~mask] -= 1
array[mask] = to
return array
[docs]
def rasterize(
obj: xr.Dataset | xr.DataArray,
geometries: gpd.GeoDataFrame | dask_geopandas.GeoDataFrame,
*,
engine: Engine | None = None,
xdim: str = "x",
ydim: str = "y",
all_touched: bool = False,
merge_alg: str = "replace",
geoms_rechunk_size: int | None = None,
clip: bool = False,
**engine_kwargs,
) -> xr.DataArray:
"""
Dask-aware rasterization of geometries.
Returns a 2D DataArray with integer codes for cells that are within the provided geometries.
Parameters
----------
obj : xr.Dataset or xr.DataArray
Xarray object whose grid to rasterize onto.
geometries : GeoDataFrame
Either a geopandas or dask_geopandas GeoDataFrame.
engine : {"rasterio", "rusterize", "exactextract"} or None
Rasterization engine to use. If None, auto-detects based on installed
packages (prefers rusterize if available, falls back to rasterio).
Note: "exactextract" must be explicitly requested and is not auto-selected.
xdim : str
Name of the "x" dimension on ``obj``.
ydim : str
Name of the "y" dimension on ``obj``.
all_touched : bool
If True, all pixels touched by geometries will be burned in.
If False, only pixels whose center is within the geometry are burned.
Note: Not supported by rusterize or exactextract engines.
merge_alg : str
Merge algorithm when geometries overlap.
- "replace": later geometries overwrite earlier ones
- "add": values are summed where geometries overlap
The rusterize engine also accepts: "first", "min", "max", "count", "any".
geoms_rechunk_size : int or None
Size to rechunk the geometry array to *after* conversion from dataframe.
clip : bool
If True, clip raster to the bounding box of the geometries.
Ignored for dask-geopandas geometries.
**engine_kwargs
Additional keyword arguments passed to the engine.
For rasterio: ``env`` (rasterio.Env for GDAL configuration).
Returns
-------
DataArray
2D DataArray with geometries "burned in" as integer codes.
Notes
-----
Different engines may produce slightly different results at pixel boundaries
due to differences in how they handle geometry-pixel intersection tests:
- **rasterio**: Uses GDAL's rasterization. By default (``all_touched=False``),
only pixels whose center falls within the geometry are burned. With
``all_touched=True``, any pixel that intersects the geometry is burned.
Requires GDAL.
- **rusterize**: A Rust-based rasterization engine. Does not require GDAL.
- **exactextract**: Uses the exactextract library for precise sub-pixel
coverage computation. Any pixel with non-zero coverage is burned, which
produces results equivalent to rasterio's ``all_touched=True``. Does not
require GDAL. Does not support ``all_touched=True`` (raises NotImplementedError)
since this is already its default behavior.
See Also
--------
rasterio.features.rasterize
rusterize.rusterize
exactextract.exact_extract
"""
if xdim not in obj.dims or ydim not in obj.dims:
raise ValueError(f"Received {xdim=!r}, {ydim=!r} but obj.dims={tuple(obj.dims)}")
resolved_engine = _get_engine(engine)
if clip:
obj = clip_to_bbox(obj, geometries, xdim=xdim, ydim=ydim)
affine = _get_affine(obj, x_dim=xdim, y_dim=ydim)
engine_merge_alg = _normalize_merge_alg(merge_alg, resolved_engine)
rasterize_geometries, dask_rasterize_wrapper = _get_rasterize_funcs(resolved_engine)
rasterize_kwargs = dict(
all_touched=all_touched,
merge_alg=engine_merge_alg,
affine=affine,
**engine_kwargs,
)
if is_in_memory(obj=obj, geometries=geometries):
geom_array = geometries.to_numpy().squeeze(axis=1)
rasterized = rasterize_geometries(
geom_array.tolist(),
shape=(obj.sizes[ydim], obj.sizes[xdim]),
offset=0,
dtype=np.min_scalar_type(len(geometries)),
fill=len(geometries),
**rasterize_kwargs,
)
else:
from dask.array import from_array, map_blocks
map_blocks_args, chunks, geom_array = prepare_for_dask(
obj,
geometries,
xdim=xdim,
ydim=ydim,
geoms_rechunk_size=geoms_rechunk_size,
)
# DaskGeoDataFrame.len() computes!
num_geoms = geom_array.size
# with dask, we use 0 as a fill value and replace it later
dtype = np.min_scalar_type(num_geoms)
# add 1 to the offset, to account for 0 as fill value
npoffsets = np.cumsum(np.array([0, *geom_array.chunks[0][:-1]])) + 1
offsets = from_array(npoffsets, chunks=1)
rasterized = map_blocks(
dask_rasterize_wrapper,
*map_blocks_args,
offsets[:, np.newaxis, np.newaxis],
chunks=((1,) * geom_array.numblocks[0], chunks[YAXIS], chunks[XAXIS]),
meta=np.array([], dtype=dtype),
fill=0, # good identity value for both sum & replace.
**rasterize_kwargs,
dtype_=dtype,
)
if merge_alg == "replace":
rasterized = rasterized.max(axis=0)
elif merge_alg == "add":
rasterized = rasterized.sum(axis=0)
else:
raise ValueError(f"Invalid merge_alg {merge_alg!r}. Must be one of: ['replace', 'add']")
# and reduce every other value by 1
rasterized = rasterized.map_blocks(partial(replace_values, to=num_geoms))
coord_vars: dict = {
xdim: obj.coords[xdim],
ydim: obj.coords[ydim],
}
if (grid_mapping := get_grid_mapping_var(obj)) is not None:
coord_vars[grid_mapping.name] = grid_mapping
return xr.DataArray(
dims=(ydim, xdim),
data=rasterized,
coords=xr.Coordinates(
coords=coord_vars,
indexes={xdim: obj.xindexes[xdim], ydim: obj.xindexes[ydim]},
),
name="rasterized",
)
[docs]
def geometry_mask(
obj: xr.Dataset | xr.DataArray,
geometries: gpd.GeoDataFrame | dask_geopandas.GeoDataFrame,
*,
engine: Engine | None = None,
xdim: str = "x",
ydim: str = "y",
all_touched: bool = False,
invert: bool = False,
geoms_rechunk_size: int | None = None,
clip: bool = False,
**engine_kwargs,
) -> xr.DataArray:
"""
Dask-aware geometry masking.
Creates a boolean mask from geometries.
Parameters
----------
obj : xr.DataArray or xr.Dataset
Xarray object used to extract the grid.
geometries : GeoDataFrame or DaskGeoDataFrame
Geometries used for masking.
engine : {"rasterio", "rusterize", "exactextract"} or None
Rasterization engine to use. If None, auto-detects based on installed
packages (prefers rusterize if available, falls back to rasterio).
Note: "exactextract" must be explicitly requested and is not auto-selected.
xdim : str
Name of the "x" dimension on ``obj``.
ydim : str
Name of the "y" dimension on ``obj``.
all_touched : bool
If True, all pixels touched by geometries will be included in mask.
Note: Not supported by rusterize or exactextract engines.
invert : bool
If True, pixels inside geometries are True (unmasked).
If False (default), pixels inside geometries are False (masked).
geoms_rechunk_size : int or None
Chunksize for geometry dimension of the output.
clip : bool
If True, clip raster to the bounding box of the geometries.
Ignored for dask-geopandas geometries.
**engine_kwargs
Additional keyword arguments passed to the engine.
For rasterio: ``env`` (rasterio.Env for GDAL configuration).
Returns
-------
DataArray
2D boolean DataArray mask.
Notes
-----
See :func:`rasterize` for details on engine differences. The exactextract
engine produces results equivalent to rasterio's ``all_touched=True``.
See Also
--------
rasterize
rasterio.features.geometry_mask
"""
if xdim not in obj.dims or ydim not in obj.dims:
raise ValueError(f"Received {xdim=!r}, {ydim=!r} but obj.dims={tuple(obj.dims)}")
resolved_engine = _get_engine(engine)
if clip:
obj = clip_to_bbox(obj, geometries, xdim=xdim, ydim=ydim)
affine = _get_affine(obj, x_dim=xdim, y_dim=ydim)
np_geometry_mask, dask_mask_wrapper = _get_mask_funcs(resolved_engine)
geometry_mask_kwargs = dict(
all_touched=all_touched,
affine=affine,
**engine_kwargs,
)
if is_in_memory(obj=obj, geometries=geometries):
geom_array = geometries.to_numpy().squeeze(axis=1)
mask = np_geometry_mask(
geom_array.tolist(),
shape=(obj.sizes[ydim], obj.sizes[xdim]),
invert=invert,
**geometry_mask_kwargs,
)
else:
from dask.array import map_blocks
map_blocks_args, chunks, geom_array = prepare_for_dask(
obj,
geometries,
xdim=xdim,
ydim=ydim,
geoms_rechunk_size=geoms_rechunk_size,
)
mask = map_blocks(
dask_mask_wrapper,
*map_blocks_args,
chunks=((1,) * geom_array.numblocks[0], chunks[YAXIS], chunks[XAXIS]),
meta=np.array([], dtype=bool),
**geometry_mask_kwargs,
)
mask = mask.all(axis=0)
if invert:
mask = ~mask
coord_vars: dict = {
xdim: obj.coords[xdim],
ydim: obj.coords[ydim],
}
if (grid_mapping := get_grid_mapping_var(obj)) is not None:
coord_vars[grid_mapping.name] = grid_mapping
return xr.DataArray(
dims=(ydim, xdim),
data=mask,
coords=xr.Coordinates(
coords=coord_vars,
indexes={xdim: obj.xindexes[xdim], ydim: obj.xindexes[ydim]},
),
name="mask",
)
[docs]
def geometry_clip(
obj: xr.Dataset | xr.DataArray,
geometries: gpd.GeoDataFrame | dask_geopandas.GeoDataFrame,
*,
engine: Engine | None = None,
xdim: str = "x",
ydim: str = "y",
all_touched: bool = False,
invert: bool = False,
geoms_rechunk_size: int | None = None,
clip: bool = True,
**engine_kwargs,
) -> xr.DataArray:
"""
Dask-aware geometry clipping.
Clips an xarray object to geometries by masking values outside the geometries.
Parameters
----------
obj : xr.DataArray or xr.Dataset
Xarray object to clip.
geometries : GeoDataFrame or DaskGeoDataFrame
Geometries used for clipping.
engine : {"rasterio", "rusterize", "exactextract"} or None
Rasterization engine to use. If None, auto-detects based on installed
packages (prefers rusterize if available, falls back to rasterio).
Note: "exactextract" must be explicitly requested and is not auto-selected.
xdim : str
Name of the "x" dimension on ``obj``.
ydim : str
Name of the "y" dimension on ``obj``.
all_touched : bool
If True, all pixels touched by geometries will be included.
Note: Not supported by rusterize or exactextract engines.
invert : bool
If True, preserve values outside the geometry (invert the clip).
If False (default), preserve values inside the geometry.
geoms_rechunk_size : int or None
Chunksize for geometry dimension of the output.
clip : bool
If True, clip raster to the bounding box of the geometries.
Ignored for dask-geopandas geometries.
**engine_kwargs
Additional keyword arguments passed to the engine.
For rasterio: ``env`` (rasterio.Env for GDAL configuration).
Returns
-------
DataArray
Clipped DataArray with values outside geometries set to NaN.
Notes
-----
See :func:`rasterize` for details on engine differences. The exactextract
engine produces results equivalent to rasterio's ``all_touched=True``, while
rusterize may produce slightly different results at pixel boundaries.
See Also
--------
rasterize
geometry_mask
rasterio.features.geometry_mask
"""
if clip:
obj = clip_to_bbox(obj, geometries, xdim=xdim, ydim=ydim)
mask = geometry_mask(
obj,
geometries,
engine=engine,
all_touched=all_touched,
invert=not invert, # rioxarray clip convention -> geometry_mask convention
xdim=xdim,
ydim=ydim,
geoms_rechunk_size=geoms_rechunk_size,
clip=False,
**engine_kwargs,
)
return obj.where(mask)