org.deeplearning4j.nn.multilayer.MultiLayerNetwork.fit()方法的使用及代码示例

x33g5p2x  于2022-01-25 转载在 其他  
字(6.6k)|赞(0)|评价(0)|浏览(129)

本文整理了Java中org.deeplearning4j.nn.multilayer.MultiLayerNetwork.fit()方法的一些代码示例,展示了MultiLayerNetwork.fit()的具体用法。这些代码示例主要来源于Github/Stackoverflow/Maven等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。MultiLayerNetwork.fit()方法的具体详情如下:
包路径:org.deeplearning4j.nn.multilayer.MultiLayerNetwork
类名称:MultiLayerNetwork
方法名:fit

MultiLayerNetwork.fit介绍

[英]Fit the unsupervised model
[中]符合无监督模型

代码示例

代码示例来源:origin: deeplearning4j/dl4j-examples

network.fit(trainIterator);

代码示例来源:origin: guoguibing/librec

autoRecModel.fit(trainSet, trainSet);
loss = autoRecModel.score();

代码示例来源:origin: guoguibing/librec

loss = 0.0d;
CDAEModel.fit(trainSet, trainSet);
loss = CDAEModel.score();

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

/**
 * Fit the model
 *
 * @param data   the examples to classify (one example in each row)
 * @param labels the example labels(a binary outcome matrix)
 */
@Override
public void fit(INDArray data, INDArray labels) {
  fit(data, labels, null, null);
}

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

@Override
public void fit() {
  fit(input, labels);
}

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

@Override
protected void fit(DataSet ds) {
  net.fit(ds);
}

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

@Override
public void fit(MultiDataSetIterator iterator) {
  fit(new MultiDataSetWrapperIterator(iterator));
}

代码示例来源:origin: mccorby/FederatedAndroidTrainer

@Override
public void train(FederatedDataSet trainingData) {
  model.fit((DataSet) trainingData.getNativeDataSet());
}

代码示例来源:origin: mccorby/FederatedAndroidTrainer

@Override
public void train(FederatedDataSet trainingData) {
  model.fit((DataSet) trainingData.getNativeDataSet());
}

代码示例来源:origin: org.deeplearning4j/deeplearning4j-parallel-wrapper

protected void fit(DataSet dataSet) {
  if (replicatedModel instanceof MultiLayerNetwork) {
    if (lastEtlTime == null)
      lastEtlTime = new AtomicLong(0);
    ((MultiLayerNetwork) replicatedModel).setLastEtlTime(lastEtlTime.get());
    ((MultiLayerNetwork) replicatedModel).fit(dataSet);
  } else if (replicatedModel instanceof ComputationGraph) {
    if (lastEtlTime == null)
      lastEtlTime = new AtomicLong(0);
    ((ComputationGraph) replicatedModel).setLastEtlTime(lastEtlTime.get());
    ((ComputationGraph) replicatedModel).fit(dataSet);
  }
}

代码示例来源:origin: org.deeplearning4j/deeplearning4j-parallel-wrapper_2.11

protected void fit(DataSet dataSet) {
  if (replicatedModel instanceof MultiLayerNetwork) {
    if (lastEtlTime == null)
      lastEtlTime = new AtomicLong(0);
    ((MultiLayerNetwork) replicatedModel).setLastEtlTime(lastEtlTime.get());
    ((MultiLayerNetwork) replicatedModel).fit(dataSet);
  } else if (replicatedModel instanceof ComputationGraph) {
    if (lastEtlTime == null)
      lastEtlTime = new AtomicLong(0);
    ((ComputationGraph) replicatedModel).setLastEtlTime(lastEtlTime.get());
    ((ComputationGraph) replicatedModel).fit(dataSet);
  }
}

代码示例来源:origin: org.deeplearning4j/deeplearning4j-scaleout-akka

@Override
public void perform(Job job) {
  Serializable work = job.getWork();
  if(work instanceof DataSet) {
    DataSet data = (DataSet) work;
    multiLayerNetwork.fit(data);
    job.setResult(multiLayerNetwork.params());
  }
}

代码示例来源:origin: mccorby/FederatedAndroidTrainer

@Override
public void train(FederatedDataSet federatedDataSet) {
  DataSet trainingData = (DataSet) federatedDataSet.getNativeDataSet();
  List<DataSet> listDs = trainingData.asList();
  DataSetIterator mnistTrain = new ListDataSetIterator(listDs, BATCH_SIZE);
  for (int i = 0; i < N_EPOCHS; i++) {
    model.fit(mnistTrain);
  }
}

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

