Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259

260

261

262

263

264

265

266

267

268

269

270

271

272

273

274

275

276

277

278

279

280

281

282

283

284

285

286

287

288

289

290

291

292

293

294

295

296

297

298

299

300

301

302

303

304

305

306

307

308

309

310

311

312

313

314

315

316

317

318

319

320

321

322

323

324

325

326

327

328

329

330

331

332

333

334

335

336

337

338

339

340

341

342

343

344

345

346

347

348

349

350

351

352

353

354

355

356

357

358

359

360

361

362

363

364

365

366

367

368

369

370

371

372

373

374

375

376

377

378

379

380

381

382

383

384

385

386

387

388

389

390

391

392

393

394

395

396

397

398

399

400

401

402

403

404

405

406

407

408

409

410

411

412

413

414

415

416

417

418

419

420

421

422

423

424

425

426

427

428

429

430

431

432

433

434

435

436

437

438

439

440

441

442

443

444

445

446

447

448

449

450

451

452

453

454

455

456

457

458

459

460

461

462

463

464

465

466

467

468

469

470

471

472

473

474

475

476

477

478

479

480

481

482

483

484

485

486

487

488

489

490

491

492

493

494

495

496

497

498

499

500

501

502

503

504

505

506

507

508

509

510

511

512

513

514

515

516

517

518

519

520

521

522

523

524

525

526

527

528

529

530

531

532

533

534

535

536

537

538

539

540

541

542

543

544

545

546

547

548

549

550

551

552

553

554

555

556

557

558

559

560

# 

# Licensed to the Apache Software Foundation (ASF) under one or more 

# contributor license agreements. See the NOTICE file distributed with 

# this work for additional information regarding copyright ownership. 

# The ASF licenses this file to You under the Apache License, Version 2.0 

# (the "License"); you may not use this file except in compliance with 

# the License. You may obtain a copy of the License at 

# 

# http://www.apache.org/licenses/LICENSE-2.0 

# 

# Unless required by applicable law or agreed to in writing, software 

# distributed under the License is distributed on an "AS IS" BASIS, 

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

# See the License for the specific language governing permissions and 

# limitations under the License. 

# 

import array 

from abc import ABCMeta 

import copy 

 

import numpy as np 

from py4j.java_gateway import JavaObject 

 

from pyspark.ml.linalg import DenseVector, Vector, Matrix 

from pyspark.ml.util import Identifiable 

 

 

__all__ = ['Param', 'Params', 'TypeConverters'] 

 

 

class Param(object): 

""" 

A param with self-contained documentation. 

 

.. versionadded:: 1.3.0 

""" 

 

def __init__(self, parent, name, doc, typeConverter=None): 

39 ↛ 40line 39 didn't jump to line 40, because the condition on line 39 was never true if not isinstance(parent, Identifiable): 

raise TypeError("Parent must be an Identifiable but got type %s." % type(parent)) 

self.parent = parent.uid 

self.name = str(name) 

self.doc = str(doc) 

self.typeConverter = TypeConverters.identity if typeConverter is None else typeConverter 

 

def _copy_new_parent(self, parent): 

"""Copy the current param to a new parent, must be a dummy param.""" 

if self.parent == "undefined": 

param = copy.copy(self) 

param.parent = parent.uid 

return param 

else: 

raise ValueError("Cannot copy from non-dummy parent %s." % parent) 

 

def __str__(self): 

return str(self.parent) + "__" + self.name 

 

def __repr__(self): 

return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc) 

 

def __hash__(self): 

return hash(str(self)) 

 

def __eq__(self, other): 

65 ↛ 68line 65 didn't jump to line 68, because the condition on line 65 was never false if isinstance(other, Param): 

return self.parent == other.parent and self.name == other.name 

else: 

return False 

 

 

class TypeConverters(object): 

""" 

Factory methods for common type conversion functions for `Param.typeConverter`. 

 

.. versionadded:: 2.0.0 

""" 

 

@staticmethod 

def _is_numeric(value): 

vtype = type(value) 

return vtype in [int, float, np.float64, np.int64] or vtype.__name__ == 'long' 

 

@staticmethod 

def _is_integer(value): 

return TypeConverters._is_numeric(value) and float(value).is_integer() 

 

@staticmethod 

def _can_convert_to_list(value): 

vtype = type(value) 

return vtype in [list, np.ndarray, tuple, range, array.array] or isinstance(value, Vector) 

 

@staticmethod 

def _can_convert_to_string(value): 

vtype = type(value) 

return isinstance(value, str) or vtype in [np.unicode_, np.string_, np.str_] 

 

@staticmethod 

def identity(value): 

