对大量数据进行排序和计数

最近需要对一个包含几亿行数据的多字段的数据集进行排序和分组聚合统计,使用 pandas 直接处理已经很难运行,下面构造一个测试数据集测试不同方法效果。

目录

测试环境

!echo "system: $(uname -v)"
!echo "sort: $(sort --version)"
import sys; print(f'python: {sys.version}')
  system: Darwin Kernel Version 21.6.0: Mon Aug 22 20:19:52 PDT 2022; root:xnu-8020.140.49~2/RELEASE_ARM64_T6000
  sort: 2.3-Apple (138.100.3)
  python: 3.10.6 (main, Sep 20 2022, 00:29:57) [Clang 14.0.0 (clang-1400.0.29.102)]

构造大数据集

import os
import random

test_dir = os.path.join('/tmp', 'sort-test')
os.makedirs(test_dir, exist_ok=True)
test_file = os.path.join(test_dir, 'data.csv')
size = 1 * 10000 * 10000

if not os.path.exists(test_file):
    random.seed(1)  # 使用 seed 保证生成文件的可重现性
    with open(test_file, 'w') as f:
        for i in range(size):
            value = random.randint(0, size)
            f.write(f'{i},{value}\n')

size_gb = os.stat(test_file).st_size / 1024 / 1024 / 1024
print(f'test_file={test_file} size_gb={size_gb:.3f}')
  test_file=/tmp/sort-test/data.csv size_gb=1.656

使用 shell 排序和计数

使用 sort awk uniq 能够非常方便的在命令行处理文本数据的排序和计数。sort 参数如下

/usr/bin/time -l bash -c 'sort -t "," -nrk2 /tmp/sort-test/data.csv | head -n 5'
65899819,100000000
30387596,99999998
15566396,99999998
46605454,99999997
8698303,99999996
      152.10 real       149.61 user         2.20 sys
         19275333632  maximum resident set size
                   0  average shared memory size
                   0  average unshared data size
                   0  average unshared stack size
             1197551  page reclaims
                   0  page faults
                   0  swaps
                   0  block input operations
                   0  block output operations
                   0  messages sent
                   0  messages received
                   0  signals received
                   3  voluntary context switches
                7368  involuntary context switches
             8449127  instructions retired
             3757431  cycles elapsed
             1393344  peak memory footprint

可以发现此操作需要 18G 内存,如果修改为 2 亿条数据需要更多内存。再对 1 亿条数据分组聚合统计数量结果如下:

/usr/bin/time -l bash -c 'sort -t "," -nrk2 /tmp/sort-test/data.csv | awk -F "," "{print \$2}" | uniq -c | sort -r | head -n 5'
  10 967956
  10 95587138
  10 92848651
  10 89216079
  10 87144205
      281.88 real       370.44 user         5.97 sys
         19275513856  maximum resident set size
                   0  average shared memory size
                   0  average unshared data size
                   0  average unshared stack size
             1898942  page reclaims
                   0  page faults
                   0  swaps
                   0  block input operations
                   0  block output operations
                   0  messages sent
                   0  messages received
                   0  signals received
              146383  voluntary context switches
               54455  involuntary context switches
            10038831  instructions retired
             5224057  cycles elapsed
             1278656  peak memory footprint

可以发现内存不变,时间几乎翻倍。

单独的 sort 命令随着文件大小增加需要大量内存,可以使用 split 拆分成多个文件分别排序,排序之后再使用 sort -m 合并排序结果。

值得一提的是 split -n 是直接拆分成指定数量的文件,速度非常快,而 split -l 按行数拆分虽然结果一致但是速度非常慢。

/usr/bin/time -l split -l 10000000 data.csv data_split_
      258.70 real        13.84 user       240.84 sys
             1671168  maximum resident set size
/usr/bin/time -l split -n 10 data.csv data_split_
        0.72 real         0.00 user         0.60 sys
             2703360  maximum resident set size
# 拆分成 10 个文件
cd /tmp/sort-test/
/usr/bin/time -l split -n 10 data.csv data_split_

# 分别对文件排序
for file in data_split_*; do
    /usr/bin/time -l sort -t "," -nrk2 -o $file $file
done

下面是一个文件的处理统计情况,每个文件需要 1.9G 内存,符合预期。

        8.83 real        15.93 user         0.47 sys
          2044411904  maximum resident set size
                   0  average shared memory size
                   0  average unshared data size
                   0  average unshared stack size
              125381  page reclaims
                   0  page faults
                   0  swaps
                   0  block input operations
                   0  block output operations
                   0  messages sent
                   0  messages received
                   0  signals received
                  68  voluntary context switches
               13442  involuntary context switches
        121344257975  instructions retired
         49746113445  cycles elapsed
          2049299776  peak memory footprint
# 对排序之后的文件进行 merge sort
time sort -m -t "," -nrk2 data_split_* | head -n 5

使用 sort -m 输出结果如下,因为这一步只是对已经排序好的文件进行 merge,如果只要输出排序前面的数据,只需要很少的时间。

65899819,100000000
30387596,99999998
15566396,99999998
46605454,99999997
8698303,99999996
sort -m -t "," -nrk2 data_split_*  0.00s user 0.00s system 46% cpu 0.006 total
head -n 5  0.00s user 0.00s system 33% cpu 0.003 total

