Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259

260

261

262

263

264

265

266

267

268

269

270

271

272

273

274

275

276

277

278

279

280

281

282

283

284

285

286

287

288

289

290

291

292

293

294

295

296

297

298

299

300

301

302

303

304

305

306

307

308

309

310

311

312

313

314

315

316

317

318

319

320

321

322

323

324

325

326

327

328

329

330

331

332

333

334

335

336

337

338

339

340

341

342

343

344

345

346

347

348

349

350

351

352

353

354

355

356

357

358

359

360

361

362

363

364

365

366

367

368

369

370

371

372

373

374

375

376

377

378

379

380

381

382

383

384

385

386

387

388

389

390

391

392

393

394

395

396

397

398

399

400

401

402

403

404

405

406

407

408

409

410

411

412

413

414

415

416

417

418

419

420

421

422

423

424

425

426

427

428

429

430

431

432

433

434

435

436

437

438

439

440

441

442

443

444

445

446

447

448

449

450

451

452

453

454

455

456

457

458

459

460

461

462

463

464

465

466

467

468

469

470

471

472

473

474

475

476

477

478

479

480

481

482

483

484

485

486

487

488

489

490

491

492

493

494

495

496

497

498

499

500

501

502

503

504

505

506

507

508

509

510

511

512

513

514

515

516

517

518

519

520

521

522

523

524

525

526

527

528

529

530

531

532

533

534

535

536

537

538

539

540

541

542

543

544

545

546

547

548

549

550

551

552

553

554

555

556

557

558

559

560

561

562

563

564

565

566

567

568

569

570

571

572

573

574

575

576

577

578

579

580

581

582

583

584

585

586

587

588

589

590

591

592

593

594

595

596

597

598

599

600

601

602

603

604

605

606

607

608

609

610

611

612

613

614

615

616

617

618

619

620

621

622

623

624

625

626

627

628

629

630

631

632

633

634

635

636

637

638

639

640

641

642

643

644

645

646

647

648

649

650

651

652

653

654

655

656

657

658

659

660

661

662

663

664

665

666

667

668

669

670

671

672

673

674

675

676

677

678

679

680

681

682

683

684

685

686

687

688

689

690

691

692

693

694

695

696

697

698

699

700

701

702

703

704

705

706

707

708

709

710

711

712

713

714

715

716

717

718

719

720

721

722

723

724

725

726

727

728

729

730

731

732

733

734

735

736

737

738

739

740

741

742

743

744

745

746

747

748

749

750

751

752

753

754

755

756

757

758

759

760

761

762

763

764

765

766

767

768

769

770

771

772

773

774

775

776

777

778

779

780

781

782

783

784

785

786

787

788

789

790

791

792

793

794

795

796

797

798

799

800

801

802

803

804

805

806

807

808

809

810

811

812

813

814

815

816

817

818

819

820

821

822

823

824

# 

# Licensed to the Apache Software Foundation (ASF) under one or more 

# contributor license agreements. See the NOTICE file distributed with 

# this work for additional information regarding copyright ownership. 

# The ASF licenses this file to You under the Apache License, Version 2.0 

# (the "License"); you may not use this file except in compliance with 

# the License. You may obtain a copy of the License at 

# 

# http://www.apache.org/licenses/LICENSE-2.0 

# 

# Unless required by applicable law or agreed to in writing, software 

# distributed under the License is distributed on an "AS IS" BASIS, 

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

# See the License for the specific language governing permissions and 

# limitations under the License. 

# 

 

import os 

import platform 

import shutil 

import warnings 

import gc 

import itertools 

import operator 

import random 

import sys 

 

import heapq 

from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ 

CompressedSerializer, AutoBatchedSerializer 

from pyspark.util import fail_on_stopiteration # type: ignore 

 

 

try: 

import psutil 

 

process = None 

 

def get_used_memory(): 

""" Return the used memory in MiB """ 

global process 

if process is None or process._pid != os.getpid(): 

process = psutil.Process(os.getpid()) 

44 ↛ 47line 44 didn't jump to line 47, because the condition on line 44 was never false if hasattr(process, "memory_info"): 

