# SPDX-FileCopyrightText: 2023-2025 OpenMC contributors and Paul Romano
# SPDX-License-Identifier: MIT
from collections.abc import Iterable
from math import exp, log
import numpy as np
from .data import EV_PER_MEV
[docs]
class Tabulated1D:
"""A one-dimensional tabulated function.
This class mirrors the TAB1 type from the ENDF-6 format. A tabulated
function is specified by tabulated (x,y) pairs along with interpolation
rules that determine the values between tabulated pairs.
Once an object has been created, it can be used as though it were an actual
function, e.g.:
>>> f = Tabulated1D([0, 10], [4, 5])
>>> [f(xi) for xi in numpy.linspace(0, 10, 5)]
[4.0, 4.25, 4.5, 4.75, 5.0]
Parameters
----------
x : Iterable of float
Independent variable
y : Iterable of float
Dependent variable
breakpoints : Iterable of int
Breakpoints for interpolation regions
interpolation : Iterable of int
Interpolation scheme identification number, e.g., 3 means y is linear in
ln(x).
Attributes
----------
x : Iterable of float
Independent variable
y : Iterable of float
Dependent variable
breakpoints : Iterable of int
Breakpoints for interpolation regions
interpolation : Iterable of int
Interpolation scheme identification number, e.g., 3 means y is linear in
ln(x).
n_regions : int
Number of interpolation regions
n_pairs : int
Number of tabulated (x,y) pairs
"""
def __init__(self, x, y, breakpoints=None, interpolation=None):
if breakpoints is None or interpolation is None:
# Single linear-linear interpolation region by default
self.breakpoints = np.array([len(x)])
self.interpolation = np.array([2])
else:
self.breakpoints = np.asarray(breakpoints, dtype=int)
self.interpolation = np.asarray(interpolation, dtype=int)
self.x = np.asarray(x)
self.y = np.asarray(y)
def __repr__(self):
return f"<Tabulated1D: {self.x.size} points, {self.breakpoints.size} regions>"
def __call__(self, x):
# Check if input is scalar
if not isinstance(x, Iterable):
return self._interpolate_scalar(x)
x = np.array(x)
# Create output array
y = np.zeros_like(x)
# Get indices for interpolation
idx = np.searchsorted(self.x, x, side='right') - 1
# Loop over interpolation regions
for k in range(len(self.breakpoints)):
# Get indices for the begining and ending of this region
i_begin = self.breakpoints[k-1] - 1 if k > 0 else 0
i_end = self.breakpoints[k] - 1
# Figure out which idx values lie within this region
contained = (idx >= i_begin) & (idx < i_end)
xk = x[contained] # x values in this region
xi = self.x[idx[contained]] # low edge of corresponding bins
xi1 = self.x[idx[contained] + 1] # high edge of corresponding bins
yi = self.y[idx[contained]]
yi1 = self.y[idx[contained] + 1]
if self.interpolation[k] == 1:
# Histogram
y[contained] = yi
elif self.interpolation[k] == 2:
# Linear-linear
y[contained] = yi + (xk - xi)/(xi1 - xi)*(yi1 - yi)
elif self.interpolation[k] == 3:
# Linear-log
y[contained] = yi + np.log(xk/xi)/np.log(xi1/xi)*(yi1 - yi)
elif self.interpolation[k] == 4:
# Log-linear
y[contained] = yi*np.exp((xk - xi)/(xi1 - xi)*np.log(yi1/yi))
elif self.interpolation[k] == 5:
# Log-log
y[contained] = (yi*np.exp(np.log(xk/xi)/np.log(xi1/xi)
*np.log(yi1/yi)))
# In some cases, x values might be outside the tabulated region due only
# to precision, so we check if they're close and set them equal if so.
y[np.isclose(x, self.x[0], atol=1e-14)] = self.y[0]
y[np.isclose(x, self.x[-1], atol=1e-14)] = self.y[-1]
return y
def _interpolate_scalar(self, x):
if x <= self._x[0]:
return self._y[0]
elif x >= self._x[-1]:
return self._y[-1]
# Get the index for interpolation
idx = np.searchsorted(self._x, x, side='right') - 1
# Loop over interpolation regions
for b, p in zip(self.breakpoints, self.interpolation):
if idx < b - 1:
break
xi = self._x[idx] # low edge of the corresponding bin
xi1 = self._x[idx + 1] # high edge of the corresponding bin
yi = self._y[idx]
yi1 = self._y[idx + 1]
if p == 1:
# Histogram
return yi
elif p == 2:
# Linear-linear
return yi + (x - xi)/(xi1 - xi)*(yi1 - yi)
elif p == 3:
# Linear-log
return yi + log(x/xi)/log(xi1/xi)*(yi1 - yi)
elif p == 4:
# Log-linear
return yi*exp((x - xi)/(xi1 - xi)*log(yi1/yi))
elif p == 5:
# Log-log
return yi*exp(log(x/xi)/log(xi1/xi)*log(yi1/yi))
def __len__(self):
return len(self.x)
@property
def x(self):
return self._x
@property
def y(self):
return self._y
@property
def breakpoints(self):
return self._breakpoints
@property
def interpolation(self):
return self._interpolation
@property
def n_pairs(self):
return len(self.x)
@property
def n_regions(self):
return len(self.breakpoints)
@x.setter
def x(self, x):
self._x = x
@y.setter
def y(self, y):
self._y = y
@breakpoints.setter
def breakpoints(self, breakpoints):
self._breakpoints = breakpoints
@interpolation.setter
def interpolation(self, interpolation):
self._interpolation = interpolation
[docs]
def integral(self):
"""Integral of the tabulated function over its tabulated range.
Returns
-------
numpy.ndarray
Array of same length as the tabulated data that represents partial
integrals from the bottom of the range to each tabulated point.
"""
# Create output array
partial_sum = np.zeros(len(self.x) - 1)
i_low = 0
for k in range(len(self.breakpoints)):
# Determine which x values are within this interpolation range
i_high = self.breakpoints[k] - 1
# Get x values and bounding (x,y) pairs
x0 = self.x[i_low:i_high]
x1 = self.x[i_low + 1:i_high + 1]
y0 = self.y[i_low:i_high]
y1 = self.y[i_low + 1:i_high + 1]
if self.interpolation[k] == 1:
# Histogram
partial_sum[i_low:i_high] = y0*(x1 - x0)
elif self.interpolation[k] == 2:
# Linear-linear
m = (y1 - y0)/(x1 - x0)
partial_sum[i_low:i_high] = (y0 - m*x0)*(x1 - x0) + \
m*(x1**2 - x0**2)/2
elif self.interpolation[k] == 3:
# Linear-log
logx = np.log(x1/x0)
m = (y1 - y0)/logx
partial_sum[i_low:i_high] = y0 + m*(x1*(logx - 1) + x0)
elif self.interpolation[k] == 4:
# Log-linear
m = np.log(y1/y0)/(x1 - x0)
partial_sum[i_low:i_high] = y0/m*(np.exp(m*(x1 - x0)) - 1)
elif self.interpolation[k] == 5:
# Log-log
m = np.log(y1/y0)/np.log(x1/x0)
partial_sum[i_low:i_high] = y0/((m + 1)*x0**m)*(
x1**(m + 1) - x0**(m + 1))
i_low = i_high
return np.concatenate(([0.], np.cumsum(partial_sum)))
[docs]
@classmethod
def from_ace(cls, ace, idx=0, convert_units=True):
"""Create a Tabulated1D object from an ACE table.
Parameters
----------
ace : openmc.data.ace.Table
An ACE table
idx : int
Offset to read from in XSS array (default of zero)
convert_units : bool
If the abscissa represents energy, indicate whether to convert MeV
to eV.
Returns
-------
openmc.data.Tabulated1D
Tabulated data object
"""
# Get number of regions and pairs
n_regions = int(ace.xss[idx])
n_pairs = int(ace.xss[idx + 1 + 2*n_regions])
# Get interpolation information
idx += 1
if n_regions > 0:
breakpoints = ace.xss[idx:idx + n_regions].astype(int)
interpolation = ace.xss[idx + n_regions:idx + 2*n_regions].astype(int)
else:
# 0 regions implies linear-linear interpolation by default
breakpoints = np.array([n_pairs])
interpolation = np.array([2])
# Get (x,y) pairs
idx += 2*n_regions + 1
x = ace.xss[idx:idx + n_pairs].copy()
y = ace.xss[idx + n_pairs:idx + 2*n_pairs].copy()
if convert_units:
x *= EV_PER_MEV
return Tabulated1D(x, y, breakpoints, interpolation)
class Tabulated2D:
"""Metadata for a two-dimensional function.
This is a dummy class that is not really used other than to store the
interpolation information for a two-dimensional function. Once we refactor
to adopt GNDS-like data containers, this will probably be removed or
extended.
Parameters
----------
breakpoints : Iterable of int
Breakpoints for interpolation regions
interpolation : Iterable of int
Interpolation scheme identification number, e.g., 3 means y is linear in
ln(x).
"""
def __init__(self, breakpoints, interpolation):
self.breakpoints = breakpoints
self.interpolation = interpolation