Source code for decode.fit
__all__ = ["baseline"]
# standard library
from typing import Any, Optional, Union
# dependencies
import numpy as np
import xarray as xr
from numpy.typing import NDArray
from sklearn import linear_model
from . import load
[docs]
def baseline(
dems: xr.DataArray,
/,
*,
order: int = 0,
model: str = "LinearRegression",
weight: Optional[Union[NDArray[np.float_], float]] = None,
**options: Any,
) -> xr.DataArray:
"""Fit baseline by polynomial and atmospheric models.
Args:
dems: DEMS DataArray to be fit.
order: Maximum order of the polynomial model.
weight: One-dimensional weight along channel axis.
If it is a scalar, then ``(dtau/dpwv)^weight`` will be used.
It is only for ``'LinearRegression'`` or ``'Ridge'`` models.
model: Name of the model class in ``sklearn.linear_model``.
options: Optional arguments used for the model initialization.
Returns:
baseline: DataArray of the fit baseline.
"""
freq = dems.d2_mkid_frequency.values
slope = dtau_dpwv(freq).values
n_freq, n_poly = len(freq), order + 1
# create data to be fit
X = np.zeros([n_freq, n_poly + 1])
X[:, 0] = slope
for exp in range(n_poly):
X[:, exp + 1] = (freq - freq.mean()) ** exp
X /= np.linalg.norm(X, axis=0)
y = dems.values.T
if weight is None:
weight = np.ones_like(freq)
elif isinstance(weight, float):
weight = slope**weight
else:
weight = np.array(weight)
# fit model to data
options = {"fit_intercept": False, **options}
model = getattr(linear_model, model)(**options)
if model in ("LinearRegression", "Ridge"):
model.fit(X, y, sample_weight=weight) # type: ignore
else:
model.fit(X, y) # type: ignore
coeff: NDArray[np.float_] = model.coef_ # type: ignore
# create baseline
baseline = xr.zeros_like(dems)
baseline += np.outer(coeff[:, 0], X[:, 0])
for exp in range(n_poly + 1):
baseline.coords[f"basis_{exp}"] = "chan", X[:, exp]
baseline.coords[f"coeff_{exp}"] = "time", coeff[:, exp]
return baseline
def dtau_dpwv(freq: NDArray[np.float_]) -> xr.DataArray:
"""Calculate dtau/dpwv as a function of frequency.
Args:
freq: Frequency in units of Hz.
Returns:
DataArray that stores dtau/dpwv.
"""
tau = load.atm(type="tau").interp(freq=freq, method="linear")
fit = tau.curvefit("pwv", lambda x, a, b: a * x + b)
return fit["curvefit_coefficients"].sel(param="a", drop=True)