info = process.memory_info() 

else: 

info = process.get_memory_info() 

return info.rss >> 20 

 

except ImportError: 

 

def get_used_memory(): 

""" Return the used memory in MiB """ 

54 ↛ 60line 54 didn't jump to line 60, because the condition on line 54 was never false if platform.system() == 'Linux': 

55 ↛ 68line 55 didn't jump to line 68, because the loop on line 55 didn't complete for line in open('/proc/self/status'): 

if line.startswith('VmRSS:'): 

return int(line.split()[1]) >> 10 

 

else: 

warnings.warn("Please install psutil to have better " 

"support with spilling") 

if platform.system() == "Darwin": 

import resource 

rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss 

return rss >> 20 

# TODO: support windows 

 

return 0 

 

 

def _get_local_dirs(sub): 

""" Get all the directories """ 

path = os.environ.get("SPARK_LOCAL_DIRS", "/tmp") 

dirs = path.split(",") 

75 ↛ 77line 75 didn't jump to line 77, because the condition on line 75 was never true if len(dirs) > 1: 

# different order in different processes and instances 

rnd = random.Random(os.getpid() + id(dirs)) 

random.shuffle(dirs, rnd.random) 

return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs] 

 

 

# global stats 

MemoryBytesSpilled = 0 

DiskBytesSpilled = 0 

 

 

class Aggregator(object): 

 

""" 

Aggregator has tree functions to merge values into combiner. 

 

createCombiner: (value) -> combiner 

mergeValue: (combine, value) -> combiner 

mergeCombiners: (combiner, combiner) -> combiner 

""" 

 

def __init__(self, createCombiner, mergeValue, mergeCombiners): 

self.createCombiner = fail_on_stopiteration(createCombiner) 

self.mergeValue = fail_on_stopiteration(mergeValue) 

self.mergeCombiners = fail_on_stopiteration(mergeCombiners) 

 

 

class SimpleAggregator(Aggregator): 

 

""" 

SimpleAggregator is useful for the cases that combiners have 

same type with values 

""" 

 

def __init__(self, combiner): 

Aggregator.__init__(self, lambda x: x, combiner, combiner) 

 

 

class Merger(object): 

 

""" 

Merge shuffled data together by aggregator 

""" 

 

def __init__(self, aggregator): 

self.agg = aggregator 

 

def mergeValues(self, iterator): 

""" Combine the items by creator and combiner """ 

raise NotImplementedError 

 

def mergeCombiners(self, iterator): 

""" Merge the combined items by mergeCombiner """ 

raise NotImplementedError 

 

def items(self): 

""" Return the merged items ad iterator """ 

raise NotImplementedError 

 

 

def _compressed_serializer(self, serializer=None): 

# always use PickleSerializer to simplify implementation 

ser = PickleSerializer() 

return AutoBatchedSerializer(CompressedSerializer(ser)) 

 

 

class ExternalMerger(Merger): 

 

""" 

External merger will dump the aggregated data into disks when 

memory usage goes above the limit, then merge them together. 

 

This class works as follows: 

 

- It repeatedly combine the items and save them in one dict in 

memory. 

 

- When the used memory goes above memory limit, it will split 

the combined data into partitions by hash code, dump them 

into disk, one file per partition. 

 

- Then it goes through the rest of the iterator, combine items 

into different dict by hash. Until the used memory goes over 

memory limit, it dump all the dicts into disks, one file per 

dict. Repeat this again until combine all the items. 

 

- Before return any items, it will load each partition and 

combine them separately. Yield them before loading next 

partition. 

 

- During loading a partition, if the memory goes over limit, 

it will partition the loaded data and dump them into disks 

and load them partition by partition again. 

 

`data` and `pdata` are used to hold the merged items in memory. 

At first, all the data are merged into `data`. Once the used 

memory goes over limit, the items in `data` are dumped into 

disks, `data` will be cleared, all rest of items will be merged 

into `pdata` and then dumped into disks. Before returning, all 

the items in `pdata` will be dumped into disks. 

 

Finally, if any items were spilled into disks, each partition 

will be merged into `data` and be yielded, then cleared. 

 

Examples 

-------- 

>>> agg = SimpleAggregator(lambda x, y: x + y) 

>>> merger = ExternalMerger(agg, 10) 

>>> N = 10000 

>>> merger.mergeValues(zip(range(N), range(N))) 

>>> assert merger.spills > 0 

>>> sum(v for k,v in merger.items()) 

49995000 

 

>>> merger = ExternalMerger(agg, 10) 

>>> merger.mergeCombiners(zip(range(N), range(N))) 

>>> assert merger.spills > 0 

>>> sum(v for k,v in merger.items()) 

49995000 

""" 

 

