本文整理了Java中org.deeplearning4j.nn.multilayer.MultiLayerNetwork.getUpdater()
方法的一些代码示例,展示了MultiLayerNetwork.getUpdater()
的具体用法。这些代码示例主要来源于Github
/Stackoverflow
/Maven
等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。MultiLayerNetwork.getUpdater()
方法的具体详情如下:
包路径:org.deeplearning4j.nn.multilayer.MultiLayerNetwork
类名称:MultiLayerNetwork
方法名:getUpdater
[英]Get the updater for this MultiLayerNetwork
[中]获取此多层网络的更新程序
代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn
@Override
public INDArray updaterState() {
return getUpdater() != null ? getUpdater().getStateViewArray() : null;
}
代码示例来源:origin: org.deeplearning4j/deeplearning4j-parallel-wrapper_2.11
protected ParallelWrapper(Model model, int workers, int prefetchSize) {
this.model = model;
this.workers = workers;
this.prefetchSize = prefetchSize;
if (this.model instanceof MultiLayerNetwork) {
((MultiLayerNetwork) this.model).getUpdater();
} else if (this.model instanceof ComputationGraph) {
((ComputationGraph) this.model).getUpdater();
}
}
代码示例来源:origin: org.deeplearning4j/deeplearning4j-parallel-wrapper
protected ParallelWrapper(Model model, int workers, int prefetchSize) {
this.model = model;
this.workers = workers;
this.prefetchSize = prefetchSize;
if (this.model instanceof MultiLayerNetwork) {
((MultiLayerNetwork) this.model).getUpdater();
} else if (this.model instanceof ComputationGraph) {
((ComputationGraph) this.model).getUpdater();
}
}
代码示例来源:origin: CampagneLaboratory/variationanalysis
protected static void save(MultiLayerNetwork net, String confOut, String paramOut, String updaterOut) throws IOException {
String confJSON = net.getLayerWiseConfigurations().toJson();
INDArray params = net.params();
Updater updater = net.getUpdater();
FileUtils.writeStringToFile(new File(confOut), confJSON, "UTF-8");
try (DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(Files.newOutputStream(Paths.get(paramOut))))) {
Nd4j.write(params, dos);
}
try (ObjectOutputStream oos = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(new File(updaterOut))))) {
oos.writeObject(updater);
}
}
代码示例来源:origin: de.datexis/texoo-core
@Deprecated
public void saveUpdater(Resource modelPath, String name) {
Resource modelFile = modelPath.resolve(name + ".bin.gz");
INDArray updaterState = null;
if(net instanceof MultiLayerNetwork) updaterState = ((MultiLayerNetwork) net).getUpdater().getStateViewArray();
else if(net instanceof ComputationGraph) updaterState = ((ComputationGraph) net).getUpdater().getStateViewArray();
if(updaterState != null) try(DataOutputStream dos = new DataOutputStream(modelFile.getGZIPOutputStream())){
Nd4j.write(updaterState, dos);
dos.flush();
} catch (IOException ex) {
log.error(ex.toString());
}
}
代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn
network.getUpdater().setStateViewArray(network, updaterState, false);
} else if (gotOldUpdater && updater != null) {
network.setUpdater(updater);
代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn
boolean paramsEquals = network.params().equals(params());
boolean confEquals = getLayerWiseConfigurations().equals(network.getLayerWiseConfigurations());
boolean updaterEquals = getUpdater().equals(network.getUpdater());
return paramsEquals && confEquals && updaterEquals;
代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn
INDArray updaterState = null;
if (model instanceof MultiLayerNetwork) {
updaterState = ((MultiLayerNetwork) model).getUpdater().getStateViewArray();
} else if (model instanceof ComputationGraph) {
updaterState = ((ComputationGraph) model).getUpdater().getStateViewArray();
代码示例来源:origin: org.deeplearning4j/deeplearning4j-parallel-wrapper
Updater updater = ((MultiLayerNetwork) model).getUpdater();
int batchSize = 0;
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
updaters.add(workerModel.getUpdater().getStateViewArray());
batchSize += workerModel.batchSize();
代码示例来源:origin: org.deeplearning4j/deeplearning4j-parallel-wrapper_2.11
Updater updater = ((MultiLayerNetwork) model).getUpdater();
int batchSize = 0;
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
updaters.add(workerModel.getUpdater().getStateViewArray());
batchSize += workerModel.batchSize();
代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn
/**
* Clones the multilayernetwork
* @return
*/
@Override
public MultiLayerNetwork clone() {
MultiLayerConfiguration conf = this.layerWiseConfigurations.clone();
MultiLayerNetwork ret = new MultiLayerNetwork(conf);
ret.init(this.params().dup(), false);
if (solver != null) {
//If solver is null: updater hasn't been initialized -> getUpdater call will force initialization, however
Updater u = this.getUpdater();
INDArray updaterState = u.getStateViewArray();
if (updaterState != null) {
ret.getUpdater().setStateViewArray(ret, updaterState.dup(), false);
}
}
if (hasAFrozenLayer()) {
//correct layers to frozen layers
Layer[] clonedLayers = ret.getLayers();
for (int i = 0; i < layers.length; i++) {
if (layers[i] instanceof FrozenLayer) {
clonedLayers[i] = new FrozenLayer(ret.getLayer(i));
}
}
ret.setLayers(clonedLayers);
}
return ret;
}
代码示例来源:origin: org.deeplearning4j/deeplearning4j-parallel-wrapper
@Override
public void updateModel(@NonNull Model model) {
this.shouldUpdate.set(true);
if (replicatedModel instanceof MultiLayerNetwork) {
replicatedModel.setParams(model.params().dup());
Updater updater = ((MultiLayerNetwork) model).getUpdater();
INDArray view = updater.getStateViewArray();
if (view != null) {
updater = ((MultiLayerNetwork) replicatedModel).getUpdater();
INDArray viewD = view.dup();
Nd4j.getExecutioner().commit();
updater.setStateViewArray((MultiLayerNetwork) replicatedModel, viewD, false);
}
} else if (replicatedModel instanceof ComputationGraph) {
replicatedModel.setParams(model.params().dup());
ComputationGraphUpdater updater = ((ComputationGraph) model).getUpdater();
INDArray view = updater.getStateViewArray();
if (view != null) {
INDArray viewD = view.dup();
Nd4j.getExecutioner().commit();
updater = ((ComputationGraph) replicatedModel).getUpdater();
updater.setStateViewArray(viewD);
}
}
Nd4j.getExecutioner().commit();
}
代码示例来源:origin: org.deeplearning4j/deeplearning4j-parallel-wrapper_2.11
@Override
public void updateModel(@NonNull Model model) {
this.shouldUpdate.set(true);
if (replicatedModel instanceof MultiLayerNetwork) {
replicatedModel.setParams(model.params().dup());
Updater updater = ((MultiLayerNetwork) model).getUpdater();
INDArray view = updater.getStateViewArray();
if (view != null) {
updater = ((MultiLayerNetwork) replicatedModel).getUpdater();
INDArray viewD = view.dup();
Nd4j.getExecutioner().commit();
updater.setStateViewArray((MultiLayerNetwork) replicatedModel, viewD, false);
}
} else if (replicatedModel instanceof ComputationGraph) {
replicatedModel.setParams(model.params().dup());
ComputationGraphUpdater updater = ((ComputationGraph) model).getUpdater();
INDArray view = updater.getStateViewArray();
if (view != null) {
INDArray viewD = view.dup();
Nd4j.getExecutioner().commit();
updater = ((ComputationGraph) replicatedModel).getUpdater();
updater.setStateViewArray(viewD);
}
}
Nd4j.getExecutioner().commit();
}
代码示例来源:origin: CampagneLaboratory/variationanalysis
MultiLayerNetwork savedNet = savedModel instanceof MultiLayerNetwork ?
(MultiLayerNetwork) savedModel : null;
if (savedNet == null || savedNet.getUpdater() == null || savedNet.params() == null) {
System.err.println("Unable to load model or updater from " + args().previousModelPath);
} else {
net.setUpdater(savedNet.getUpdater());
net.setParams(savedNet.params());
代码示例来源:origin: org.deeplearning4j/deeplearning4j-nn
INDArray updaterView = network.getUpdater().getStateViewArray();
if (updaterView != null) {
代码示例来源:origin: org.deeplearning4j/deeplearning4j-parallel-wrapper
replicatedModel.setParams(originalModel.params().unsafeDuplication(true));
Updater updaterReplica = ((MultiLayerNetwork) replicatedModel).getUpdater();
Updater updaterOrigina = ((MultiLayerNetwork) originalModel).getUpdater();
Updater updaterReplica = ((MultiLayerNetwork) replicatedModel).getUpdater();
if (updaterReplica.getStateViewArray() != null)
Nd4j.getAffinityManager().ensureLocation(updaterReplica.getStateViewArray(),
代码示例来源:origin: org.deeplearning4j/deeplearning4j-parallel-wrapper_2.11
replicatedModel.setParams(originalModel.params().unsafeDuplication(true));
Updater updaterReplica = ((MultiLayerNetwork) replicatedModel).getUpdater();
Updater updaterOrigina = ((MultiLayerNetwork) originalModel).getUpdater();
Updater updaterReplica = ((MultiLayerNetwork) replicatedModel).getUpdater();
if (updaterReplica.getStateViewArray() != null)
Nd4j.getAffinityManager().ensureLocation(updaterReplica.getStateViewArray(),
内容来源于网络,如有侵权,请联系作者删除!