public class SAETrainer extends Trainer
| Modifier and Type | Field and Description |
|---|---|
protected NeuronList |
hiddenLayer
Hidden layer
|
protected NeuralNet |
partNet
Layer that is trained - extracted net
|
protected Trainer |
partTrainer
The trainer internally used to train each layer
|
protected boolean |
print
Flag indicating how much training information is printed to console
|
protected int |
steps
The number of steps used for training
|
protected boolean |
wasInputLayer
Flag indicating if the current input layer is the input layer of whole net
|
protected boolean |
wasOutputLayer
Flag indicating if the current hidden layer is the output layer of whole net
|
currentEpoch, epochs, learnRate, net, shuffle, stepMode, testData, trainingData, trainingSignalGenerator| Constructor and Description |
|---|
SAETrainer()
Create a new SAE trainer using a
SquareError. |
SAETrainer(Trainer trainer,
int steps)
Create a new SAE trainer using a
SquareError, a given trainer and a given number of steps |
| Modifier and Type | Method and Description |
|---|---|
protected double |
calcAdaptation(BrainPart part)
Not used.
|
protected void |
endBatch()
Updates weights and biases using the batch gradient information.
|
protected void |
endTrain()
Do some house keeping after all epochs have been trained.
|
protected void |
extractHiddenLayer(NeuralNet partNet,
int layer)
Extract hidden layer (neurons) to setup part net hidden layer
Hidden layer is inserted by REFERENCE into part net
|
protected void |
extractInputOutputLayer(NeuralNet partNet,
int layer)
Extract input layer (neurons and links) to setup part net input and output layer
Input layer is inserted by REFERENCE into part net
Copy of input layer is inserted as COPY into part net
|
protected java.util.List<java.lang.Double> |
getHiddenLayerOutput(java.util.List<java.lang.Double> outputs) |
protected NeuronList |
getNeuronsOfLayer(int layer)
Extract all neurons that are part of same layer
|
protected PatternSet |
getNextSet(NeuralNet partNet,
PatternSet patterns)
Generate the next pattern set from current pattern set by extracting hidden layer activation.
|
protected int |
numberOfLayers()
Return number of layers
|
protected void |
reset()
Resets the trainer to initial values in order to start a new training procedure.
|
protected void |
resetLayer(int ioLayer)
Reset STRUCTURE of net (activation function, bias, input or output...)
by resetting of neurons in part net AFTER training
|
protected void |
resetTrainer()
Trainer gets new part net to train
|
void |
setPrint(boolean printing) |
void |
train()
Training feed-forward network layer-wise to become auto-encoding.
|
protected void |
train(java.util.List<java.lang.Double> inputPattern,
java.util.List<java.lang.Double> targetPattern)
Trains a single pattern once.
|
protected void |
updatePartTrainer()
Pass configuration of SAE trainer down to part trainer for AE net
|
protected void |
updatePartTrainerData()
Pass training data of SAE trainer down to part trainer for AE net
|
clone, fromXML, getBatchSize, getEpochs, getLearnRate, getNetwork, getRank, getTestData, getTrainingData, getTrainingSignalGenerator, getWinningNeuron, setBatchSize, setEpochs, setLearnRate, setNetwork, setShuffle, setStepMode, setTestData, setTrainingData, setTrainingSignalGenerator, test, test, toString, toXML, trainPattern, updateLinks, updateNeuronsprotected Trainer partTrainer
protected NeuralNet partNet
protected NeuronList hiddenLayer
protected boolean wasInputLayer
protected boolean wasOutputLayer
protected int steps
protected boolean print
public SAETrainer()
SquareError.public SAETrainer(Trainer trainer, int steps)
SquareError, a given trainer and a given number of stepsprotected void train(java.util.List<java.lang.Double> inputPattern,
java.util.List<java.lang.Double> targetPattern)
Trainerprotected double calcAdaptation(BrainPart part)
calcAdaptation in class Trainerpart - the brain part to be adaptedprotected void reset()
Trainerprotected void endBatch()
Trainerprotected void endTrain()
Trainerpublic void train()
protected void resetLayer(int ioLayer)
ioLayer - actual layer number of input layer of part netprotected void resetTrainer()
protected void updatePartTrainer()
protected void updatePartTrainerData()
protected NeuronList getNeuronsOfLayer(int layer)
layer - to get neurons fromprotected void extractInputOutputLayer(NeuralNet partNet, int layer)
layer - input layer to extractprotected void extractHiddenLayer(NeuralNet partNet, int layer)
layer - hidden layer to extractprotected int numberOfLayers()
protected PatternSet getNextSet(NeuralNet partNet, PatternSet patterns)
partNet - current part netpatterns - current patternsprotected java.util.List<java.lang.Double> getHiddenLayerOutput(java.util.List<java.lang.Double> outputs)
public void setPrint(boolean printing)