public class BoostingStrategy
extends Object
implements scala.Serializable, scala.Product
GradientBoostedTrees
.
param: treeStrategy Parameters for the tree algorithm. We support regression and binary
classification for boosting. Impurity setting will be ignored.
param: loss Loss function used for minimization during gradient boosting.
param: numIterations Number of iterations of boosting. In other words, the number of
weak hypotheses used in the final model.
param: learningRate Learning rate for shrinking the contribution of each estimator. The
learning rate should be between in the interval (0, 1]
param: validationTol validationTol is a condition which decides iteration termination when
runWithValidation is used.
The end of iteration is decided based on below logic:
If the current loss on the validation set is greater than 0.01, the diff
of validation error is compared to relative tolerance which is
validationTol * (current loss on the validation set).
If the current loss on the validation set is less than or equal to 0.01,
the diff of validation error is compared to absolute tolerance which is
validationTol * 0.01.
Ignored when
org.apache.spark.mllib.tree.GradientBoostedTrees.run()
is used.
Constructor and Description |
---|
BoostingStrategy(Strategy treeStrategy,
Loss loss,
int numIterations,
double learningRate,
double validationTol) |
Modifier and Type | Method and Description |
---|---|
abstract static boolean |
canEqual(Object that) |
static BoostingStrategy |
defaultParams(scala.Enumeration.Value algo)
Returns default configuration for the boosting algorithm
|
static BoostingStrategy |
defaultParams(String algo)
Returns default configuration for the boosting algorithm
|
abstract static boolean |
equals(Object that) |
double |
getLearningRate() |
Loss |
getLoss() |
int |
getNumIterations() |
Strategy |
getTreeStrategy() |
double |
getValidationTol() |
double |
learningRate() |
Loss |
loss() |
int |
numIterations() |
abstract static int |
productArity() |
abstract static Object |
productElement(int n) |
static scala.collection.Iterator<Object> |
productIterator() |
static String |
productPrefix() |
void |
setLearningRate(double x$1) |
void |
setLoss(Loss x$1) |
void |
setNumIterations(int x$1) |
void |
setTreeStrategy(Strategy x$1) |
void |
setValidationTol(double x$1) |
Strategy |
treeStrategy() |
double |
validationTol() |
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
public static BoostingStrategy defaultParams(String algo)
algo
- Learning goal. Supported: "Classification" or "Regression"public static BoostingStrategy defaultParams(scala.Enumeration.Value algo)
algo
- Learning goal. Supported:
org.apache.spark.mllib.tree.configuration.Algo.Classification
,
org.apache.spark.mllib.tree.configuration.Algo.Regression
public abstract static boolean canEqual(Object that)
public abstract static boolean equals(Object that)
public abstract static Object productElement(int n)
public abstract static int productArity()
public static scala.collection.Iterator<Object> productIterator()
public static String productPrefix()
public Strategy treeStrategy()
public void setTreeStrategy(Strategy x$1)
public Loss loss()
public void setLoss(Loss x$1)
public int numIterations()
public void setNumIterations(int x$1)
public double learningRate()
public void setLearningRate(double x$1)
public double validationTol()
public void setValidationTol(double x$1)
public Strategy getTreeStrategy()
public Loss getLoss()
public int getNumIterations()
public double getLearningRate()
public double getValidationTol()