Skip to content Skip to sidebar Skip to footer

Efficient Rolling Trimmed Mean With Python

What's the most efficient way to calculate a rolling (aka moving window) trimmed mean with Python? For example, for a data set of 50K rows and a window size of 50, for each row I n

Solution 1:

One observation that could come in handy is that you do not need to sort all the values at each step. Rather, if you ensure that the window is always sorted, all you need to do is insert the new value at the relevant spot, and remove the old one from where it was, both of which are operations that can be done in O(log_2(window_size)) using bisect. In practice, this would look something like

def rolling_mean(data):
    x = sorted(data[:49])
    res = np.repeat(np.nan, len(data))
    for i in range(49, len(data)):
        if i != 49:
            del x[bisect.bisect_left(x, data[i - 50])]
        bisect.insort_right(x, data[i])
        res[i] = np.mean(x[3:47])
    return res

Now, the additional benefit in this case turns out to be less than what is gained by the vectorization that scipy.stats.trim_mean relies on, and so in particular, this will still be slower than @ChrisA's solution, but it is a useful starting point for further performance optimization.

> data = pd.Series(np.random.randint(0, 1000, 50000))
> %timeit data.rolling(50).apply(lambda w: trim_mean(w, 0.06))
727 ms ± 34.7 ms per loop (mean ± std. dev. of7 runs, 1loopeach)
> %timeit rolling_mean(data.values)
812 ms ± 42.1 ms per loop (mean ± std. dev. of7 runs, 1loopeach)

Notably, Numba's jitter, which is often useful in situations like these, also provides no benefit:

> from numba import jit
> rolling_mean_jit = jit(rolling_mean)
> %timeit rolling_mean_jit(data.values)
1.05 s ± 183 ms per loop (mean ± std. dev. of7 runs, 1loopeach)

The following, seemingly far-from-optimal, approach outperforms both of the other approaches considered above:

def rolling_mean_np(data):
    res = np.repeat(np.nan, len(data))
    for i in range(len(data)-49):
        x = np.sort(data[i:i+50])
        res[i+49] = x[3:47].mean()
    return res

Timing:

> %timeit rolling_mean_np(data.values)
564 ms ± 4.44 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

What is more, this time around, JIT compilation does help:

> rolling_mean_np_jit = jit(rolling_mean_np)
> %timeit rolling_mean_np_jit(data.values)94.9 ms ± 605 µs per loop(mean ± std. dev. of 7 runs, 10 loops each)

While we're at it, let's just quickly verify that this actually does what we expect it to:

> np.all(rolling_mean_np_jit(data.values)[49:] == data.rolling(50).apply(lambda w: trim_mean(w, 0.06)).values[49:])
True

In fact, by helping out the sorter just a little bit, we can squeeze out another factor of 2, taking the total time down to 57 ms:

def rolling_mean_np_manual(data):
    x = np.sort(data[:50])
    res = np.repeat(np.nan, len(data))
    for i in range(50, len(data)+1):
        res[i-1] = x[3:47].mean()
        if i != len(data):
            idx_old = np.searchsorted(x, data[i-50])
            x[idx_old] = data[i]
            x.sort()
    return res

> %timeit rolling_mean_np_manual(data.values)
580 ms ± 23 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
> rolling_mean_np_manual_jit = jit(rolling_mean_np_manual)
> %timeit rolling_mean_np_manual_jit(data.values)
57 ms ± 5.89 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
> np.all(rolling_mean_np_manual_jit(data.values)[49:] == data.rolling(50).apply(lambda w: trim_mean(w, 0.06)).values[49:])
True

Now, the "sorting" that is going on in this example of course just boils down to placing the new element in the right place, while shifting everything in between by one. Doing this by hand will make the pure Python code slower, but the jitted version gains another factor of 2, taking us below 30 ms:

def rolling_mean_np_shift(data):
    x = np.sort(data[:50])
    res = np.repeat(np.nan, len(data))
    for i in range(50, len(data)+1):
        res[i-1] = x[3:47].mean()
        if i != len(data):
            idx_old, idx_new = np.searchsorted(x, [data[i-50], data[i]])
            if idx_old < idx_new:
                x[idx_old:idx_new-1] = x[idx_old+1:idx_new]
                x[idx_new-1] = data[i]
            elif idx_new < idx_old:
                x[idx_new+1:idx_old+1] = x[idx_new:idx_old]
                x[idx_new] = data[i]
            else:
                x[idx_new] = data[i]
    return res

> %timeit rolling_mean_np_shift(data.values)
937 ms ± 97.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
> rolling_mean_np_shift_jit = jit(rolling_mean_np_shift)
> %timeit rolling_mean_np_shift_jit(data.values)
26.4 ms ± 693 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
> np.all(rolling_mean_np_shift_jit(data.values)[49:] == data.rolling(50).apply(lambda w: trim_mean(w, 0.06)).values[49:])
True

At this point, most of the time is spent in np.searchsorted, so let us make the search itself JIT-friendly. Adopting the source code for bisect, we let

@jit
def binary_search(a, x):
    lo = 0
    hi = 50
    while lo < hi:
        mid = (lo+hi)//2
        if a[mid] < x: lo = mid+1else: hi = mid
    return lo

@jit
def rolling_mean_np_jitted_search(data):
    x = np.sort(data[:50])
    res = np.repeat(np.nan, len(data))
    for i in range(50, len(data)+1):
        res[i-1] = x[3:47].mean()
        if i != len(data):
            idx_old = binary_search(x, data[i-50])
            idx_new = binary_search(x, data[i])
            if idx_old < idx_new:
                x[idx_old:idx_new-1] = x[idx_old+1:idx_new]
                x[idx_new-1] = data[i]
            elif idx_new < idx_old:
                x[idx_new+1:idx_old+1] = x[idx_new:idx_old]
                x[idx_new] = data[i]
            else:
                x[idx_new] = data[i]
    return res

This takes us down to 12 ms, a x60 improvement over the raw pandas+SciPy approach:

> %timeit rolling_mean_np_jitted_search(data.values)
12 ms ± 210 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Solution 2:

You might try using scipy.stats.trim_mean :

from scipy.stats import trim_mean

df['value'].rolling(5).apply(lambda x: trim_mean(x, 0.2))

[output]

0NaN1NaN2NaN3NaN410.000000511.000000613.000000713.333333814.000000915.666667

Note that I had to use rolling(5) and proportiontocut=0.2 for your toy data set.

For your real data you should use rolling(50) and trim_mean(x, 0.06) to remove the top and bottom 3 values from the rolling window.

Solution 3:

I bet slicing and sorting with every move of the window is the slow part. Instead of slicing every time, make a separate list of 50 (or 5) values. Sort once at the start, then when adding and removing values (moving the window) add new values in the correct place so as to preserve the sort order (much like in insertion sort algorithm). Then calculate trimmed mean based on the subset of values from that list. You will need a way to keep info on where your list is in relation to the entire set, i think a single int variable will suffice.

Post a Comment for "Efficient Rolling Trimmed Mean With Python"