Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #
""" Returns the JVM view associated with SparkContext. Must be called after SparkContext is initialized. """ else: raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")
""" Object with a unique ID. """
#: A unique id for the object.
def _randomUID(cls): """ Generate a unique string id for the object. The default implementation concatenates the class name, "_", and 12 random hex chars. """
""" Base class for MLWriter and MLReader. Stores information about the SparkContext and SparkSession.
.. versionadded:: 2.3.0 """
""" Sets the Spark Session to use for saving/loading. """ self._sparkSession = sparkSession return self
def sparkSession(self): """ Returns the user-specified Spark Session or the default. """
def sc(self): """ Returns the underlying `SparkContext`. """
""" Utility class that can save ML instances.
.. versionadded:: 2.0.0 """
"""Save the ML instance to the input path."""
""" save() handles overwriting and then calls this method. Subclasses should override this method to implement the actual saving of the instance. """ raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
"""Overwrites if the output path already exists."""
""" Adds an option to the underlying MLWriter. See the documentation for the specific model's writer for possible options. The option name (key) is case-insensitive. """ self.optionMap[key.lower()] = str(value) return self
""" Utility class that can save ML instances in different formats.
.. versionadded:: 2.4.0 """
""" Specifies the format of ML export ("pmml", "internal", or the fully qualified class name for export). """ self.source = source return self
""" (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaParams` types """
"""Save the ML instance to the input path.""" raise TypeError("path should be a string, got type %s" % type(path))
"""Overwrites if the output path already exists.""" self._jwrite.overwrite() return self
"""Sets the Spark Session to use for saving.""" self._jwrite.session(sparkSession._jsparkSession) return self
""" (Private) Specialization of :py:class:`GeneralMLWriter` for :py:class:`JavaParams` types """
""" Specifies the format of ML export ("pmml", "internal", or the fully qualified class name for export). """
""" Mixin for ML instances that provide :py:class:`MLWriter`.
.. versionadded:: 2.0.0 """
"""Returns an MLWriter instance for this ML instance.""" raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self))
"""Save this ML instance to the given path, a shortcut of 'write().save(path)'."""
""" (Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`. """
"""Returns an MLWriter instance for this ML instance."""
""" (Private) Mixin for ML instances that provide :py:class:`GeneralJavaMLWriter`. """
"""Returns an GeneralMLWriter instance for this ML instance."""
""" Utility class that can load ML instances.
.. versionadded:: 2.0.0 """
"""Load the ML instance from the input path.""" raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
""" (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types """
"""Load the ML instance from the input path.""" raise TypeError("path should be a string, got type %s" % type(path)) raise NotImplementedError("This Java ML type cannot be loaded into Python currently: %r" % self._clazz)
"""Sets the Spark Session to use for loading.""" self._jread.session(sparkSession._jsparkSession) return self
def _java_loader_class(cls, clazz): """ Returns the full class name of the Java ML instance. The default implementation replaces "pyspark" by "org.apache.spark" in the Python full class name. """ # Remove the last package name "pipeline" for Pipeline and PipelineModel.
def _load_java_obj(cls, clazz): """Load the peer Java object of the ML instance."""
""" Mixin for instances that provide :py:class:`MLReader`.
.. versionadded:: 2.0.0 """
def read(cls): """Returns an MLReader instance for this class.""" raise NotImplementedError("MLReadable.read() not implemented for type: %r" % cls)
def load(cls, path): """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
""" (Private) Mixin for instances that provide JavaMLReader. """
def read(cls): """Returns an MLReader instance for this class."""
""" Helper trait for making simple :py:class:`Params` types writable. If a :py:class:`Params` class stores all data as :py:class:`Param` values, then extending this trait will provide a default implementation of writing saved instances of the class. This only handles simple :py:class:`Param` types; e.g., it will not handle :py:class:`pyspark.sql.DataFrame`. See :py:class:`DefaultParamsReadable`, the counterpart to this class.
.. versionadded:: 2.3.0 """
"""Returns a DefaultParamsWriter instance for this class."""
else: raise TypeError("Cannot use DefautParamsWritable with type %s because it does not " + " extend Params.", type(self))
""" Specialization of :py:class:`MLWriter` for :py:class:`Params` types
Class for writing Estimators and Transformers whose parameters are JSON-serializable.
.. versionadded:: 2.3.0 """
def extractJsonParams(instance, skipParams): if param.name not in skipParams}
""" Saves metadata + Params to: path + "/metadata"
- class - timestamp - sparkVersion - uid - paramMap - defaultParamMap (since 2.4.0) - (optionally, extra metadata)
Parameters ---------- extraMetadata : dict, optional Extra metadata to be saved at same level as uid, paramMap, etc. paramMap : dict, optional If given, this is saved in the "paramMap" field. """ sc, extraMetadata, paramMap)
""" Helper for :py:meth:`DefaultParamsWriter.saveMetadata` which extracts the JSON to save. This is useful for ensemble models which need to save metadata for many sub-models.
Notes ----- See :py:meth:`DefaultParamsWriter.saveMetadata` for details on what this includes. """
# User-supplied param values else:
# Default param values
"sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams, "defaultParamMap": jsonDefaultParams}
""" Helper trait for making simple :py:class:`Params` types readable. If a :py:class:`Params` class stores all data as :py:class:`Param` values, then extending this trait will provide a default implementation of reading saved instances of the class. This only handles simple :py:class:`Param` types; e.g., it will not handle :py:class:`pyspark.sql.DataFrame`. See :py:class:`DefaultParamsWritable`, the counterpart to this class.
.. versionadded:: 2.3.0 """
def read(cls): """Returns a DefaultParamsReader instance for this class."""
""" Specialization of :py:class:`MLReader` for :py:class:`Params` types
Default :py:class:`MLReader` implementation for transformers and estimators that contain basic (json-serializable) params and no data. This will not handle more complex params or types with data (e.g., models with coefficients).
.. versionadded:: 2.3.0 """
def __get_class(clazz): """ Loads Python class from its name. """
""" Load metadata saved using :py:meth:`DefaultParamsWriter.saveMetadata`
Parameters ---------- path : str sc : :py:class:`pyspark.SparkContext` expectedClassName : str, optional If non empty, this is checked against the loaded metadata. """
""" Parse metadata JSON string produced by :py:meth`DefaultParamsWriter._get_metadata_to_save`. This is a helper function for :py:meth:`DefaultParamsReader.loadMetadata`.
Parameters ---------- metadataStr : str JSON string of metadata expectedClassName : str, optional If non empty, this is checked against the loaded metadata. """ assert className == expectedClassName, "Error loading metadata: Expected " + \ "class name {} but found class name {}".format(expectedClassName, className)
""" Extract Params from metadata, and set them in the instance. """ # Set user-supplied param values
# Set default param values
# For metadata file prior to Spark 2.4, there is no default section. "`defaultParamMap` section not found"
def isPythonParamsInstance(metadata):
def loadParamsInstance(path, sc): """ Load a :py:class:`Params` instance from the given path, and return it. This assumes the instance inherits from :py:class:`MLReadable`. """ else:
""" Base class for models that provides Training summary.
.. versionadded:: 3.0.0 """
def hasSummary(self): """ Indicates whether a training summary exists for this model instance. """
def summary(self): """ Gets summary of the model trained on the training set. An exception is thrown if no summary exists. """
def isMetaEstimator(pyInstance): (isinstance(pyInstance, Estimator) and isinstance(pyInstance, _ValidatorParams))
def getAllNestedStages(pyInstance):
# TODO: We need to handle `RFormulaModel.pipelineModel` here after Pyspark RFormulaModel # support pipelineModel property. raise ValueError('PySpark does not support nested validator.') else:
def getUidMap(instance): raise RuntimeError(f'{instance.__class__.__module__}.{instance.__class__.__name__}' f'.load found a compound estimator with stages with duplicate ' f'UIDs. List of UIDs: {list(uidMap.keys())}.') |