""" 

Dummy converter that just returns value. 

""" 

return value 

 

@staticmethod 

def toList(value): 

""" 

Convert a value to a list, if possible. 

""" 

if type(value) == list: 

return value 

elif type(value) in [np.ndarray, tuple, range, array.array]: 

return list(value) 

113 ↛ 116line 113 didn't jump to line 116, because the condition on line 113 was never false elif isinstance(value, Vector): 

return list(value.toArray()) 

else: 

raise TypeError("Could not convert %s to list" % value) 

 

@staticmethod 

def toListFloat(value): 

""" 

Convert a value to list of floats, if possible. 

""" 

if TypeConverters._can_convert_to_list(value): 

value = TypeConverters.toList(value) 

if all(map(lambda v: TypeConverters._is_numeric(v), value)): 

return [float(v) for v in value] 

raise TypeError("Could not convert %s to list of floats" % value) 

 

@staticmethod 

def toListListFloat(value): 

""" 

Convert a value to list of list of floats, if possible. 

""" 

134 ↛ 137line 134 didn't jump to line 137, because the condition on line 134 was never false if TypeConverters._can_convert_to_list(value): 

value = TypeConverters.toList(value) 

return [TypeConverters.toListFloat(v) for v in value] 

raise TypeError("Could not convert %s to list of list of floats" % value) 

 

@staticmethod 

def toListInt(value): 

""" 

Convert a value to list of ints, if possible. 

""" 

144 ↛ 148line 144 didn't jump to line 148, because the condition on line 144 was never false if TypeConverters._can_convert_to_list(value): 

value = TypeConverters.toList(value) 

if all(map(lambda v: TypeConverters._is_integer(v), value)): 

return [int(v) for v in value] 

raise TypeError("Could not convert %s to list of ints" % value) 

 

@staticmethod 

def toListString(value): 

""" 

Convert a value to list of strings, if possible. 

""" 

155 ↛ 159line 155 didn't jump to line 159, because the condition on line 155 was never false if TypeConverters._can_convert_to_list(value): 

value = TypeConverters.toList(value) 

if all(map(lambda v: TypeConverters._can_convert_to_string(v), value)): 

return [TypeConverters.toString(v) for v in value] 

raise TypeError("Could not convert %s to list of strings" % value) 

 

@staticmethod 

def toVector(value): 

""" 

Convert a value to a MLlib Vector, if possible. 

""" 

if isinstance(value, Vector): 

return value 

168 ↛ 172line 168 didn't jump to line 172, because the condition on line 168 was never false elif TypeConverters._can_convert_to_list(value): 

value = TypeConverters.toList(value) 

if all(map(lambda v: TypeConverters._is_numeric(v), value)): 

return DenseVector(value) 

raise TypeError("Could not convert %s to vector" % value) 

 

@staticmethod 

def toMatrix(value): 

""" 

Convert a value to a MLlib Matrix, if possible. 

""" 

179 ↛ 181line 179 didn't jump to line 181, because the condition on line 179 was never false if isinstance(value, Matrix): 

return value 

raise TypeError("Could not convert %s to matrix" % value) 

 

@staticmethod 

def toFloat(value): 

""" 

Convert a value to a float, if possible. 

""" 

if TypeConverters._is_numeric(value): 

return float(value) 

else: 

raise TypeError("Could not convert %s to float" % value) 

 

@staticmethod 

def toInt(value): 

""" 

Convert a value to an int, if possible. 

""" 

if TypeConverters._is_integer(value): 

return int(value) 

else: 

raise TypeError("Could not convert %s to int" % value) 

 

@staticmethod 

def toString(value): 

""" 

Convert a value to a string, if possible. 

""" 

if isinstance(value, str): 

return value 

210 ↛ 211line 210 didn't jump to line 211, because the condition on line 210 was never true elif type(value) in [np.string_, np.str_, np.unicode_]: 

return str(value) 

else: 

raise TypeError("Could not convert %s to string type" % type(value)) 

 

@staticmethod 

def toBoolean(value): 

""" 

Convert a value to a boolean, if possible. 

""" 

if type(value) == bool: 

return value 

else: 

raise TypeError("Boolean Param requires value of type bool. Found %s." % type(value)) 

 

 

class Params(Identifiable, metaclass=ABCMeta): 

""" 

Components that take parameters. This also provides an internal 

param map to store parameter values attached to the instance. 

 

.. versionadded:: 1.3.0 

""" 

 

def __init__(self): 

super(Params, self).__init__() 

#: internal param map for user-supplied values param map 

self._paramMap = {} 

 

#: internal param map for default values 

self._defaultParamMap = {} 

 

