org.deeplearning4j.nn.multilayer.MultiLayerNetwork.getUpdater()方法的使用及代码示例

x33g5p2x  于2022-01-25 转载在 其他  
字(8.2k)|赞(0)|评价(0)|浏览(114)

本文整理了Java中org.deeplearning4j.nn.multilayer.MultiLayerNetwork.getUpdater()方法的一些代码示例,展示了MultiLayerNetwork.getUpdater()的具体用法。这些代码示例主要来源于Github/Stackoverflow/Maven等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。MultiLayerNetwork.getUpdater()方法的具体详情如下:
包路径:org.deeplearning4j.nn.multilayer.MultiLayerNetwork
类名称:MultiLayerNetwork
方法名:getUpdater

MultiLayerNetwork.getUpdater介绍

[英]Get the updater for this MultiLayerNetwork
[中]获取此多层网络的更新程序

代码示例

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

@Override
public INDArray updaterState() {
  return getUpdater() != null ? getUpdater().getStateViewArray() : null;
}

代码示例来源:origin: org.deeplearning4j/deeplearning4j-parallel-wrapper_2.11

protected ParallelWrapper(Model model, int workers, int prefetchSize) {
  this.model = model;
  this.workers = workers;
  this.prefetchSize = prefetchSize;
  if (this.model instanceof MultiLayerNetwork) {
    ((MultiLayerNetwork) this.model).getUpdater();
  } else if (this.model instanceof ComputationGraph) {
    ((ComputationGraph) this.model).getUpdater();
  }
}

代码示例来源:origin: org.deeplearning4j/deeplearning4j-parallel-wrapper

protected ParallelWrapper(Model model, int workers, int prefetchSize) {
  this.model = model;
  this.workers = workers;
  this.prefetchSize = prefetchSize;
  if (this.model instanceof MultiLayerNetwork) {
    ((MultiLayerNetwork) this.model).getUpdater();
  } else if (this.model instanceof ComputationGraph) {
    ((ComputationGraph) this.model).getUpdater();
  }
}

代码示例来源:origin: CampagneLaboratory/variationanalysis

protected static void save(MultiLayerNetwork net, String confOut, String paramOut, String updaterOut) throws IOException {
  String confJSON = net.getLayerWiseConfigurations().toJson();
  INDArray params = net.params();
  Updater updater = net.getUpdater();
  FileUtils.writeStringToFile(new File(confOut), confJSON, "UTF-8");
  try (DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(Files.newOutputStream(Paths.get(paramOut))))) {
    Nd4j.write(params, dos);
  }
  try (ObjectOutputStream oos = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(new File(updaterOut))))) {
    oos.writeObject(updater);
  }
}

代码示例来源:origin: de.datexis/texoo-core

@Deprecated
public void saveUpdater(Resource modelPath, String name) {
 Resource modelFile = modelPath.resolve(name + ".bin.gz");
 INDArray updaterState = null;
 if(net instanceof MultiLayerNetwork) updaterState = ((MultiLayerNetwork) net).getUpdater().getStateViewArray();
 else if(net instanceof ComputationGraph) updaterState = ((ComputationGraph) net).getUpdater().getStateViewArray();
 if(updaterState != null) try(DataOutputStream dos = new DataOutputStream(modelFile.getGZIPOutputStream())){
  Nd4j.write(updaterState, dos);
  dos.flush();
 } catch (IOException ex) {
  log.error(ex.toString());
 } 
}

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

