public abstract class Trainer extends java.lang.Object implements java.lang.Cloneable, Storable
| Modifier and Type | Field and Description |
|---|---|
protected int |
currentEpoch
The current epoch during training.
|
protected int |
epochs
The number of epochs to be trained.
|
protected double |
learnRate
The learn rate of this trainer.
|
protected NeuralNet |
net
The network to be trained.
|
protected boolean |
shuffle
Indicates random shuffling of training data.
|
protected boolean |
stepMode
Indicates step-wise training.
|
protected PatternSet |
testData |
protected PatternSet |
trainingData
The training data set.
|
protected TrainingSignalGenerator |
trainingSignalGenerator
The training signal generator.
|
| Constructor and Description |
|---|
Trainer()
Constructs a trainer with squared error function (learn rate = 0.2, epochs = 1000).
|
| Modifier and Type | Method and Description |
|---|---|
protected abstract double |
calcAdaptation(BrainPart part)
Calculates the adaptation value based on the gradient.
|
Trainer |
clone()
Return a deep clone of this object.
|
protected void |
endBatch()
Updates weights and biases using the batch gradient information.
|
protected abstract void |
endTrain()
Do some house keeping after all epochs have been trained.
|
void |
fromXML(org.jdom2.Element element)
Reads attributes from XML.
|
int |
getBatchSize()
Returns the (mini)batch size.
|
int |
getEpochs()
Returns the number of epochs to be trained.
|
double |
getLearnRate()
Returns the learning rate of the trainer.
|
NeuralNet |
getNetwork()
Returns the network associated with this trainer.
|
int |
getRank(java.util.List<java.lang.Double> input,
int index)
Returns the ranked activation of the given output neuron and the given input pattern.
|
PatternSet |
getTestData()
Returns the programs data set.
|
PatternSet |
getTrainingData()
Returns the training data set.
|
TrainingSignalGenerator |
getTrainingSignalGenerator()
Returns the current training signal generator for this trainer.
|
Neuron |
getWinningNeuron(java.util.List<java.lang.Double> input)
Returns the output neuron with largest activation upon presentation of 'input'.
|
protected abstract void |
reset()
Resets the trainer to initial values in order to start a new training procedure.
|
void |
setBatchSize(int batchSize)
Sets the (mini)batch size.
|
void |
setEpochs(int epochs)
Sets the number of epochs.
|
void |
setLearnRate(double learnRate)
Set the learn rate for this Trainer.
|
void |
setNetwork(NeuralNet net)
Sets the network for this trainer.
|
void |
setShuffle(boolean shuffle)
Sets the shuffle flag.
|
void |
setStepMode(boolean stepMode)
Sets the step mode.
|
void |
setTestData(PatternSet testData)
Sets the programs data set.
|
void |
setTrainingData(PatternSet trainingData)
Sets the training data set.
|
void |
setTrainingSignalGenerator(TrainingSignalGenerator trainingSignalGenerator)
Set the training signal generator for this trainer.
|
double |
test()
Returns the error on the test data.
|
double |
test(java.util.List<java.lang.Double> input,
java.util.List<java.lang.Double> target)
Test a given input pattern and return the error.
|
java.lang.String |
toString() |
org.jdom2.Element |
toXML(org.jdom2.Element element)
Writes attributes to XML.
|
void |
train()
Performs training in the following order:
|
protected abstract void |
train(java.util.List<java.lang.Double> inputPattern,
java.util.List<java.lang.Double> targetPattern)
Trains a single pattern once.
|
void |
trainPattern(java.util.List<java.lang.Double> inputPattern,
java.util.List<java.lang.Double> targetPattern)
Trains a single pattern once and updates the weights.
|
protected void |
updateLinks()
Adapts neuron link values after a (mini)batch.
|
protected void |
updateNeurons()
Adapts neuron bias values after a (mini)batch.
|
protected NeuralNet net
protected PatternSet trainingData
protected PatternSet testData
protected TrainingSignalGenerator trainingSignalGenerator
protected boolean stepMode
protected boolean shuffle
protected double learnRate
protected int epochs
protected int currentEpoch
public Trainer()
public Trainer clone()
clone in class java.lang.Objectpublic PatternSet getTrainingData()
public void setTrainingData(PatternSet trainingData)
trainingData - a pattern set (its size is the default batch size)public PatternSet getTestData()
public void setTestData(PatternSet testData)
testData - a pattern setpublic int getBatchSize()
public void setBatchSize(int batchSize)
batchSize - the size of a (mini)batchpublic void setStepMode(boolean stepMode)
stepMode - the mode flagpublic void setShuffle(boolean shuffle)
shuffle - if true, shuffle training datapublic void train()
- calls reset(), if not step mode
- runs through the number of epochs presenting each pattern by calling train(java.util.List, java.util.List)
- after each epoch endBatch() is called, then training data are shuffled (setShuffle(boolean))
- after all epochs are done, endTrain() is called
Also works for streamed training patterns.
public void trainPattern(java.util.List<java.lang.Double> inputPattern,
java.util.List<java.lang.Double> targetPattern)
inputPattern - an input patterntargetPattern - a target pattern (may be null)public double test(java.util.List<java.lang.Double> input,
java.util.List<java.lang.Double> target)
input - input patterntarget - target patternpublic double test()
protected void endBatch()
protected void updateNeurons()
protected void updateLinks()
public NeuralNet getNetwork()
public void setNetwork(NeuralNet net)
net - a neural netpublic void setLearnRate(double learnRate)
learnRate - the learning ratelearnRatepublic double getLearnRate()
learnRatepublic int getEpochs()
epochspublic void setEpochs(int epochs)
epochs - the number of epochsepochspublic TrainingSignalGenerator getTrainingSignalGenerator()
public void setTrainingSignalGenerator(TrainingSignalGenerator trainingSignalGenerator)
public Neuron getWinningNeuron(java.util.List<java.lang.Double> input)
input - an input patternpublic int getRank(java.util.List<java.lang.Double> input,
int index)
input - an input patternindex - the index of the output neuronpublic java.lang.String toString()
toString in class java.lang.Objectpublic void fromXML(org.jdom2.Element element)
public org.jdom2.Element toXML(org.jdom2.Element element)
protected abstract void reset()
protected abstract void endTrain()
protected abstract void train(java.util.List<java.lang.Double> inputPattern,
java.util.List<java.lang.Double> targetPattern)
inputPattern - an input patterntargetPattern - a target pattern (may be null)protected abstract double calcAdaptation(BrainPart part)
part - the brain part to be adapted