# the max total partitions created recursively 

MAX_TOTAL_PARTITIONS = 4096 

 

def __init__(self, aggregator, memory_limit=512, serializer=None, 

localdirs=None, scale=1, partitions=59, batch=1000): 

Merger.__init__(self, aggregator) 

self.memory_limit = memory_limit 

self.serializer = _compressed_serializer(serializer) 

self.localdirs = localdirs or _get_local_dirs(str(id(self))) 

# number of partitions when spill data into disks 

self.partitions = partitions 

# check the memory after # of items merged 

self.batch = batch 

# scale is used to scale down the hash of key for recursive hash map 

self.scale = scale 

# un-partitioned merged data 

self.data = {} 

# partitioned merged data, list of dicts 

self.pdata = [] 

# number of chunks dumped into disks 

self.spills = 0 

# randomize the hash of key, id(o) is the address of o (aligned by 8) 

self._seed = id(self) + 7 

 

def _get_spill_dir(self, n): 

""" Choose one directory for spill by number n """ 

return os.path.join(self.localdirs[n % len(self.localdirs)], str(n)) 

 

def _next_limit(self): 

""" 

Return the next memory limit. If the memory is not released 

after spilling, it will dump the data only when the used memory 

starts to increase. 

""" 

return max(self.memory_limit, get_used_memory() * 1.05) 

 

def mergeValues(self, iterator): 

""" Combine the items by creator and combiner """ 

# speedup attribute lookup 

creator, comb = self.agg.createCombiner, self.agg.mergeValue 

c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, self.batch 

limit = self.memory_limit 

 

for k, v in iterator: 

d = pdata[hfun(k)] if pdata else data 

d[k] = comb(d[k], v) if k in d else creator(v) 

 

c += 1 

if c >= batch: 

if get_used_memory() >= limit: 

self._spill() 

limit = self._next_limit() 

batch /= 2 

c = 0 

else: 

batch *= 1.5 

 

254 ↛ 255line 254 didn't jump to line 255, because the condition on line 254 was never true if get_used_memory() >= limit: 

self._spill() 

 

def _partition(self, key): 

""" Return the partition for key """ 

return hash((key, self._seed)) % self.partitions 

 

def _object_size(self, obj): 

""" How much of memory for this obj, assume that all the objects 

consume similar bytes of memory 

""" 

return 1 

 

def mergeCombiners(self, iterator, limit=None): 

""" Merge (K,V) pair by mergeCombiner """ 

if limit is None: 

limit = self.memory_limit 

# speedup attribute lookup 

comb, hfun, objsize = self.agg.mergeCombiners, self._partition, self._object_size 

c, data, pdata, batch = 0, self.data, self.pdata, self.batch 

for k, v in iterator: 

d = pdata[hfun(k)] if pdata else data 

d[k] = comb(d[k], v) if k in d else v 

if not limit: 

continue 

 

c += objsize(v) 

if c > batch: 

if get_used_memory() > limit: 

self._spill() 

limit = self._next_limit() 

batch /= 2 

c = 0 

else: 

batch *= 1.5 

 

290 ↛ 291line 290 didn't jump to line 291, because the condition on line 290 was never true if limit and get_used_memory() >= limit: 

self._spill() 

 

def _spill(self): 

""" 

dump already partitioned data into disks. 

 

It will dump the data in batch for better performance. 

""" 

