Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,38 @@
* 用于在测试用例前执行初始化 Sql 语句。
*
* @author 易文渊
* @author 季聿阶
* @since 2024-07-21
*/
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface Sql {
/**
* 获取 sql 脚本文件路径。
* 获取前置 SQL 脚本文件路径。
*
* @return 表示 sql 脚本文件路径集合的 {@code String[]}。
* @return 表示前置 SQL 脚本文件路径集合的 {@link String}{@code []}。
*/
String[] scripts();
String[] before() default {};

/**
* 获取后置 SQL 脚本文件路径。
*
* @return 表示后置 SQL 脚本文件路径集合的 {@link String}{@code []}。
*/
String[] after() default {};

/**
* 获取 SQL 脚本执行位置。
*/
enum Position {
/**
* 在前置执行。
*/
BEFORE,

/**
* 在后置执行。
*/
AFTER
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand Down Expand Up @@ -64,7 +66,7 @@ public TestPlugin(FitRuntime runtime, TestContextConfiguration configuration) {
Validation.notNull(configuration, "The configuration to create test plugin cannot be null.");
this.packageScanner = this.scanner((packageScanner, clazz) -> this.onClassDetected(packageScanner, clazz,
// 包含的类已经提前注册,因此需要将包含的和排除的类进行合并。
Stream.concat(Arrays.stream(this.configuration.includeClasses()),
Stream.concat(this.configuration.includeClasses().keySet().stream(),
Arrays.stream(this.configuration.excludeClasses())).collect(Collectors.toSet())));
}

Expand Down Expand Up @@ -93,6 +95,7 @@ protected void registerSystemBeans() {
@Override
protected void scanBeans() {
this.registerBeans(this.configuration.includeClasses());
this.configuration.actions().forEach(action -> action.accept(this));
this.scan(this.configuration.scannedPackages());
this.registerMockedBeans(this.configuration.mockedBeanFields());
}
Expand All @@ -111,10 +114,22 @@ private void onClassDetected(PackageScanner scanner, Class<?> clazz, Set<Class<?
}
}

private void registerBeans(Class<?>[] classArray) {
Arrays.stream(classArray)
.filter(clazz -> !this.container().lookup(clazz).isPresent())
.forEach(clazz -> this.container().registry().register(clazz));
private void registerBeans(Map<Class<?>, Supplier<Object>> classes) {
classes.entrySet()
.stream()
.filter(entry -> this.container().lookup(entry.getKey()).isEmpty())
.forEach(entry -> {
if (entry.getValue() == null) {
this.container().registry().register(entry.getKey());
} else {
Object bean = entry.getValue().get();
if (bean == null) {
this.container().registry().register(entry.getKey());
} else {
this.container().registry().register(bean);
}
}
});
}

private void scan(Set<String> basePackages) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,39 @@

package modelengine.fitframework.test.domain.listener;

import modelengine.fitframework.ioc.BeanContainer;
import modelengine.fitframework.ioc.BeanNotFoundException;
import modelengine.fitframework.test.annotation.EnableDataSource;
import modelengine.fitframework.test.domain.TestContext;
import modelengine.fitframework.test.domain.resolver.TestContextConfiguration;
import modelengine.fitframework.test.domain.util.AnnotationUtils;
import modelengine.fitframework.util.MapBuilder;

import org.h2.jdbcx.JdbcConnectionPool;

import java.util.Optional;
import java.util.function.Supplier;

import javax.sql.DataSource;

/**
* 用于注入 dataSource 的监听器。
*
* @author 易文渊
* @author 季聿阶
* @since 2024-07-21
*/
public class DataSourceListener implements TestListener {
@Override
public void beforeTestClass(TestContext context) {
Class<?> clazz = context.testClass();
public Optional<TestContextConfiguration> config(Class<?> clazz) {
Optional<EnableDataSource> annotationOption = AnnotationUtils.getAnnotation(clazz, EnableDataSource.class);
if (!annotationOption.isPresent()) {
return;
}
BeanContainer beanContainer = context.plugin().container();
try {
beanContainer.beans().get(DataSource.class);
} catch (BeanNotFoundException e) {
EnableDataSource enableDataSource = annotationOption.get();
DataSource dataSource = JdbcConnectionPool.create(enableDataSource.model().getUrl(), "sa", "sa");
beanContainer.registry().register(dataSource);
if (annotationOption.isEmpty()) {
return Optional.empty();
}
TestContextConfiguration customConfig = TestContextConfiguration.custom()
.testClass(clazz)
.includeClasses(MapBuilder.<Class<?>, Supplier<Object>>get().put(DataSource.class, () -> {
EnableDataSource enableDataSource = annotationOption.get();
return JdbcConnectionPool.create(enableDataSource.model().getUrl(), "sa", "sa");
}).build())
.build();
return Optional.of(customConfig);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import modelengine.fitframework.test.domain.mvc.request.MockRequestBuilder;
import modelengine.fitframework.test.domain.resolver.TestContextConfiguration;
import modelengine.fitframework.test.domain.util.AnnotationUtils;
import modelengine.fitframework.util.MapBuilder;
import modelengine.fitframework.util.StringUtils;
import modelengine.fitframework.util.ThreadUtils;

Expand All @@ -26,6 +27,7 @@
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;

/**
* 用于注入 mockMvc 的监听器。
Expand All @@ -50,12 +52,12 @@ public MockMvcListener(int port) {

@Override
public Optional<TestContextConfiguration> config(Class<?> clazz) {
if (!AnnotationUtils.getAnnotation(clazz, EnableMockMvc.class).isPresent()) {
if (AnnotationUtils.getAnnotation(clazz, EnableMockMvc.class).isEmpty()) {
return Optional.empty();
}
TestContextConfiguration configuration = TestContextConfiguration.custom()
.testClass(clazz)
.includeClasses(new Class[] {MockController.class})
.includeClasses(MapBuilder.<Class<?>, Supplier<Object>>get().put(MockController.class, null).build())
.scannedPackages(DEFAULT_SCAN_PACKAGES)
.build();
return Optional.of(configuration);
Expand All @@ -64,7 +66,7 @@ public Optional<TestContextConfiguration> config(Class<?> clazz) {
@Override
public void beforeTestClass(TestContext context) {
Class<?> testClass = context.testClass();
if (!AnnotationUtils.getAnnotation(testClass, EnableMockMvc.class).isPresent()) {
if (AnnotationUtils.getAnnotation(testClass, EnableMockMvc.class).isEmpty()) {
return;
}
MockMvc mockMvc = new MockMvc(this.port);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,26 @@

package modelengine.fitframework.test.domain.listener;

import modelengine.fitframework.plugin.Plugin;
import modelengine.fitframework.test.annotation.Sql;
import modelengine.fitframework.test.domain.TestContext;
import modelengine.fitframework.test.domain.resolver.TestContextConfiguration;
import modelengine.fitframework.util.IoUtils;

import java.io.IOException;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.List;
import java.util.Optional;

import javax.sql.DataSource;

/**
* 用于执行 SQL 脚本。
*
* @author 易文渊
* @author 季聿阶
* @since 2024-07-21
*/
public class SqlExecuteListener implements TestListener {
Expand All @@ -29,34 +34,66 @@ public class SqlExecuteListener implements TestListener {
private Sql globalSql;

@Override
public void beforeTestClass(TestContext context) {
Class<?> testClass = context.testClass();
this.globalSql = testClass.getAnnotation(Sql.class);
public Optional<TestContextConfiguration> config(Class<?> clazz) {
this.globalSql = clazz.getAnnotation(Sql.class);
if (this.globalSql == null) {
return Optional.empty();
}
TestContextConfiguration configuration =
TestContextConfiguration.custom().testClass(clazz).actions(List.of(this::executeAction)).build();
return Optional.of(configuration);
}

private void executeAction(Plugin plugin) {
if (this.globalSql == null) {
return;
}
executeSql(plugin, this.globalSql, Sql.Position.BEFORE);
}

@Override
public void beforeTestMethod(TestContext context) {
execSql(globalSql, context);
execMethodSql(context);
execMethodSql(context, Sql.Position.BEFORE);
}

@Override
public void afterTestMethod(TestContext context) {
execMethodSql(context, Sql.Position.AFTER);
}

private static void execMethodSql(TestContext context) {
private static void execMethodSql(TestContext context, Sql.Position position) {
Method method = context.testMethod();
Sql sql = method.getAnnotation(Sql.class);
execSql(sql, context);
executeSql(context.plugin(), sql, position);
}

@Override
public void afterTestClass(TestContext context) {
Class<?> testClass = context.testClass();
Sql sql = testClass.getAnnotation(Sql.class);
executeSql(context.plugin(), sql, Sql.Position.AFTER);
}

private static void execSql(Sql sql, TestContext context) {
private static void executeSql(Plugin plugin, Sql sql, Sql.Position position) {
if (sql == null) {
return;
}
DataSource dataSource = context.plugin().container().beans().get(DataSource.class);
DataSource dataSource = plugin.container().beans().get(DataSource.class);
try (Connection connection = dataSource.getConnection()) {
for (String script : sql.scripts()) {
String[] scripts = getScripts(sql, position);
for (String script : scripts) {
connection.createStatement().execute(IoUtils.content(CLASS_LOADER, script));
}
} catch (SQLException | IOException e) {
throw new IllegalStateException("Fail to execute sql.", e);
throw new IllegalStateException("Failed to execute sql.", e);
}
}

private static String[] getScripts(Sql sql, Sql.Position position) {
if (position == Sql.Position.BEFORE) {
return sql.before();
} else {
return sql.after();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* 默认的单测类解析器。
Expand All @@ -30,8 +32,7 @@
* @since 2023-01-17
*/
public class DefaultTestClassResolver implements TestClassResolver {
private static final Set<String> DEFAULT_SCAN_PACKAGES = new HashSet<>(Arrays.asList(
"modelengine.fit.value",
private static final Set<String> DEFAULT_SCAN_PACKAGES = new HashSet<>(Arrays.asList("modelengine.fit.value",
"modelengine.fit.serialization",
"modelengine.fitframework.validation"));

Expand All @@ -41,7 +42,8 @@ public TestContextConfiguration resolve(Class<?> clazz) {
Class<?>[] includeClasses = this.resolveIncludeClasses(testConfigurationClass);
return TestContextConfiguration.custom()
.testClass(clazz)
.includeClasses(includeClasses)
.includeClasses(Stream.of(includeClasses)
.collect(Collectors.toMap(Function.identity(), key -> () -> null)))
.excludeClasses(this.resolveExcludeClasses(clazz))
.scannedPackages(this.scanBeans(includeClasses))
.mockedBeanFields(this.scanMockBeansFieldSet(clazz))
Expand Down Expand Up @@ -86,7 +88,7 @@ private Set<String> scanBeans(Class<?>[] classes) {

private Set<String> getBasePackages(Class<?> clazz) {
Optional<ScanPackages> opScanPackagesAnnotation = AnnotationUtils.getAnnotation(clazz, ScanPackages.class);
if (!opScanPackagesAnnotation.isPresent()) {
if (opScanPackagesAnnotation.isEmpty()) {
return new HashSet<>();
}
Set<String> basePackages = new HashSet<>(Arrays.asList(opScanPackagesAnnotation.get().value()));
Expand Down
Loading