public void fitFeaturized(DataSet input) {
  if (isGraph) {
    unFrozenSubsetGraph.fit(input);
    copyParamsFromSubsetGraphToOrig();
  } else {
    unFrozenSubsetMLN.fit(input);
    copyParamsFromSubsetMLNToOrig();
  }
}

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

public void fitFeaturized(DataSetIterator iter) {
  if (isGraph) {
    unFrozenSubsetGraph.fit(iter);
    copyParamsFromSubsetGraphToOrig();
  } else {
    unFrozenSubsetMLN.fit(iter);
    copyParamsFromSubsetMLNToOrig();
  }
}

代码示例来源:origin: org.deeplearning4j/deeplearning4j-keras

/**
 * Performs fitting of the model which is referenced in the parameters according to learning parameters specified.
 *
 * @param entryPointFitParameters Definition of the model and learning process
 */
public void fit(EntryPointFitParameters entryPointFitParameters) throws Exception {
  try {
    MultiLayerNetwork multiLayerNetwork = neuralNetworkReader.readNeuralNetwork(entryPointFitParameters);
    DataSetIterator dataSetIterator =
            new HDF5MiniBatchDataSetIterator(entryPointFitParameters.getTrainFeaturesDirectory(),
                    entryPointFitParameters.getTrainLabelsDirectory());
    for (int i = 0; i < entryPointFitParameters.getNbEpoch(); i++) {
      log.info("Fitting: " + i);
      multiLayerNetwork.fit(dataSetIterator);
    }
    log.info("Learning model finished");
  } catch (Throwable e) {
    log.error("Error while handling request!", e);
    throw e;
  }
}

代码示例来源:origin: mccorby/FederatedAndroidTrainer

@Override
public void train(FederatedDataSet dataSource) {
  DataSet trainingData = (DataSet) dataSource.getNativeDataSet();
  List<DataSet> listDs = trainingData.asList();
  DataSetIterator iterator = new ListDataSetIterator(listDs, BATCH_SIZE);
  //Train the network on the full data set, and evaluate in periodically
  for (int i = 0; i < N_EPOCHS; i++) {
    iterator.reset();
    mNetwork.fit(iterator);
  }
}

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

/**
 * Fit the model
 *
 * @param examples the examples to classify (one example in each row)
 * @param labels   the labels for each example (the number of labels must match
 */
@Override
public void fit(INDArray examples, int[] labels) {
  org.deeplearning4j.nn.conf.layers.OutputLayer layerConf =
          (org.deeplearning4j.nn.conf.layers.OutputLayer) getOutputLayer().conf().getLayer();
  fit(examples, FeatureUtil.toOutcomeMatrix(labels, layerConf.getNOut()));
}

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

@Override
public void fit(MultiDataSet dataSet) {
  if (dataSet.getFeatures().length == 1 && dataSet.getLabels().length == 1) {
    INDArray features = null;
    INDArray labels = null;
    INDArray fMask = null;
    INDArray lMask = null;
    if (dataSet.getFeaturesMaskArrays() != null)
      fMask = dataSet.getFeaturesMaskArrays()[0];
    if (dataSet.getFeaturesMaskArrays() != null)
      lMask = dataSet.getLabelsMaskArrays()[0];
    features = dataSet.getFeatures()[0];
    labels = dataSet.getLabels()[0];
    DataSet ds = new DataSet(features, labels, fMask, lMask);
    fit(ds);
  }
  throw new DL4JInvalidInputException(
          "MultiLayerNetwork can't handle MultiDataSet. Please consider use of ComputationGraph");
}

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

/**
 * Fit the model
 *
 * @param data the data to train on
 */
@Override
public void fit(org.nd4j.linalg.dataset.api.DataSet data) {
  if (layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
    doTruncatedBPTT(data.getFeatures(), data.getLabels(), data.getFeaturesMaskArray(),
            data.getLabelsMaskArray());
  } else {
    //Standard training
    boolean hasMaskArrays = data.hasMaskArrays();
    if (hasMaskArrays)
      setLayerMaskArrays(data.getFeaturesMaskArray(), data.getLabelsMaskArray());
    fit(data.getFeatures(), data.getLabels());
    if (hasMaskArrays)
      clearLayerMaskArrays();
  }
  clearLayersStates();
}

相关文章