Source code for endf.function

# SPDX-FileCopyrightText: 2023-2025 OpenMC contributors and Paul Romano
# SPDX-License-Identifier: MIT

from collections.abc import Iterable
from math import exp, log

import numpy as np

from .data import EV_PER_MEV


[docs] class Tabulated1D: """A one-dimensional tabulated function. This class mirrors the TAB1 type from the ENDF-6 format. A tabulated function is specified by tabulated (x,y) pairs along with interpolation rules that determine the values between tabulated pairs. Once an object has been created, it can be used as though it were an actual function, e.g.: >>> f = Tabulated1D([0, 10], [4, 5]) >>> [f(xi) for xi in numpy.linspace(0, 10, 5)] [4.0, 4.25, 4.5, 4.75, 5.0] Parameters ---------- x : Iterable of float Independent variable y : Iterable of float Dependent variable breakpoints : Iterable of int Breakpoints for interpolation regions interpolation : Iterable of int Interpolation scheme identification number, e.g., 3 means y is linear in ln(x). Attributes ---------- x : Iterable of float Independent variable y : Iterable of float Dependent variable breakpoints : Iterable of int Breakpoints for interpolation regions interpolation : Iterable of int Interpolation scheme identification number, e.g., 3 means y is linear in ln(x). n_regions : int Number of interpolation regions n_pairs : int Number of tabulated (x,y) pairs """ def __init__(self, x, y, breakpoints=None, interpolation=None): if breakpoints is None or interpolation is None: # Single linear-linear interpolation region by default self.breakpoints = np.array([len(x)]) self.interpolation = np.array([2]) else: self.breakpoints = np.asarray(breakpoints, dtype=int) self.interpolation = np.asarray(interpolation, dtype=int) self.x = np.asarray(x) self.y = np.asarray(y) def __repr__(self): return f"<Tabulated1D: {self.x.size} points, {self.breakpoints.size} regions>" def __call__(self, x): # Check if input is scalar if not isinstance(x, Iterable): return self._interpolate_scalar(x) x = np.array(x) # Create output array y = np.zeros_like(x) # Get indices for interpolation idx = np.searchsorted(self.x, x, side='right') - 1 # Loop over interpolation regions for k in range(len(self.breakpoints)): # Get indices for the begining and ending of this region i_begin = self.breakpoints[k-1] - 1 if k > 0 else 0 i_end = self.breakpoints[k] - 1 # Figure out which idx values lie within this region contained = (idx >= i_begin) & (idx < i_end) xk = x[contained] # x values in this region xi = self.x[idx[contained]] # low edge of corresponding bins xi1 = self.x[idx[contained] + 1] # high edge of corresponding bins yi = self.y[idx[contained]] yi1 = self.y[idx[contained] + 1] if self.interpolation[k] == 1: # Histogram y[contained] = yi elif self.interpolation[k] == 2: # Linear-linear y[contained] = yi + (xk - xi)/(xi1 - xi)*(yi1 - yi) elif self.interpolation[k] == 3: # Linear-log y[contained] = yi + np.log(xk/xi)/np.log(xi1/xi)*(yi1 - yi) elif self.interpolation[k] == 4: # Log-linear y[contained] = yi*np.exp((xk - xi)/(xi1 - xi)*np.log(yi1/yi)) elif self.interpolation[k] == 5: # Log-log y[contained] = (yi*np.exp(np.log(xk/xi)/np.log(xi1/xi) *np.log(yi1/yi))) # In some cases, x values might be outside the tabulated region due only # to precision, so we check if they're close and set them equal if so. y[np.isclose(x, self.x[0], atol=1e-14)] = self.y[0] y[np.isclose(x, self.x[-1], atol=1e-14)] = self.y[-1] return y def _interpolate_scalar(self, x): if x <= self._x[0]: return self._y[0] elif x >= self._x[-1]: return self._y[-1] # Get the index for interpolation idx = np.searchsorted(self._x, x, side='right') - 1 # Loop over interpolation regions for b, p in zip(self.breakpoints, self.interpolation): if idx < b - 1: break xi = self._x[idx] # low edge of the corresponding bin xi1 = self._x[idx + 1] # high edge of the corresponding bin yi = self._y[idx] yi1 = self._y[idx + 1] if p == 1: # Histogram return yi elif p == 2: # Linear-linear return yi + (x - xi)/(xi1 - xi)*(yi1 - yi) elif p == 3: # Linear-log return yi + log(x/xi)/log(xi1/xi)*(yi1 - yi) elif p == 4: # Log-linear return yi*exp((x - xi)/(xi1 - xi)*log(yi1/yi)) elif p == 5: # Log-log return yi*exp(log(x/xi)/log(xi1/xi)*log(yi1/yi)) def __len__(self): return len(self.x) @property def x(self): return self._x @property def y(self): return self._y @property def breakpoints(self): return self._breakpoints @property def interpolation(self): return self._interpolation @property def n_pairs(self): return len(self.x) @property def n_regions(self): return len(self.breakpoints) @x.setter def x(self, x): self._x = x @y.setter def y(self, y): self._y = y @breakpoints.setter def breakpoints(self, breakpoints): self._breakpoints = breakpoints @interpolation.setter def interpolation(self, interpolation): self._interpolation = interpolation
[docs] def integral(self): """Integral of the tabulated function over its tabulated range. Returns ------- numpy.ndarray Array of same length as the tabulated data that represents partial integrals from the bottom of the range to each tabulated point. """ # Create output array partial_sum = np.zeros(len(self.x) - 1) i_low = 0 for k in range(len(self.breakpoints)): # Determine which x values are within this interpolation range i_high = self.breakpoints[k] - 1 # Get x values and bounding (x,y) pairs x0 = self.x[i_low:i_high] x1 = self.x[i_low + 1:i_high + 1] y0 = self.y[i_low:i_high] y1 = self.y[i_low + 1:i_high + 1] if self.interpolation[k] == 1: # Histogram partial_sum[i_low:i_high] = y0*(x1 - x0) elif self.interpolation[k] == 2: # Linear-linear m = (y1 - y0)/(x1 - x0) partial_sum[i_low:i_high] = (y0 - m*x0)*(x1 - x0) + \ m*(x1**2 - x0**2)/2 elif self.interpolation[k] == 3: # Linear-log logx = np.log(x1/x0) m = (y1 - y0)/logx partial_sum[i_low:i_high] = y0 + m*(x1*(logx - 1) + x0) elif self.interpolation[k] == 4: # Log-linear m = np.log(y1/y0)/(x1 - x0) partial_sum[i_low:i_high] = y0/m*(np.exp(m*(x1 - x0)) - 1) elif self.interpolation[k] == 5: # Log-log m = np.log(y1/y0)/np.log(x1/x0) partial_sum[i_low:i_high] = y0/((m + 1)*x0**m)*( x1**(m + 1) - x0**(m + 1)) i_low = i_high return np.concatenate(([0.], np.cumsum(partial_sum)))
[docs] @classmethod def from_ace(cls, ace, idx=0, convert_units=True): """Create a Tabulated1D object from an ACE table. Parameters ---------- ace : openmc.data.ace.Table An ACE table idx : int Offset to read from in XSS array (default of zero) convert_units : bool If the abscissa represents energy, indicate whether to convert MeV to eV. Returns ------- openmc.data.Tabulated1D Tabulated data object """ # Get number of regions and pairs n_regions = int(ace.xss[idx]) n_pairs = int(ace.xss[idx + 1 + 2*n_regions]) # Get interpolation information idx += 1 if n_regions > 0: breakpoints = ace.xss[idx:idx + n_regions].astype(int) interpolation = ace.xss[idx + n_regions:idx + 2*n_regions].astype(int) else: # 0 regions implies linear-linear interpolation by default breakpoints = np.array([n_pairs]) interpolation = np.array([2]) # Get (x,y) pairs idx += 2*n_regions + 1 x = ace.xss[idx:idx + n_pairs].copy() y = ace.xss[idx + n_pairs:idx + 2*n_pairs].copy() if convert_units: x *= EV_PER_MEV return Tabulated1D(x, y, breakpoints, interpolation)
class Tabulated2D: """Metadata for a two-dimensional function. This is a dummy class that is not really used other than to store the interpolation information for a two-dimensional function. Once we refactor to adopt GNDS-like data containers, this will probably be removed or extended. Parameters ---------- breakpoints : Iterable of int Breakpoints for interpolation regions interpolation : Iterable of int Interpolation scheme identification number, e.g., 3 means y is linear in ln(x). """ def __init__(self, breakpoints, interpolation): self.breakpoints = breakpoints self.interpolation = interpolation