global MemoryBytesSpilled, DiskBytesSpilled 

path = self._get_spill_dir(self.spills) 

301 ↛ 304line 301 didn't jump to line 304, because the condition on line 301 was never false if not os.path.exists(path): 

os.makedirs(path) 

 

used_memory = get_used_memory() 

if not self.pdata: 

# The data has not been partitioned, it will iterator the 

# dataset once, write them into different files, has no 

# additional memory. It only called when the memory goes 

# above limit at the first time. 

 

# open all the files for writing 

streams = [open(os.path.join(path, str(i)), 'wb') 

for i in range(self.partitions)] 

 

for k, v in self.data.items(): 

h = self._partition(k) 

# put one item in batch, make it compatible with load_stream 

# it will increase the memory if dump them in batch 

self.serializer.dump_stream([(k, v)], streams[h]) 

 

for s in streams: 

DiskBytesSpilled += s.tell() 

s.close() 

 

self.data.clear() 

self.pdata.extend([{} for i in range(self.partitions)]) 

 

else: 

for i in range(self.partitions): 

p = os.path.join(path, str(i)) 

with open(p, "wb") as f: 

# dump items in batch 

self.serializer.dump_stream(iter(self.pdata[i].items()), f) 

self.pdata[i].clear() 

DiskBytesSpilled += os.path.getsize(p) 

 

self.spills += 1 

gc.collect() # release the memory as much as possible 

MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 

 

def items(self): 

""" Return all merged items as iterator """ 

if not self.pdata and not self.spills: 

return iter(self.data.items()) 

return self._external_items() 

 

def _external_items(self): 

""" Return all partitioned items as iterator """ 

assert not self.data 

350 ↛ 353line 350 didn't jump to line 353, because the condition on line 350 was never false if any(self.pdata): 

self._spill() 

# disable partitioning and spilling when merge combiners from disk 

self.pdata = [] 

 

try: 

for i in range(self.partitions): 

for v in self._merged_items(i): 

yield v 

self.data.clear() 

 

# remove the merged partition 

for j in range(self.spills): 

path = self._get_spill_dir(j) 

os.remove(os.path.join(path, str(i))) 

finally: 

self._cleanup() 

 

def _merged_items(self, index): 

self.data = {} 

limit = self._next_limit() 

for j in range(self.spills): 

path = self._get_spill_dir(j) 

p = os.path.join(path, str(index)) 

# do not check memory during merging 

with open(p, "rb") as f: 

self.mergeCombiners(self.serializer.load_stream(f), 0) 

 

# limit the total partitions 

379 ↛ 382line 379 didn't jump to line 382, because the condition on line 379 was never true if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS 

and j < self.spills - 1 

and get_used_memory() > limit): 

self.data.clear() # will read from disk again 

gc.collect() # release the memory as much as possible 

return self._recursive_merged_items(index) 

 

return self.data.items() 

 

def _recursive_merged_items(self, index): 

""" 

merge the partitioned items and return the as iterator 

 

If one partition can not be fit in memory, then them will be 

partitioned and merged recursively. 

""" 

subdirs = [os.path.join(d, "parts", str(index)) for d in self.localdirs] 

m = ExternalMerger(self.agg, self.memory_limit, self.serializer, subdirs, 

self.scale * self.partitions, self.partitions, self.batch) 

m.pdata = [{} for _ in range(self.partitions)] 

limit = self._next_limit() 

 

for j in range(self.spills): 

path = self._get_spill_dir(j) 

p = os.path.join(path, str(index)) 

with open(p, 'rb') as f: 

m.mergeCombiners(self.serializer.load_stream(f), 0) 

 

if get_used_memory() > limit: 

m._spill() 

limit = self._next_limit() 

 

return m._external_items() 

 

def _cleanup(self): 

""" Clean up all the files in disks """ 

for d in self.localdirs: 

shutil.rmtree(d, True) 

 

 

class ExternalSorter(object): 

