本文整理了Java中org.deeplearning4j.nn.multilayer.MultiLayerNetwork.fit()
方法的一些代码示例,展示了MultiLayerNetwork.fit()
的具体用法。这些代码示例主要来源于Github
/Stackoverflow
/Maven
等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。MultiLayerNetwork.fit()
方法的具体详情如下:
包路径:org.deeplearning4j.nn.multilayer.MultiLayerNetwork
类名称:MultiLayerNetwork
方法名:fit
[英]Fit the unsupervised model
[中]符合无监督模型
代码示例来源:origin: deeplearning4j/dl4j-examples
network.fit(trainIterator);
代码示例来源:origin: guoguibing/librec
autoRecModel.fit(trainSet, trainSet);
loss = autoRecModel.score();
代码示例来源:origin: guoguibing/librec
loss = 0.0d;
CDAEModel.fit(trainSet, trainSet);
loss = CDAEModel.score();
代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn
/**
* Fit the model
*
* @param data the examples to classify (one example in each row)
* @param labels the example labels(a binary outcome matrix)
*/
@Override
public void fit(INDArray data, INDArray labels) {
fit(data, labels, null, null);
}
代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn
@Override
public void fit() {
fit(input, labels);
}
代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn
@Override
protected void fit(DataSet ds) {
net.fit(ds);
}
代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn
@Override
public void fit(MultiDataSetIterator iterator) {
fit(new MultiDataSetWrapperIterator(iterator));
}
代码示例来源:origin: mccorby/FederatedAndroidTrainer
@Override
public void train(FederatedDataSet trainingData) {
model.fit((DataSet) trainingData.getNativeDataSet());
}
代码示例来源:origin: mccorby/FederatedAndroidTrainer
@Override
public void train(FederatedDataSet trainingData) {
model.fit((DataSet) trainingData.getNativeDataSet());
}
代码示例来源:origin: org.deeplearning4j/deeplearning4j-parallel-wrapper
protected void fit(DataSet dataSet) {
if (replicatedModel instanceof MultiLayerNetwork) {
if (lastEtlTime == null)
lastEtlTime = new AtomicLong(0);
((MultiLayerNetwork) replicatedModel).setLastEtlTime(lastEtlTime.get());
((MultiLayerNetwork) replicatedModel).fit(dataSet);
} else if (replicatedModel instanceof ComputationGraph) {
if (lastEtlTime == null)
lastEtlTime = new AtomicLong(0);
((ComputationGraph) replicatedModel).setLastEtlTime(lastEtlTime.get());
((ComputationGraph) replicatedModel).fit(dataSet);
}
}
代码示例来源:origin: org.deeplearning4j/deeplearning4j-parallel-wrapper_2.11
protected void fit(DataSet dataSet) {
if (replicatedModel instanceof MultiLayerNetwork) {
if (lastEtlTime == null)
lastEtlTime = new AtomicLong(0);
((MultiLayerNetwork) replicatedModel).setLastEtlTime(lastEtlTime.get());
((MultiLayerNetwork) replicatedModel).fit(dataSet);
} else if (replicatedModel instanceof ComputationGraph) {
if (lastEtlTime == null)
lastEtlTime = new AtomicLong(0);
((ComputationGraph) replicatedModel).setLastEtlTime(lastEtlTime.get());
((ComputationGraph) replicatedModel).fit(dataSet);
}
}
代码示例来源:origin: org.deeplearning4j/deeplearning4j-scaleout-akka
@Override
public void perform(Job job) {
Serializable work = job.getWork();
if(work instanceof DataSet) {
DataSet data = (DataSet) work;
multiLayerNetwork.fit(data);
job.setResult(multiLayerNetwork.params());
}
}
代码示例来源:origin: mccorby/FederatedAndroidTrainer
@Override
public void train(FederatedDataSet federatedDataSet) {
DataSet trainingData = (DataSet) federatedDataSet.getNativeDataSet();
List<DataSet> listDs = trainingData.asList();
DataSetIterator mnistTrain = new ListDataSetIterator(listDs, BATCH_SIZE);
for (int i = 0; i < N_EPOCHS; i++) {
model.fit(mnistTrain);
}
}
代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn
public void fitFeaturized(DataSet input) {
if (isGraph) {
unFrozenSubsetGraph.fit(input);
copyParamsFromSubsetGraphToOrig();
} else {
unFrozenSubsetMLN.fit(input);
copyParamsFromSubsetMLNToOrig();
}
}
代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn
public void fitFeaturized(DataSetIterator iter) {
if (isGraph) {
unFrozenSubsetGraph.fit(iter);
copyParamsFromSubsetGraphToOrig();
} else {
unFrozenSubsetMLN.fit(iter);
copyParamsFromSubsetMLNToOrig();
}
}
代码示例来源:origin: org.deeplearning4j/deeplearning4j-keras
/**
* Performs fitting of the model which is referenced in the parameters according to learning parameters specified.
*
* @param entryPointFitParameters Definition of the model and learning process
*/
public void fit(EntryPointFitParameters entryPointFitParameters) throws Exception {
try {
MultiLayerNetwork multiLayerNetwork = neuralNetworkReader.readNeuralNetwork(entryPointFitParameters);
DataSetIterator dataSetIterator =
new HDF5MiniBatchDataSetIterator(entryPointFitParameters.getTrainFeaturesDirectory(),
entryPointFitParameters.getTrainLabelsDirectory());
for (int i = 0; i < entryPointFitParameters.getNbEpoch(); i++) {
log.info("Fitting: " + i);
multiLayerNetwork.fit(dataSetIterator);
}
log.info("Learning model finished");
} catch (Throwable e) {
log.error("Error while handling request!", e);
throw e;
}
}
代码示例来源:origin: mccorby/FederatedAndroidTrainer
@Override
public void train(FederatedDataSet dataSource) {
DataSet trainingData = (DataSet) dataSource.getNativeDataSet();
List<DataSet> listDs = trainingData.asList();
DataSetIterator iterator = new ListDataSetIterator(listDs, BATCH_SIZE);
//Train the network on the full data set, and evaluate in periodically
for (int i = 0; i < N_EPOCHS; i++) {
iterator.reset();
mNetwork.fit(iterator);
}
}
代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn
/**
* Fit the model
*
* @param examples the examples to classify (one example in each row)
* @param labels the labels for each example (the number of labels must match
*/
@Override
public void fit(INDArray examples, int[] labels) {
org.deeplearning4j.nn.conf.layers.OutputLayer layerConf =
(org.deeplearning4j.nn.conf.layers.OutputLayer) getOutputLayer().conf().getLayer();
fit(examples, FeatureUtil.toOutcomeMatrix(labels, layerConf.getNOut()));
}
代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn
@Override
public void fit(MultiDataSet dataSet) {
if (dataSet.getFeatures().length == 1 && dataSet.getLabels().length == 1) {
INDArray features = null;
INDArray labels = null;
INDArray fMask = null;
INDArray lMask = null;
if (dataSet.getFeaturesMaskArrays() != null)
fMask = dataSet.getFeaturesMaskArrays()[0];
if (dataSet.getFeaturesMaskArrays() != null)
lMask = dataSet.getLabelsMaskArrays()[0];
features = dataSet.getFeatures()[0];
labels = dataSet.getLabels()[0];
DataSet ds = new DataSet(features, labels, fMask, lMask);
fit(ds);
}
throw new DL4JInvalidInputException(
"MultiLayerNetwork can't handle MultiDataSet. Please consider use of ComputationGraph");
}
代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn
/**
* Fit the model
*
* @param data the data to train on
*/
@Override
public void fit(org.nd4j.linalg.dataset.api.DataSet data) {
if (layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
doTruncatedBPTT(data.getFeatures(), data.getLabels(), data.getFeaturesMaskArray(),
data.getLabelsMaskArray());
} else {
//Standard training
boolean hasMaskArrays = data.hasMaskArrays();
if (hasMaskArrays)
setLayerMaskArrays(data.getFeaturesMaskArray(), data.getLabelsMaskArray());
fit(data.getFeatures(), data.getLabels());
if (hasMaskArrays)
clearLayerMaskArrays();
}
clearLayersStates();
}
内容来源于网络,如有侵权,请联系作者删除!