use core::{
cmp::PartialOrd,
ops::{BitAnd, Shr},
};
use num_traits::{
cast::AsPrimitive,
identities::Zero,
ops::wrapping::{WrappingAdd, WrappingSub},
Signed,
};
use serde::{Deserialize, Serialize};
#[inline(always)]
pub fn overflowing_sub<T>(y: T, x: T) -> (T, i32)
where
T: WrappingSub + Zero + PartialOrd,
{
let delta = y.wrapping_sub(&x);
let wrap = (delta >= T::zero()) as i32 - (y >= x) as i32;
(delta, wrap)
}
pub fn saturating_scale(lo: i32, hi: i32, shift: u32) -> i32 {
debug_assert!(shift > 0);
debug_assert!(shift <= 32);
let hi_range = -1 << (shift - 1);
if hi <= hi_range {
i32::MIN - hi_range
} else if -hi <= hi_range {
hi_range - i32::MIN
} else {
(lo >> shift) + (hi << (32 - shift))
}
}
#[derive(Copy, Clone, Default, Deserialize, Serialize)]
pub struct Unwrapper<Q> {
y: Q,
}
impl<Q> Unwrapper<Q>
where
Q: 'static + WrappingAdd + Copy,
{
pub fn update<P>(&mut self, x: P) -> P
where
P: 'static + WrappingSub + Copy + AsPrimitive<Q>,
Q: AsPrimitive<P>,
{
let dx = x.wrapping_sub(&self.y.as_());
self.y = self.y.wrapping_add(&dx.as_());
dx
}
pub fn wraps<P, const S: u32>(&self) -> P
where
Q: AsPrimitive<P> + Shr<u32, Output = Q>,
P: 'static + Copy + WrappingAdd + Signed + BitAnd<u32, Output = P>,
{
(self.y >> S)
.as_()
.wrapping_add(&((self.y >> (S - 1)).as_() & 1))
}
pub fn phase<P>(&self) -> P
where
P: 'static + Copy,
Q: AsPrimitive<P>,
{
self.y.as_()
}
pub fn y(&self) -> Q {
self.y
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn overflowing_sub_correctness() {
for (x0, x1, v) in [
(0i32, 0i32, 0i32),
(0, 1, 0),
(0, -1, 0),
(1, 0, 0),
(-1, 0, 0),
(0, 0x7fff_ffff, 0),
(-1, 0x7fff_ffff, -1),
(-2, 0x7fff_ffff, -1),
(-1, -0x8000_0000, 0),
(0, -0x8000_0000, 0),
(1, -0x8000_0000, 1),
(-0x6000_0000, 0x6000_0000, -1),
(0x6000_0000, -0x6000_0000, 1),
(-0x4000_0000, 0x3fff_ffff, 0),
(-0x4000_0000, 0x4000_0000, -1),
(-0x4000_0000, 0x4000_0001, -1),
(0x4000_0000, -0x3fff_ffff, 0),
(0x4000_0000, -0x4000_0000, 0),
(0x4000_0000, -0x4000_0001, 1),
]
.iter()
{
let (dx, w) = overflowing_sub(*x1, *x0);
assert_eq!(*v, w, " = overflowing_sub({:#x}, {:#x})", *x0, *x1);
let (dx0, w0) = x1.overflowing_sub(*x0);
assert_eq!(w0, w != 0);
assert_eq!(dx, dx0);
}
}
#[test]
fn saturating_scale_correctness() {
let shift = 8;
for (lo, hi, res) in [
(0i32, 0i32, 0i32),
(0, 1, 0x0100_0000),
(0, -1, -0x0100_0000),
(0x100, 0, 1),
(-1 << 31, 0, -1 << 23),
(0x7fffffff, 0, 0x007f_ffff),
(0x7fffffff, 1, 0x0017f_ffff),
(-0x7fffffff, -1, -0x0180_0000),
(0x1234_5600, 0x7f, 0x7f12_3456),
(0x1234_5600, -0x7f, -0x7f00_0000 + 0x12_3456),
(0, 0x7f, 0x7f00_0000),
(0, 0x80, 0x7fff_ff80),
(0, -0x7f, -0x7f00_0000),
(0, -0x80, -0x7fff_ff80),
(0x7fff_ffff, 0x7f, 0x7f7f_ffff),
(-0x8000_0000, 0x7f, 0x7e80_0000),
(-0x8000_0000, -0x7f, -0x7f80_0000),
(0x7fff_ffff, -0x7f, -0x7e80_0001),
(0x100, 0x7f, 0x7f00_0001),
(0, -0x80, -0x7fff_ff80),
(-1 << 31, 0x80, 0x7fff_ff80),
(-1 << 31, -0x80, -0x7fff_ff80),
]
.iter()
{
let s = saturating_scale(*lo, *hi, shift);
assert_eq!(
*res, s,
"{:#x} != {:#x} = saturating_scale({:#x}, {:#x}, {:#x})",
*res, s, *lo, *hi, shift
);
}
}
}