""" 

ExternalSorter will divide the elements into chunks, sort them in 

memory and dump them into disks, finally merge them back. 

 

The spilling will only happen when the used memory goes above 

the limit. 

 

Examples 

-------- 

>>> sorter = ExternalSorter(1) # 1M 

>>> import random 

>>> l = list(range(1024)) 

>>> random.shuffle(l) 

>>> sorted(l) == list(sorter.sorted(l)) 

True 

>>> sorted(l) == list(sorter.sorted(l, key=lambda x: -x, reverse=True)) 

True 

""" 

def __init__(self, memory_limit, serializer=None): 

self.memory_limit = memory_limit 

self.local_dirs = _get_local_dirs("sort") 

self.serializer = _compressed_serializer(serializer) 

 

def _get_path(self, n): 

""" Choose one directory for spill by number n """ 

d = self.local_dirs[n % len(self.local_dirs)] 

if not os.path.exists(d): 

os.makedirs(d) 

return os.path.join(d, str(n)) 

 

def _next_limit(self): 

""" 

Return the next memory limit. If the memory is not released 

after spilling, it will dump the data only when the used memory 

starts to increase. 

""" 

return max(self.memory_limit, get_used_memory() * 1.05) 

 

def sorted(self, iterator, key=None, reverse=False): 

""" 

Sort the elements in iterator, do external sort when the memory 

goes above the limit. 

""" 

global MemoryBytesSpilled, DiskBytesSpilled 

batch, limit = 100, self._next_limit() 

chunks, current_chunk = [], [] 

iterator = iter(iterator) 

while True: 

# pick elements in batch 

chunk = list(itertools.islice(iterator, batch)) 

current_chunk.extend(chunk) 

if len(chunk) < batch: 

break 

 

used_memory = get_used_memory() 

if used_memory > limit: 

# sort them inplace will save memory 

current_chunk.sort(key=key, reverse=reverse) 

path = self._get_path(len(chunks)) 

with open(path, 'wb') as f: 

self.serializer.dump_stream(current_chunk, f) 

 

def load(f): 

for v in self.serializer.load_stream(f): 

yield v 

# close the file explicit once we consume all the items 

# to avoid ResourceWarning in Python3 

f.close() 

chunks.append(load(open(path, 'rb'))) 

current_chunk = [] 

MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 

DiskBytesSpilled += os.path.getsize(path) 

os.unlink(path) # data will be deleted after close 

 

494 ↛ 469line 494 didn't jump to line 469, because the condition on line 494 was never false elif not chunks: 

batch = min(int(batch * 1.5), 10000) 

 

current_chunk.sort(key=key, reverse=reverse) 

if not chunks: 

return current_chunk 

 

501 ↛ 504line 501 didn't jump to line 504, because the condition on line 501 was never false if current_chunk: 

chunks.append(iter(current_chunk)) 

 

return heapq.merge(*chunks, key=key, reverse=reverse) 

 

 

class ExternalList(object): 

""" 

ExternalList can have many items which cannot be hold in memory in 

the same time. 

 

Examples 

-------- 

>>> l = ExternalList(list(range(100))) 

>>> len(l) 

100 

>>> l.append(10) 

>>> len(l) 

101 

>>> for i in range(20240): 

... l.append(i) 

>>> len(l) 

20341 

>>> import pickle 

>>> l2 = pickle.loads(pickle.dumps(l)) 

>>> len(l2) 

20341 

>>> list(l2)[100] 

10 

""" 

LIMIT = 10240 

 

def __init__(self, values): 

self.values = values 

self.count = len(values) 

self._file = None 

self._ser = None 

 

def __getstate__(self): 

if self._file is not None: 

self._file.flush() 

with os.fdopen(os.dup(self._file.fileno()), "rb") as f: 

f.seek(0) 

serialized = f.read() 

else: 

serialized = b'' 

return self.values, self.count, serialized 

 

def __setstate__(self, item): 

self.values, self.count, serialized = item 

if serialized: 

self._open_file() 

self._file.write(serialized) 

else: 

self._file = None 

self._ser = None 

 

def __iter__(self): 

if self._file is not None: 

self._file.flush() 

