目标是提供一个完整的Java机器学习(Machine Learning/ML)框架,作为人工智能在学术界与工业界的桥梁. 让相关领域的研发人员能够在各种软硬件环境/数据结构/算法/模型之间无缝切换. 涵盖了从数据处理到模型的训练与评估各个环节,支持硬件加速和并行计算,是最快最全的Java机器学习库.
希望路过的同学,顺手给JStarCraft框架点个Star,算是对作者的一种鼓励吧!
JStarCraft AI是一个机器学习的轻量级框架.遵循Apache 2.0协议.
在学术界,绝大多数研究人员使用的编程语言是Python.
在工业界,绝大多数开发人员使用的编程语言是Java.
JStarCraft AI是一个基于Java语言的机器学习工具包,由一系列的数据结构,算法和模型组成.
目标是作为在学术界与工业界从事机器学习研发的相关人员之间的桥梁.普及机器学习在Java领域的应用.
作者 | 洪钊桦 |
---|---|
[email protected], [email protected] |
JStarCraft AI框架各个模块之间的关系:
<dependency>
<groupId>com.jstarcraft</groupId>
<artifactId>ai</artifactId>
<version>1.0</version>
</dependency>
compile group: 'com.jstarcraft', name: 'ai', version: '1.0'
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>1.0.0-beta3</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-9.0-platform</artifactId>
<version>1.0.0-beta3</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-9.1-platform</artifactId>
<version>1.0.0-beta3</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-9.2-platform</artifactId>
<version>1.0.0-beta3</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.0-platform</artifactId>
<version>1.0.0-beta3</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.1-platform</artifactId>
<version>1.0.0-beta3</version>
</dependency>
// 获取默认环境上下文
EnvironmentContext context = EnvironmentContext.getContext();
// 在环境上下文中执行任务
Future<?> task = context.doTask(() - > {
int dimension = 10;
MathMatrix leftMatrix = getRandomMatrix(dimension);
MathMatrix rightMatrix = getRandomMatrix(dimension);
MathMatrix dataMatrix = getZeroMatrix(dimension);
dataMatrix.dotProduct(leftMatrix, false, rightMatrix, true, MathCalculator.PARALLEL);
});
用户(User) | 旧手机类型(Item) | 新手机类型(Item) | 评分(Score) |
---|---|---|---|
Google Fan | Android | Android | 3 |
Google Fan | Android | IOS | 1 |
Google Fan | IOS | Android | 5 |
Apple Fan | IOS | IOS | 3 |
Apple Fan | Android | IOS | 5 |
Apple Fan | IOS | Android | 1 |
定性(User) | 定性(Item) | 定性(Item) | 定量(Score) |
---|---|---|---|
0 | 0 | 0 | 3 |
0 | 0 | 1 | 1 |
0 | 1 | 0 | 5 |
1 | 1 | 1 | 3 |
1 | 0 | 1 | 5 |
1 | 1 | 0 | 1 |
数据转换器(DataConverter)负责各种各样的格式转换为JStarCraft AI框架能够处理的数据模块(DataModule).
JStarCraft AI框架各个转换器与其它系统之间的关系:
// 定性属性
Map<String, Class<?>> qualityDifinitions = new HashMap<>();
qualityDifinitions.put("user", String.class);
qualityDifinitions.put("item", String.class);
// 定量属性
Map<String, Class<?>> quantityDifinitions = new HashMap<>();
quantityDifinitions.put("score", float.class);
DataSpace space = new DataSpace(qualityDifinitions, quantityDifinitions);
TreeMap<Integer, String> configuration = new TreeMap<>();
configuration.put(1, "user");
configuration.put(3, "item");
configuration.put(4, "score");
DataModule module = space.makeDenseModule("module", configuration, 1000);
JStarCraft AI框架兼容的格式
// ARFF转换器
ArffConverter converter = new ArffConverter(space.getQualityAttributes(), space.getQuantityAttributes());
// 获取流
File file = new File(this.getClass().getResource("module.arff").toURI());
InputStream stream = new FileInputStream(file);
// 转换数据
int count = converter.convert(module, stream, null, null, null);
// CSV转换器
CsvConverter converter = new CsvConverter(',', space.getQualityAttributes(), space.getQuantityAttributes());
// 获取流
File file = new File(this.getClass().getResource("module.csv").toURI());
InputStream stream = new FileInputStream(file);
// 转换数据
int count = converter.convert(module, stream, null, null, null);
// JSON转换器
JsonConverter converter = new JsonConverter(space.getQualityAttributes(), space.getQuantityAttributes());
// 获取流
File file = new File(this.getClass().getResource("module.json").toURI());
InputStream stream = new FileInputStream(file);
// 转换数据
int count = converter.convert(module, stream, null, null, null);
// HQL转换器
QueryConverter converter = new QueryConverter(space.getQualityAttributes(), space.getQuantityAttributes());
// 获取游标
String selectDataHql = "select data.user, data.leftItem, data.rightItem, data.score from MockData data";
Session session = sessionFactory.openSession();
Query query = session.createQuery(selectDataHql);
ScrollableResults iterator = query.scroll();
// 转换数据
int count = converter.convert(module, iterator, null, null, null);
session.close();
// SQL转换器
QueryConverter converter = new QueryConverter(space.getQualityAttributes(), space.getQuantityAttributes());
// 获取游标
String selectDataSql = "select user, leftItem, rightItem, score from MockData";
Session session = sessionFactory.openSession();
Query query = session.createQuery(selectDataSql);
ScrollableResults iterator = query.scroll();
// 转换数据
int count = converter.convert(module, iterator, null, null, null);
session.close();