在此基础上进行分组计数:

# 分组计数
/usr/bin/time -l bash -c 'sort -m -t "," -nrk2 data_split_* | awk -F "," "{print \$2}" | uniq -c | sort -r | head -n 5'

输出了和上述一致的结果,此步骤需要 6G 内存两分钟时间。

  10 967956
  10 95587138
  10 92848651
  10 89216079
  10 87144205
      125.79 real       219.39 user         2.25 sys
          6086688768  maximum resident set size
                   0  average shared memory size
                   0  average unshared data size
                   0  average unshared stack size
              380423  page reclaims
                   0  page faults
                   0  swaps
                   0  block input operations
                   0  block output operations
                   0  messages sent
                   0  messages received
                   0  signals received
              139198  voluntary context switches
               33656  involuntary context switches
            10155254  instructions retired
             5116540  cycles elapsed
             1262272  peak memory footprint

参考文档

使用 Python 排序和计数

首先使用暴力方式直接排序,最高需要约 10G 内存 1分钟,其实这无论是从内存占用还是速度上都比 sort 命令好,但还是随着数据量增加需要大量内存。

import time
from itertools import islice
from collections import Counter

time_start = time.time()
items = []
with open(test_file, 'r') as f:
    for line in f:
        items.append(int(line.split(',')[1].strip()))

r_items = reversed(sorted(items))
print(list(islice(r_items, 0, 10)))
print(f'took: {time.time() - time_start:.3f}s')

c = Counter(r_items)
print(c.most_common()[:10])
  [100000000, 99999998, 99999998, 99999997, 99999996, 99999995, 99999994, 99999992, 99999992, 99999992]
  took: 55.838s
  [(95587138, 10), (92848651, 10), (89216079, 10), (87144205, 10), (75968895, 10), (61848367, 10), (46303300, 10), (45262839, 10), (45186772, 10), (43677379, 10)]

改进策略也是拆分文件缓存排序结果,通过 heapq.merge 合并结果:

import uuid
import shutil
import heapq


def merge_sort_file(source_file, *, reverse=False, max_lines=1000000, line_key=None, merge_key=None):
    split_dir = os.path.join(test_dir, f'merge_sort_file_{uuid.uuid4().hex}')
    os.makedirs(split_dir, exist_ok=False)
    chunks = []

    def append_chunk(chunk, items):
        items = [str(i)+ '\n' for i in sorted(items, reverse=reverse)]
        chunk.writelines(items)
        chunk.flush()
        chunk.seek(0)
        chunks.append(chunk)

    with open(source_file, 'r') as f:
        chunk = None
        items = []
        for index, line in enumerate(f):
            if index % max_lines == 0:
                # print('chunk', index)
                if chunk:
                    append_chunk(chunk, items)
                    items = []
                chunk = open(os.path.join(split_dir, str(index)), 'w+')

            item = line_key(line) if line_key else line
            items.append(item)
        append_chunk(chunk, items)
        items = []

    for item in heapq.merge(*chunks, reverse=reverse, key=merge_key):
        yield item

    for chunk in chunks:
        try:
            chunk.close()
        except Exception:
            pass
    shutil.rmtree(split_dir)

line_key = lambda line: int(line.split(',')[1].strip())
merge_key = lambda line: int(line.strip())
for index, item in enumerate(merge_sort_file(test_file, reverse=True, line_key=line_key, merge_key=merge_key)):
    print(index, item.strip())
    if index >= 4:
        break
  0 100000000
  1 99999998
  2 99999998
  3 99999997
  4 99999996

使用 SQLite 排序和计数

导入到数据库后支持在有限内存进行对数据的统计和计数操作,缺点导入需要比较长时间,优点是一旦导入之后可以使用 SQL 进行简单或复杂的查询和分析,而且速度很快。

from contextlib import contextmanager

@contextmanager
def timer(name='task'):
    start_time = time.time()
    try:
        yield
    finally:
        print(f'{name} took {time.time() - start_time}s')
import pandas as pd
from sqlalchemy import create_engine

sqlite_file = os.path.join(test_dir, 'data.db')
if os.path.exists(sqlite_file):
    os.unlink(sqlite_file)
disk_engine = create_engine('sqlite:///' + sqlite_file)

with timer('import'):
    for df in pd.read_csv(test_file, names=['value'], chunksize=500000):
        df.to_sql('data', disk_engine, if_exists='append')
  import took 992.4786958694458s
with timer('count'):
    df = pd.read_sql_query('select count(*) from data;', disk_engine)
    display(df)

with timer('top'):
    df = pd.read_sql_query('select value from data order by value desc limit 5;', disk_engine)
    display(df)
count(*)
0100000000
  count took 0.4194040298461914s
value
0100000000
199999998
299999998
399999997
499999996
  top took 4.143555164337158s
with timer('group_by'):
    df = pd.read_sql_query("""
        select value, count(*) as count 
        from data 
        group by value 
        order by count desc limit 5
    """, disk_engine)
    display(df)
valuecount
09558713810
19284865110
28921607910
38714420510
47596889510
  group_by took 48.16865396499634s
os.unlink(sqlite_file)