# read all items from disks first 

with os.fdopen(os.dup(self._file.fileno()), 'rb') as f: 

f.seek(0) 

for v in self._ser.load_stream(f): 

yield v 

 

for v in self.values: 

yield v 

 

def __len__(self): 

return self.count 

 

def append(self, value): 

self.values.append(value) 

self.count += 1 

# dump them into disk if the key is huge 

if len(self.values) >= self.LIMIT: 

self._spill() 

 

def _open_file(self): 

dirs = _get_local_dirs("objects") 

d = dirs[id(self) % len(dirs)] 

if not os.path.exists(d): 

os.makedirs(d) 

p = os.path.join(d, str(id(self))) 

self._file = open(p, "w+b", 65536) 

self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024) 

os.unlink(p) 

 

def __del__(self): 

if self._file: 

self._file.close() 

self._file = None 

 

def _spill(self): 

""" dump the values into disk """ 

global MemoryBytesSpilled, DiskBytesSpilled 

if self._file is None: 

self._open_file() 

 

used_memory = get_used_memory() 

pos = self._file.tell() 

self._ser.dump_stream(self.values, self._file) 

self.values = [] 

gc.collect() 

DiskBytesSpilled += self._file.tell() - pos 

MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 

 

 

class ExternalListOfList(ExternalList): 

""" 

An external list for list. 

 

Examples 

-------- 

>>> l = ExternalListOfList([[i, i] for i in range(100)]) 

>>> len(l) 

200 

>>> l.append(range(10)) 

>>> len(l) 

210 

>>> len(list(l)) 

210 

""" 

 

def __init__(self, values): 

ExternalList.__init__(self, values) 

self.count = sum(len(i) for i in values) 

 

def append(self, value): 

ExternalList.append(self, value) 

# already counted 1 in ExternalList.append 

self.count += len(value) - 1 

 

def __iter__(self): 

for values in ExternalList.__iter__(self): 

for v in values: 

yield v 

 

 

class GroupByKey(object): 

""" 

Group a sorted iterator as [(k1, it1), (k2, it2), ...] 

 

Examples 

-------- 

>>> k = [i // 3 for i in range(6)] 

>>> v = [[i] for i in range(6)] 

>>> g = GroupByKey(zip(k, v)) 

>>> [(k, list(it)) for k, it in g] 

[(0, [0, 1, 2]), (1, [3, 4, 5])] 

""" 

 

def __init__(self, iterator): 

self.iterator = iterator 

 

def __iter__(self): 

key, values = None, None 

for k, v in self.iterator: 

if values is not None and k == key: 

values.append(v) 

else: 

if values is not None: 

yield (key, values) 

key = k 

values = ExternalListOfList([v]) 

667 ↛ exitline 667 didn't return from function '__iter__', because the condition on line 667 was never false if values is not None: 

yield (key, values) 

 

 

class ExternalGroupBy(ExternalMerger): 

 

""" 

Group by the items by key. If any partition of them can not been 

hold in memory, it will do sort based group by. 

 

This class works as follows: 

 

- It repeatedly group the items by key and save them in one dict in 

memory. 

 

- When the used memory goes above memory limit, it will split 

the combined data into partitions by hash code, dump them 

into disk, one file per partition. If the number of keys 

in one partitions is smaller than 1000, it will sort them 

by key before dumping into disk. 

 

- Then it goes through the rest of the iterator, group items 

by key into different dict by hash. Until the used memory goes over 

memory limit, it dump all the dicts into disks, one file per 

dict. Repeat this again until combine all the items. It 

also will try to sort the items by key in each partition 

before dumping into disks. 

 

- It will yield the grouped items partitions by partitions. 

If the data in one partitions can be hold in memory, then it 

will load and combine them in memory and yield. 

 

- If the dataset in one partition cannot be hold in memory, 

it will sort them first. If all the files are already sorted, 

it merge them by heap.merge(), so it will do external sort 

for all the files. 

 

- After sorting, `GroupByKey` class will put all the continuous 

items with the same key as a group, yield the values as 

an iterator. 

""" 