#: value returned by :py:func:`params` 

self._params = None 

 

# Copy the params from the class to the object 

self._copy_params() 

 

def _copy_params(self): 

""" 

Copy all params defined on the class to current object. 

""" 

cls = type(self) 

src_name_attrs = [(x, getattr(cls, x)) for x in dir(cls)] 

src_params = list(filter(lambda nameAttr: isinstance(nameAttr[1], Param), src_name_attrs)) 

for name, param in src_params: 

setattr(self, name, param._copy_new_parent(self)) 

 

@property 

def params(self): 

""" 

Returns all params ordered by name. The default implementation 

uses :py:func:`dir` to get all attributes of type 

:py:class:`Param`. 

""" 

if self._params is None: 

self._params = list(filter(lambda attr: isinstance(attr, Param), 

[getattr(self, x) for x in dir(self) if x != "params" and 

not isinstance(getattr(type(self), x, None), property)])) 

return self._params 

 

def explainParam(self, param): 

""" 

Explains a single param and returns its name, doc, and optional 

default value and user-supplied value in a string. 

""" 

param = self._resolveParam(param) 

values = [] 

if self.isDefined(param): 

279 ↛ 281line 279 didn't jump to line 281, because the condition on line 279 was never false if param in self._defaultParamMap: 

values.append("default: %s" % self._defaultParamMap[param]) 

281 ↛ 282line 281 didn't jump to line 282, because the condition on line 281 was never true if param in self._paramMap: 

values.append("current: %s" % self._paramMap[param]) 

else: 

values.append("undefined") 

valueStr = "(" + ", ".join(values) + ")" 

return "%s: %s %s" % (param.name, param.doc, valueStr) 

 

def explainParams(self): 

""" 

Returns the documentation of all params with their optionally 

default values and user-supplied values. 

""" 

return "\n".join([self.explainParam(param) for param in self.params]) 

 

def getParam(self, paramName): 

""" 

Gets a param by its name. 

""" 

param = getattr(self, paramName) 

300 ↛ 303line 300 didn't jump to line 303, because the condition on line 300 was never false if isinstance(param, Param): 

return param 

else: 

raise ValueError("Cannot find param with name %s." % paramName) 

 

def isSet(self, param): 

""" 

Checks whether a param is explicitly set by user. 

""" 

param = self._resolveParam(param) 

return param in self._paramMap 

 

def hasDefault(self, param): 

""" 

Checks whether a param has a default value. 

""" 

param = self._resolveParam(param) 

return param in self._defaultParamMap 

 

def isDefined(self, param): 

""" 

Checks whether a param is explicitly set by user or has 

a default value. 

""" 

return self.isSet(param) or self.hasDefault(param) 

 

def hasParam(self, paramName): 

""" 

Tests whether this instance contains a param with a given 

(string) name. 

""" 

331 ↛ 335line 331 didn't jump to line 335, because the condition on line 331 was never false if isinstance(paramName, str): 

p = getattr(self, paramName, None) 

return isinstance(p, Param) 

else: 

raise TypeError("hasParam(): paramName must be a string") 

 

def getOrDefault(self, param): 

""" 

Gets the value of a param in the user-supplied param map or its 

default value. Raises an error if neither is set. 

""" 

param = self._resolveParam(param) 

if param in self._paramMap: 

return self._paramMap[param] 

else: 

return self._defaultParamMap[param] 

 

def extractParamMap(self, extra=None): 

""" 

Extracts the embedded default param values and user-supplied 

values, and then merges them with extra values from input into 

a flat param map, where the latter value is used if there exist 

conflicts, i.e., with ordering: default param values < 

user-supplied values < extra. 

 

Parameters 

---------- 

extra : dict, optional 

extra param values 

 

Returns 

------- 

dict 

merged param map 

""" 

366 ↛ 368line 366 didn't jump to line 368, because the condition on line 366 was never false if extra is None: 

extra = dict() 

paramMap = self._defaultParamMap.copy() 

paramMap.update(self._paramMap) 

paramMap.update(extra) 

return paramMap 

 

def copy(self, extra=None): 

""" 

Creates a copy of this instance with the same uid and some 

extra params. The default implementation creates a 

shallow copy using :py:func:`copy.copy`, and then copies the 

embedded and extra parameters over and returns the copy. 

Subclasses should override this method if the default approach 

is not sufficient. 

 

Parameters 

---------- 

extra : dict, optional 

Extra parameters to copy to the new instance 

 

Returns 

------- 

:py:class:`Params` 

Copy of this instance 

""" 

if extra is None: 

extra = dict() 

that = copy.copy(self) 

that._paramMap = {} 

that._defaultParamMap = {} 

