使用Deeplearning4j训练YOLOV2模型
使用Deeplearning4j训练YOLOV2模型
一、引入pom.xml依赖
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>cn.dearcloud</groupId>
<artifactId>train-yolo-for-java</artifactId>
<version>1.0-SNAPSHOT</version>
<dependencies>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-zoo</artifactId>
<version>1.0.0-beta</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-modelimport</artifactId>
<version>1.0.0-beta</version>
</dependency>
<!--GPU-->
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-8.0-platform</artifactId>
<version>1.0.0-beta</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-cuda-8.0</artifactId>
<version>1.0.0-beta</version>
</dependency>
<!--CPU-->
<!--<dependency>-->
<!--<groupId>org.nd4j</groupId>-->
<!--<artifactId>nd4j-native-platform</artifactId>-->
<!--<version>1.0.0-beta</version>-->
<!--</dependency>-->
<!--Log-->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.16.22</version>
</dependency>
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-slf4j-impl</artifactId>
<version>2.11.0</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.7.0</version>
<configuration>
<source>1.8</source>
<target>1.8</target>
<encoding>UTF-8</encoding>
</configuration>
</plugin>
</plugins>
</build>
</project>
二、读取数据集
假设,数据集文件夹所在路径如下,下面有图片和图片同名的txt文件中记录标注对像。一行一个标注对像,每行依次是:Label,X,Y,Width,Height
D:\\Project\\AIProject\\train-yolo-for-java\\docs\\pupil-datasets
三、编写标注加载代码
package cn.dearcloud.provider;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.datavec.image.recordreader.objdetect.ImageObject;
import org.datavec.image.recordreader.objdetect.ImageObjectLabelProvider;
import java.io.File;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
public class CnnLabelProvider implements ImageObjectLabelProvider {
public CnnLabelProvider() {
}
@Override
public List<ImageObject> getImageObjectsForPath(String path) {
try {
List<ImageObject> imageObjects = new ArrayList<>();
File labelFile = new File(FilenameUtils.getFullPath(path), FilenameUtils.getBaseName(path) + ".txt");
List<String> lines = FileUtils.readLines(labelFile, "UTF-8");
for (String line : lines) {
//label,x,y,w,h
String[] arr = line.split(",");
if (arr.length == 5) {
String labelName = arr[0];
int x = Integer.parseInt(arr[1]);
int y = Integer.parseInt(arr[2]);
int w = Integer.parseInt(arr[3]);
int h = Integer.parseInt(arr[4]);
imageObjects.add(new ImageObject(x, y, x + w, y + h, labelName));
}
}
return imageObjects;
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}
@Override
public List<ImageObject> getImageObjectsForPath(URI uri) {
return getImageObjectsForPath(new File(uri).getPath());
}
}
四、编写YoloV2训练代码
package cn.dearcloud;
import cn.dearcloud.provider.CnnLabelProvider;
import lombok.extern.log4j.Log4j2;
import org.bytedeco.javacpp.opencv_core;
import org.bytedeco.javacpp.opencv_imgproc;
import org.bytedeco.javacv.CanvasFrame;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.api.records.metadata.RecordMetaDataImageURI;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.objdetect.ObjectDetectionRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.objdetect.DetectedObject;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.zoo.model.YOLO2;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.learning.config.Adam;
import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Random;
import static org.bytedeco.javacpp.opencv_core.FONT_HERSHEY_DUPLEX;
import static org.bytedeco.javacpp.opencv_imgproc.resize;
import static org.opencv.core.CvType.CV_8U;
@Log4j2
public class Yolo2Trainer {
// parameters matching the pretrained TinyYOLO model
int width = 480;
int height = 320;
int nChannels = 3;
int gridWidth = 15;
int gridHeight = 10;
int nClasses = 1;
int nBoxes = 5;
double lambdaNoObj = 0.5;
double lambdaCoord = 5.0;
double[][] priorBoxes = {{1.08, 1.19}, {3.42, 4.41}, {6.63, 11.38}, {9.42, 5.11}, {16.62, 10.52}};
double detectionThreshold = 0.3;
// parameters for the training phase
int batchSize = 1;
int nEpochs = 50;
double learningRate = 1e-3;
double lrMomentum = 0.9;
public void read() throws IOException, InterruptedException {
String datasetsDir = "D:\\Project\\AIProject\\train-yolo-for-java\\docs\\pupil-datasets";
File imageDir = new File(datasetsDir);
log.info("Load data...");
//切分数据集
Random rng = new Random();
FileSplit fileSplit = new FileSplit(imageDir, NativeImageLoader.ALLOWED_FORMATS, rng);
InputSplit[] data = fileSplit.sample(null, 0.8, 0.2);
InputSplit trainData = data[0];
InputSplit testData = data[1];
//自己实现ImageObjectLabelProvider接口
CnnLabelProvider labelProvider = new CnnLabelProvider();
ObjectDetectionRecordReader trainRecordReader = new ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth, labelProvider);
trainRecordReader.initialize(trainData);//returned values: 4d array, with dimensions [minibatch, 4+C, h, w]
ObjectDetectionRecordReader testRecordReader = new ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth, labelProvider);
testRecordReader.initialize(testData);
// ObjectDetectionRecordReader performs regression, so we need to specify it here
RecordReaderDataSetIterator trainDataSetIterator = new RecordReaderDataSetIterator(trainRecordReader, batchSize, 1, 1, true);
trainDataSetIterator.setPreProcessor(new ImagePreProcessingScaler(0, 1));
RecordReaderDataSetIterator testDataSetIterator = new RecordReaderDataSetIterator(testRecordReader, 1, 1, 1, true);
testDataSetIterator.setPreProcessor(new ImagePreProcessingScaler(0, 1));
ComputationGraph model;
String modelFilename = "model_surface_YOLO2.zip";
if (new File(modelFilename).exists()) {
log.info("Load model...");
model = ModelSerializer.restoreComputationGraph(modelFilename);
} else {
ComputationGraph pretrained = (ComputationGraph) YOLO2.builder().build().initPretrained();
INDArray priors = org.nd4j.linalg.factory.Nd4j.create(priorBoxes);
FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder()
.seed(1234)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(1.0)
.updater(new Adam.Builder().learningRate(1e-3).build())
.l2(0.00001)
.activation(Activation.IDENTITY)
.trainingWorkspaceMode(WorkspaceMode.ENABLED)
.inferenceWorkspaceMode(WorkspaceMode.ENABLED)
.build();
model = new TransferLearning.GraphBuilder(pretrained).fineTuneConfiguration(fineTuneConf).removeVertexKeepConnections("conv2d_23")
.addLayer("convolution2d_23",
new ConvolutionLayer.Builder(1, 1)
.nIn(1024)
.nOut(nBoxes * (5 + nClasses))
.stride(1, 1)
.convolutionMode(ConvolutionMode.Same)
.weightInit(WeightInit.UNIFORM)
.hasBias(false)
.activation(Activation.IDENTITY)
.build(),
"leaky_re_lu_22")
.addLayer("outputs",
new Yolo2OutputLayer.Builder()
.boundingBoxPriors(priors)
.lambbaNoObj(lambdaNoObj).lambdaCoord(lambdaCoord)
.build(),
"convolution2d_23")
.setOutputs("outputs")
.build();
System.out.println(model.summary(InputType.convolutional(width, height, nChannels)));
//设置训练时输出
model.setListeners(new org.deeplearning4j.optimize.listeners.ScoreIterationListener(1));
//开始训练
for (int i = 0; i < nEpochs; i++) {
trainDataSetIterator.reset();
while (trainDataSetIterator.hasNext()) {
model.fit(trainDataSetIterator.next());
}
log.info("*** Completed epoch {} ***", i);
}
ModelSerializer.writeModel(model, modelFilename, true);
}
// 可视化与测试
NativeImageLoader imageLoader = new NativeImageLoader();
CanvasFrame frame = new CanvasFrame("RedBloodCellDetection");
OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model.getOutputLayer(0);
List<String> labels = trainDataSetIterator.getLabels();
testDataSetIterator.setCollectMetaData(true);
while (testDataSetIterator.hasNext() && frame.isVisible()) {
org.nd4j.linalg.dataset.DataSet ds = testDataSetIterator.next();
RecordMetaDataImageURI metadata = (RecordMetaDataImageURI) ds.getExampleMetaData().get(0);
INDArray features = ds.getFeatures();
INDArray results = model.outputSingle(features);
List<DetectedObject> objs = yout.getPredictedObjects(results, detectionThreshold);
File file = new File(metadata.getURI());
log.info(file.getName() + ": " + objs);
opencv_core.Mat mat = imageLoader.asMat(features);
opencv_core.Mat convertedMat = new opencv_core.Mat();
mat.convertTo(convertedMat, CV_8U, 255, 0);
int w = metadata.getOrigW() * 2;
int h = metadata.getOrigH() * 2;
opencv_core.Mat image = new opencv_core.Mat();
resize(convertedMat, image, new opencv_core.Size(w, h));
for (DetectedObject obj : objs) {
double[] xy1 = obj.getTopLeftXY();
double[] xy2 = obj.getBottomRightXY();
String label = labels.get(obj.getPredictedClass());
int x1 = (int) Math.round(w * xy1[0] / gridWidth);
int y1 = (int) Math.round(h * xy1[1] / gridHeight);
int x2 = (int) Math.round(w * xy2[0] / gridWidth);
int y2 = (int) Math.round(h * xy2[1] / gridHeight);
opencv_imgproc.rectangle(image, new opencv_core.Point(x1, y1), new opencv_core.Point(x2, y2), opencv_core.Scalar.RED);
opencv_imgproc.putText(image, label, new opencv_core.Point(x1 + 2, y2 - 2), FONT_HERSHEY_DUPLEX, 1, opencv_core.Scalar.GREEN);
}
frame.setTitle(new File(metadata.getURI()).getName() + " - RedBloodCellDetection");
frame.setCanvasSize(w, h);
frame.showImage(converter.convert(image));
frame.waitKey();
}
frame.dispose();
}
}
五、顺便给大家写写TinyYolo的训练代码
package cn.dearcloud;
import lombok.extern.log4j.Log4j2;
import org.bytedeco.javacpp.opencv_core;
import org.bytedeco.javacpp.opencv_imgproc;
import org.bytedeco.javacv.CanvasFrame;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.api.io.filters.RandomPathFilter;
import org.datavec.api.records.metadata.RecordMetaDataImageURI;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.objdetect.ObjectDetectionRecordReader;
import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.objdetect.DetectedObject;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.zoo.model.TinyYOLO;
import org.deeplearning4j.zoo.model.YOLO2;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.learning.config.Nesterovs;
import java.io.File;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.List;
import java.util.Random;
import static org.bytedeco.javacpp.opencv_core.FONT_HERSHEY_DUPLEX;
import static org.bytedeco.javacpp.opencv_imgproc.resize;
import static org.opencv.core.CvType.CV_8U;
/**
* 参考:https://blog.csdn.net/u011669700/article/details/79886619 实现
*/
@Log4j2
public class TinyYoloTrainer {
// parameters matching the pretrained TinyYOLO model
int width = 416;
int height = 416;
int nChannels = 3;
int gridWidth = 13;
int gridHeight = 13;
int numClasses = 1;
// parameters for the Yolo2OutputLayer
int nBoxes = 5;
double lambdaNoObj = 0.5;
double lambdaCoord = 5.0;
double[][] priorBoxes = {{2, 2}, {2, 2}, {2, 2}, {2, 2}, {2, 2}};
double detectionThreshold = 0.3;
// parameters for the training phase
int batchSize = 2;
int nEpochs = 50;
double learningRate = 1e-3;
double lrMomentum = 0.9;
public void read() throws IOException, InterruptedException {
String dataDir = new ClassPathResource("/datasets").getFile().getPath();
File imageDir = new File(dataDir, "JPEGImages");
log.info("Load data...");
//切分数据集
Random rng = new Random();
FileSplit fileSplit = new org.datavec.api.split.FileSplit(imageDir, NativeImageLoader.ALLOWED_FORMATS, rng);
InputSplit[] data = fileSplit.sample(new RandomPathFilter(rng) {
@Override
protected boolean accept(String name) {
boolean isXmlExist = false;
try {
isXmlExist = new File(new URI(name.replace("JPEGImages", "Annotations").replace(".jpg", ".xml"))).exists();
} catch (URISyntaxException e) {
e.printStackTrace();
}
return isXmlExist;
}
}, 0.8, 0.2);
InputSplit trainData = data[0];
InputSplit testData = data[1];
//用于解析识别voc方式的label方式,也可以自己实现ImageObjectLabelProvider接口
VocLabelProvider labelProvider = new VocLabelProvider(dataDir);
ObjectDetectionRecordReader trainRecordReader = new org.datavec.image.recordreader.objdetect.ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth, labelProvider);
trainRecordReader.initialize(trainData);//returned values: 4d array, with dimensions [minibatch, 4+C, h, w]
ObjectDetectionRecordReader testRecordReader = new org.datavec.image.recordreader.objdetect.ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth, labelProvider);
testRecordReader.initialize(testData);
// ObjectDetectionRecordReader performs regression, so we need to specify it here
RecordReaderDataSetIterator trainDataSetIterator = new org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator(trainRecordReader, batchSize, 1, 1, true);
trainDataSetIterator.setPreProcessor(new ImagePreProcessingScaler(0, 1, 8));
RecordReaderDataSetIterator testDataSetIterator = new org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator(testRecordReader, batchSize, 1, 1, true);
testDataSetIterator.setPreProcessor(new ImagePreProcessingScaler(0, 1, 8));
String modelFilename = "model_yolov2.zip";
ComputationGraph pretrained = (ComputationGraph) TinyYOLO.builder().build().initPretrained();
INDArray priors = org.nd4j.linalg.factory.Nd4j.create(priorBoxes);
FineTuneConfiguration fineTuneConfiguration = new org.deeplearning4j.nn.transferlearning.FineTuneConfiguration.Builder()
.seed(100)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(1.0)
.updater(Nesterovs.builder().learningRate(learningRate).momentum(lrMomentum).build())
.activation(Activation.IDENTITY)
.trainingWorkspaceMode(WorkspaceMode.ENABLED)
.inferenceWorkspaceMode(WorkspaceMode.ENABLED)
.build();
ComputationGraph model = new TransferLearning.GraphBuilder(pretrained).fineTuneConfiguration(fineTuneConfiguration).removeVertexKeepConnections("conv2d_9")
.addLayer("convolution2d_9",
new ConvolutionLayer.Builder(1, 1)
.nIn(1024)
.nOut(nBoxes * (5 + numClasses))
.stride(1, 1)
.convolutionMode(ConvolutionMode.Same)
.weightInit(WeightInit.UNIFORM)
.hasBias(false)
.activation(Activation.IDENTITY)
.build(),
"leaky_re_lu_8")
.addLayer("outputs", new Yolo2OutputLayer.Builder().lambbaNoObj(lambdaNoObj).lambdaCoord(lambdaCoord).boundingBoxPriors(priors).build(),
"convolution2d_9")
.setOutputs("outputs")
.build();
//设置训练时输出
model.setListeners(new org.deeplearning4j.optimize.listeners.ScoreIterationListener(1));
//开始训练
for (int i = 0; i < nEpochs; i++) {
trainDataSetIterator.reset();
while (trainDataSetIterator.hasNext()) {
model.fit(trainDataSetIterator.next());
}
log.info("*** Completed epoch {} ***", i);
}
ModelSerializer.writeModel(model, modelFilename, true);
// 可视化与测试
NativeImageLoader imageLoader = new NativeImageLoader();
CanvasFrame frame = new CanvasFrame("RedBloodCellDetection");
OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model.getOutputLayer(0);
List<String> labels = trainDataSetIterator.getLabels();
testDataSetIterator.setCollectMetaData(true);
while (testDataSetIterator.hasNext() && frame.isVisible()) {
org.nd4j.linalg.dataset.DataSet ds = testDataSetIterator.next();
RecordMetaDataImageURI metadata = (RecordMetaDataImageURI) ds.getExampleMetaData().get(0);
INDArray features = ds.getFeatures();
INDArray results = model.outputSingle(features);
List<DetectedObject> objs = yout.getPredictedObjects(results, detectionThreshold);
File file = new File(metadata.getURI());
log.info(file.getName() + ": " + objs);
opencv_core.Mat mat = imageLoader.asMat(features);
opencv_core.Mat convertedMat = new opencv_core.Mat();
mat.convertTo(convertedMat, CV_8U, 255, 0);
int w = metadata.getOrigW() * 2;
int h = metadata.getOrigH() * 2;
opencv_core.Mat image = new opencv_core.Mat();
resize(convertedMat, image, new opencv_core.Size(w, h));
for (DetectedObject obj : objs) {
double[] xy1 = obj.getTopLeftXY();
double[] xy2 = obj.getBottomRightXY();
String label = labels.get(obj.getPredictedClass());
int x1 = (int) Math.round(w * xy1[0] / gridWidth);
int y1 = (int) Math.round(h * xy1[1] / gridHeight);
int x2 = (int) Math.round(w * xy2[0] / gridWidth);
int y2 = (int) Math.round(h * xy2[1] / gridHeight);
opencv_imgproc.rectangle(image, new opencv_core.Point(x1, y1), new opencv_core.Point(x2, y2), opencv_core.Scalar.RED);
opencv_imgproc.putText(image, label, new opencv_core.Point(x1 + 2, y2 - 2), FONT_HERSHEY_DUPLEX, 1, opencv_core.Scalar.GREEN);
}
frame.setTitle(new File(metadata.getURI()).getName() + " - RedBloodCellDetection");
frame.setCanvasSize(w, h);
frame.showImage(converter.convert(image));
frame.waitKey();
}
frame.dispose();
}
}
日志如下:
使用Deeplearning4j训练YOLOV2模型
https://www.dearcloud.cn/2018/12/19/20200310-cnblogs-old-posts/20181219-使用Deeplearning4j训练YOLOV2模型/