SORT_KEY_LIMIT = 1000 

 

def flattened_serializer(self): 

assert isinstance(self.serializer, BatchedSerializer) 

ser = self.serializer 

return FlattenedValuesSerializer(ser, 20) 

 

def _object_size(self, obj): 

return len(obj) 

 

def _spill(self): 

""" 

dump already partitioned data into disks. 

""" 

global MemoryBytesSpilled, DiskBytesSpilled 

path = self._get_spill_dir(self.spills) 

if not os.path.exists(path): 

os.makedirs(path) 

 

used_memory = get_used_memory() 

if not self.pdata: 

# The data has not been partitioned, it will iterator the 

# data once, write them into different files, has no 

# additional memory. It only called when the memory goes 

# above limit at the first time. 

 

# open all the files for writing 

streams = [open(os.path.join(path, str(i)), 'wb') 

for i in range(self.partitions)] 

 

# If the number of keys is small, then the overhead of sort is small 

# sort them before dumping into disks 

self._sorted = len(self.data) < self.SORT_KEY_LIMIT 

if self._sorted: 

self.serializer = self.flattened_serializer() 

for k in sorted(self.data.keys()): 

h = self._partition(k) 

self.serializer.dump_stream([(k, self.data[k])], streams[h]) 

else: 

for k, v in self.data.items(): 

h = self._partition(k) 

self.serializer.dump_stream([(k, v)], streams[h]) 

 

for s in streams: 

DiskBytesSpilled += s.tell() 

s.close() 

 

self.data.clear() 

# self.pdata is cached in `mergeValues` and `mergeCombiners` 

self.pdata.extend([{} for i in range(self.partitions)]) 

 

else: 

for i in range(self.partitions): 

p = os.path.join(path, str(i)) 

with open(p, "wb") as f: 

# dump items in batch 

if self._sorted: 

# sort by key only (stable) 

sorted_items = sorted(self.pdata[i].items(), key=operator.itemgetter(0)) 

self.serializer.dump_stream(sorted_items, f) 

else: 

self.serializer.dump_stream(self.pdata[i].items(), f) 

self.pdata[i].clear() 

DiskBytesSpilled += os.path.getsize(p) 

 

self.spills += 1 

gc.collect() # release the memory as much as possible 

MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 

 

def _merged_items(self, index): 

size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index))) 

for j in range(self.spills)) 

# if the memory can not hold all the partition, 

# then use sort based merge. Because of compression, 

# the data on disks will be much smaller than needed memory 

if size >= self.memory_limit << 17: # * 1M / 8 

return self._merge_sorted_items(index) 

 

self.data = {} 

for j in range(self.spills): 

path = self._get_spill_dir(j) 

p = os.path.join(path, str(index)) 

# do not check memory during merging 

with open(p, "rb") as f: 

self.mergeCombiners(self.serializer.load_stream(f), 0) 

return self.data.items() 

 

def _merge_sorted_items(self, index): 

""" load a partition from disk, then sort and group by key """ 

def load_partition(j): 

path = self._get_spill_dir(j) 

p = os.path.join(path, str(index)) 

with open(p, 'rb', 65536) as f: 

for v in self.serializer.load_stream(f): 

yield v 

 

disk_items = [load_partition(j) for j in range(self.spills)] 

 

if self._sorted: 

# all the partitions are already sorted 

sorted_items = heapq.merge(*disk_items, key=operator.itemgetter(0)) 

 

else: 

# Flatten the combined values, so it will not consume huge 

# memory during merging sort. 

ser = self.flattened_serializer() 

sorter = ExternalSorter(self.memory_limit, ser) 

sorted_items = sorter.sorted(itertools.chain(*disk_items), 

key=operator.itemgetter(0)) 

return ((k, vs) for k, vs in GroupByKey(sorted_items)) 

 

 

if __name__ == "__main__": 

import doctest 

(failure_count, test_count) = doctest.testmod() 

823 ↛ 824line 823 didn't jump to line 824, because the condition on line 823 was never true if failure_count: 

sys.exit(-1)