return self._copyValues(that, extra) 

 

def set(self, param, value): 

""" 

Sets a parameter in the embedded param map. 

""" 

self._shouldOwn(param) 

try: 

value = param.typeConverter(value) 

except ValueError as e: 

raise ValueError('Invalid param value given for param "%s". %s' % (param.name, e)) 

self._paramMap[param] = value 

 

def _shouldOwn(self, param): 

""" 

Validates that the input param belongs to this Params instance. 

""" 

if not (self.uid == param.parent and self.hasParam(param.name)): 

raise ValueError("Param %r does not belong to %r." % (param, self)) 

 

def _resolveParam(self, param): 

""" 

Resolves a param and validates the ownership. 

 

Parameters 

---------- 

param : str or :py:class:`Param` 

param name or the param instance, which must 

belong to this Params instance 

 

Returns 

------- 

:py:class:`Param` 

resolved param instance 

""" 

if isinstance(param, Param): 

self._shouldOwn(param) 

return param 

elif isinstance(param, str): 

return self.getParam(param) 

else: 

raise TypeError("Cannot resolve %r as a param." % param) 

 

def _testOwnParam(self, param_parent, param_name): 

""" 

Test the ownership. Return True or False 

""" 

return self.uid == param_parent and self.hasParam(param_name) 

 

@staticmethod 

def _dummy(): 

""" 

Returns a dummy Params instance used as a placeholder to 

generate docs. 

""" 

dummy = Params() 

dummy.uid = "undefined" 

return dummy 

 

def _set(self, **kwargs): 

""" 

Sets user-supplied params. 

""" 

for param, value in kwargs.items(): 

p = getattr(self, param) 

462 ↛ 467line 462 didn't jump to line 467, because the condition on line 462 was never false if value is not None: 

try: 

value = p.typeConverter(value) 

except TypeError as e: 

raise TypeError('Invalid param value given for param "%s". %s' % (p.name, e)) 

self._paramMap[p] = value 

return self 

 

def clear(self, param): 

""" 

Clears a param from the param map if it has been explicitly set. 

""" 

if self.isSet(param): 

del self._paramMap[param] 

 

def _setDefault(self, **kwargs): 

""" 

Sets default params. 

""" 

for param, value in kwargs.items(): 

p = getattr(self, param) 

483 ↛ 489line 483 didn't jump to line 489, because the condition on line 483 was never false if value is not None and not isinstance(value, JavaObject): 

try: 

value = p.typeConverter(value) 

except TypeError as e: 

raise TypeError('Invalid default param value given for param "%s". %s' 

% (p.name, e)) 

self._defaultParamMap[p] = value 

return self 

 

def _copyValues(self, to, extra=None): 

""" 

Copies param values from this instance to another instance for 

params shared by them. 

 

Parameters 

---------- 

to : :py:class:`Params` 

the target instance 

extra : dict, optional 

extra params to be copied 

 

Returns 

------- 

:py:class:`Params` 

the target instance with param values copied 

""" 

paramMap = self._paramMap.copy() 

if isinstance(extra, dict): 

for param, value in extra.items(): 

if isinstance(param, Param): 

paramMap[param] = value 

else: 

raise TypeError("Expecting a valid instance of Param, but received: {}" 

.format(param)) 

elif extra is not None: 

raise TypeError("Expecting a dict, but received an object of type {}." 

.format(type(extra))) 

for param in self.params: 

# copy default params 

if param in self._defaultParamMap and to.hasParam(param.name): 

to._defaultParamMap[to.getParam(param.name)] = self._defaultParamMap[param] 

# copy explicitly set params 

if param in paramMap and to.hasParam(param.name): 

to._set(**{param.name: paramMap[param]}) 

return to 

 

def _resetUid(self, newUid): 

""" 

Changes the uid of this instance. This updates both 

the stored uid and the parent uid of params and param maps. 

This is used by persistence (loading). 

 

Parameters 

---------- 

newUid 

new uid to use, which is converted to unicode 

 

Returns 

------- 

:py:class:`Params` 

same instance, but with the uid and Param.parent values 

updated, including within param maps 

""" 

newUid = str(newUid) 

self.uid = newUid 

newDefaultParamMap = dict() 

newParamMap = dict() 

for param in self.params: 

newParam = copy.copy(param) 

newParam.parent = newUid 

if param in self._defaultParamMap: 

newDefaultParamMap[newParam] = self._defaultParamMap[param] 

if param in self._paramMap: 

newParamMap[newParam] = self._paramMap[param] 

param.parent = newUid 

self._defaultParamMap = newDefaultParamMap 

self._paramMap = newParamMap 

return self