network.getUpdater().setStateViewArray(network, updaterState, false);
} else if (gotOldUpdater && updater != null) {
  network.setUpdater(updater);

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

boolean paramsEquals = network.params().equals(params());
boolean confEquals = getLayerWiseConfigurations().equals(network.getLayerWiseConfigurations());
boolean updaterEquals = getUpdater().equals(network.getUpdater());
return paramsEquals && confEquals && updaterEquals;

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

INDArray updaterState = null;
if (model instanceof MultiLayerNetwork) {
  updaterState = ((MultiLayerNetwork) model).getUpdater().getStateViewArray();
} else if (model instanceof ComputationGraph) {
  updaterState = ((ComputationGraph) model).getUpdater().getStateViewArray();

代码示例来源:origin: org.deeplearning4j/deeplearning4j-parallel-wrapper

Updater updater = ((MultiLayerNetwork) model).getUpdater();
int batchSize = 0;
  for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
    MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
    updaters.add(workerModel.getUpdater().getStateViewArray());
    batchSize += workerModel.batchSize();

代码示例来源:origin: org.deeplearning4j/deeplearning4j-parallel-wrapper_2.11

Updater updater = ((MultiLayerNetwork) model).getUpdater();
int batchSize = 0;
  for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
    MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
    updaters.add(workerModel.getUpdater().getStateViewArray());
    batchSize += workerModel.batchSize();

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

/**
 * Clones the multilayernetwork
 * @return
 */
@Override
public MultiLayerNetwork clone() {
  MultiLayerConfiguration conf = this.layerWiseConfigurations.clone();
  MultiLayerNetwork ret = new MultiLayerNetwork(conf);
  ret.init(this.params().dup(), false);
  if (solver != null) {
    //If  solver is null: updater hasn't been initialized -> getUpdater call will force initialization, however
    Updater u = this.getUpdater();
    INDArray updaterState = u.getStateViewArray();
    if (updaterState != null) {
      ret.getUpdater().setStateViewArray(ret, updaterState.dup(), false);
    }
  }
  if (hasAFrozenLayer()) {
    //correct layers to frozen layers
    Layer[] clonedLayers = ret.getLayers();
    for (int i = 0; i < layers.length; i++) {
      if (layers[i] instanceof FrozenLayer) {
        clonedLayers[i] = new FrozenLayer(ret.getLayer(i));
      }
    }
    ret.setLayers(clonedLayers);
  }
  return ret;
}

代码示例来源:origin: org.deeplearning4j/deeplearning4j-parallel-wrapper

@Override
public void updateModel(@NonNull Model model) {
  this.shouldUpdate.set(true);
  if (replicatedModel instanceof MultiLayerNetwork) {
    replicatedModel.setParams(model.params().dup());
    Updater updater = ((MultiLayerNetwork) model).getUpdater();
    INDArray view = updater.getStateViewArray();
    if (view != null) {
      updater = ((MultiLayerNetwork) replicatedModel).getUpdater();
      INDArray viewD = view.dup();
      Nd4j.getExecutioner().commit();
      updater.setStateViewArray((MultiLayerNetwork) replicatedModel, viewD, false);
    }
  } else if (replicatedModel instanceof ComputationGraph) {
    replicatedModel.setParams(model.params().dup());
    ComputationGraphUpdater updater = ((ComputationGraph) model).getUpdater();
    INDArray view = updater.getStateViewArray();
    if (view != null) {
      INDArray viewD = view.dup();
      Nd4j.getExecutioner().commit();
      updater = ((ComputationGraph) replicatedModel).getUpdater();
      updater.setStateViewArray(viewD);
    }
  }
  Nd4j.getExecutioner().commit();
}

代码示例来源:origin: org.deeplearning4j/deeplearning4j-parallel-wrapper_2.11

@Override
public void updateModel(@NonNull Model model) {
  this.shouldUpdate.set(true);
  if (replicatedModel instanceof MultiLayerNetwork) {
    replicatedModel.setParams(model.params().dup());
    Updater updater = ((MultiLayerNetwork) model).getUpdater();
    INDArray view = updater.getStateViewArray();
    if (view != null) {
      updater = ((MultiLayerNetwork) replicatedModel).getUpdater();
      INDArray viewD = view.dup();
      Nd4j.getExecutioner().commit();
      updater.setStateViewArray((MultiLayerNetwork) replicatedModel, viewD, false);
    }
  } else if (replicatedModel instanceof ComputationGraph) {
    replicatedModel.setParams(model.params().dup());
    ComputationGraphUpdater updater = ((ComputationGraph) model).getUpdater();
    INDArray view = updater.getStateViewArray();
    if (view != null) {
      INDArray viewD = view.dup();
      Nd4j.getExecutioner().commit();
      updater = ((ComputationGraph) replicatedModel).getUpdater();
      updater.setStateViewArray(viewD);
    }
  }
  Nd4j.getExecutioner().commit();
}

代码示例来源:origin: CampagneLaboratory/variationanalysis

MultiLayerNetwork savedNet = savedModel instanceof MultiLayerNetwork ?
    (MultiLayerNetwork) savedModel : null;
if (savedNet == null || savedNet.getUpdater() == null || savedNet.params() == null) {
  System.err.println("Unable to load model or updater from " + args().previousModelPath);
} else {
  net.setUpdater(savedNet.getUpdater());
  net.setParams(savedNet.params());

代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn

INDArray updaterView = network.getUpdater().getStateViewArray();
if (updaterView != null) {

代码示例来源:origin: org.deeplearning4j/deeplearning4j-parallel-wrapper

replicatedModel.setParams(originalModel.params().unsafeDuplication(true));
Updater updaterReplica = ((MultiLayerNetwork) replicatedModel).getUpdater();
Updater updaterOrigina = ((MultiLayerNetwork) originalModel).getUpdater();
    Updater updaterReplica = ((MultiLayerNetwork) replicatedModel).getUpdater();
    if (updaterReplica.getStateViewArray() != null)
      Nd4j.getAffinityManager().ensureLocation(updaterReplica.getStateViewArray(),

代码示例来源:origin: org.deeplearning4j/deeplearning4j-parallel-wrapper_2.11

replicatedModel.setParams(originalModel.params().unsafeDuplication(true));
Updater updaterReplica = ((MultiLayerNetwork) replicatedModel).getUpdater();
Updater updaterOrigina = ((MultiLayerNetwork) originalModel).getUpdater();
    Updater updaterReplica = ((MultiLayerNetwork) replicatedModel).getUpdater();
    if (updaterReplica.getStateViewArray() != null)
      Nd4j.getAffinityManager().ensureLocation(updaterReplica.getStateViewArray(),

相关文章