fix(security/crypto): 修复处理 MP Wrapper 时 无法加密的情况 (#4)

This commit is contained in:
cary
2024-06-18 14:14:55 +08:00
committed by GitHub
parent 8d00ae32ce
commit a235a6ea8b
3 changed files with 120 additions and 5 deletions

View File

@@ -24,5 +24,6 @@
<groupId>com.baomidou</groupId>
<artifactId>mybatis-plus-core</artifactId>
</dependency>
</dependencies>
</project>

View File

@@ -53,7 +53,19 @@ public abstract class AbstractMyBatisInterceptor implements Interceptor {
* @return 加密参数
*/
public Map<String, FieldEncrypt> getEncryptParams(String mappedStatementId) {
return ENCRYPT_PARAM_CACHE.computeIfAbsent(mappedStatementId, this::getEncryptParamsNoCached);
return getEncryptParams(mappedStatementId, null);
}
/**
* 获取加密参数
*
* @param mappedStatementId 映射语句 ID
* @param parameterIndex 参数数量
* @return 加密参数
*/
public Map<String, FieldEncrypt> getEncryptParams(String mappedStatementId, Integer parameterIndex) {
return ENCRYPT_PARAM_CACHE
.computeIfAbsent(mappedStatementId, it -> getEncryptParamsNoCached(mappedStatementId, parameterIndex));
}
/**
@@ -105,9 +117,10 @@ public abstract class AbstractMyBatisInterceptor implements Interceptor {
* 获取参数列表(无缓存)
*
* @param mappedStatementId 映射语句 ID
* @param parameterIndex 参数数量
* @return 参数列表
*/
private Map<String, FieldEncrypt> getEncryptParamsNoCached(String mappedStatementId) {
private Map<String, FieldEncrypt> getEncryptParamsNoCached(String mappedStatementId, Integer parameterIndex) {
try {
String className = CharSequenceUtil.subBefore(mappedStatementId, StringConstants.DOT, true);
String wrapperMethodName = CharSequenceUtil.subAfter(mappedStatementId, StringConstants.DOT, true);
@@ -117,8 +130,12 @@ public abstract class AbstractMyBatisInterceptor implements Interceptor {
.map(suffix -> wrapperMethodName.substring(0, wrapperMethodName.length() - suffix.length()))
.orElse(wrapperMethodName);
// 获取真实方法
Optional<Method> methodOptional = Arrays.stream(ReflectUtil.getMethods(Class
.forName(className), m -> Objects.equals(m.getName(), methodName))).findFirst();
Optional<Method> methodOptional = Arrays.stream(ReflectUtil.getMethods(Class.forName(className), m -> {
if (Objects.nonNull(parameterIndex)) {
return Objects.equals(m.getName(), methodName) && m.getParameterCount() == parameterIndex;
}
return Objects.equals(m.getName(), methodName);
})).findFirst();
if (methodOptional.isEmpty()) {
return Collections.emptyMap();
}
@@ -136,6 +153,8 @@ public abstract class AbstractMyBatisInterceptor implements Interceptor {
}
} else if (parameterName.startsWith(Constants.ENTITY)) {
map.put(parameterName, null);
} else if (parameterName.startsWith(Constants.WRAPPER)) {
map.put(parameterName, null);
}
}
return map;

View File

@@ -16,9 +16,17 @@
package top.continew.starter.security.crypto.core;
import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.ReflectUtil;
import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
import com.baomidou.mybatisplus.core.conditions.Wrapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.baomidou.mybatisplus.core.metadata.TableFieldInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
import com.baomidou.mybatisplus.core.toolkit.Constants;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
@@ -28,11 +36,14 @@ import org.apache.ibatis.plugin.*;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.type.SimpleTypeRegistry;
import top.continew.starter.core.constant.StringConstants;
import top.continew.starter.security.crypto.annotation.FieldEncrypt;
import top.continew.starter.security.crypto.autoconfigure.CryptoProperties;
import top.continew.starter.security.crypto.encryptor.IEncryptor;
import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.*;
/**
@@ -99,13 +110,18 @@ public class MyBatisEncryptInterceptor extends AbstractMyBatisInterceptor {
* @throws Exception /
*/
private void encryptMap(HashMap<String, Object> parameterMap, MappedStatement mappedStatement) throws Exception {
Map<String, FieldEncrypt> encryptParamMap = super.getEncryptParams(mappedStatement.getId());
Map<String, FieldEncrypt> encryptParamMap = super.getEncryptParams(mappedStatement.getId(), parameterMap
.isEmpty() ? null : parameterMap.size() / 2);
for (Map.Entry<String, FieldEncrypt> encryptParamEntry : encryptParamMap.entrySet()) {
String parameterName = encryptParamEntry.getKey();
if (parameterName.startsWith(Constants.ENTITY)) {
// 兼容 MyBatis Plus 封装的 update 相关方法updateById、update
Object entity = parameterMap.getOrDefault(parameterName, null);
this.doEncrypt(this.getEncryptFields(entity), entity);
} else if (parameterName.startsWith(Constants.WRAPPER)) {
Wrapper wrapper = (Wrapper)parameterMap.getOrDefault(parameterName, null);
// 处理 wrapper 的情况
handleWrapperEncrypt(wrapper, mappedStatement);
} else {
FieldEncrypt fieldEncrypt = encryptParamEntry.getValue();
parameterMap.put(parameterName, this.doEncrypt(parameterMap.get(parameterName), fieldEncrypt));
@@ -148,4 +164,83 @@ public class MyBatisEncryptInterceptor extends AbstractMyBatisInterceptor {
ReflectUtil.setFieldValue(entity, field, ciphertext);
}
}
/**
* 处理 wrapper 的加密情况
*
* @param wrapper wrapper 对象
* @param mappedStatement 映射语句
* @throws Exception /
*/
private void handleWrapperEncrypt(Wrapper wrapper, MappedStatement mappedStatement) throws Exception {
if (wrapper instanceof AbstractWrapper abstractWrapper) {
String sqlSet = abstractWrapper.getSqlSet();
if (StringUtils.isEmpty(sqlSet)) {
return;
}
String className = CharSequenceUtil.subBefore(mappedStatement.getId(), StringConstants.DOT, true);
Class<?> mapperClass = Class.forName(className);
Optional<Class> baseMapperGenerics = getDoByMapperClass(mapperClass, Optional.empty());
// 获取不到泛型对象 则不进行下面的逻辑
if (baseMapperGenerics.isEmpty()) {
return;
}
TableInfo tableInfo = TableInfoHelper.getTableInfo(baseMapperGenerics.get());
List<TableFieldInfo> fieldList = tableInfo.getFieldList();
// 将 name=#{ew.paramNameValuePairs.xxx},age=#{ew.paramNameValuePairs.xxx} 切出来
for (String sqlFragment : sqlSet.split(Constants.COMMA)) {
String columnName = sqlFragment.split(Constants.EQUALS)[0];
// 截取其中的 xxx 字符 #{ew.paramNameValuePairs.xxx}
String paramNameVal = sqlFragment.split(Constants.EQUALS)[1].substring(25, sqlFragment
.split(Constants.EQUALS)[1].length() - 1);
Optional<TableFieldInfo> fieldInfo = fieldList.stream()
.filter(f -> f.getColumn().equals(columnName))
.findAny();
if (fieldInfo.isPresent()) {
TableFieldInfo tableFieldInfo = fieldInfo.get();
FieldEncrypt fieldEncrypt = tableFieldInfo.getField().getAnnotation(FieldEncrypt.class);
if (fieldEncrypt != null) {
Map<String, Object> paramNameValuePairs = abstractWrapper.getParamNameValuePairs();
Object o = paramNameValuePairs.get(paramNameVal);
paramNameValuePairs.put(paramNameVal, this.doEncrypt(o, fieldEncrypt));
}
}
}
}
}
/**
* 从 Mapper 获取泛型
*
* @param mapperClass mapper class
* @param tempResult 临时存储的泛型对象
* @return domain 对象
*/
private static Optional<Class> getDoByMapperClass(Class<?> mapperClass, Optional<Class> tempResult) {
Type[] genericInterfaces = mapperClass.getGenericInterfaces();
Optional<Class> result = tempResult;
for (Type genericInterface : genericInterfaces) {
if (genericInterface instanceof ParameterizedType parameterizedType) {
Type rawType = parameterizedType.getRawType();
Type[] actualTypeArguments = parameterizedType.getActualTypeArguments();
// 如果匹配上 BaseMapper 且泛型参数是 Class 类型,则直接返回
if (rawType.equals(BaseMapper.class)) {
return actualTypeArguments[0] instanceof Class ?
Optional.of((Class) actualTypeArguments[0]) : result;
} else if (rawType instanceof Class interfaceClass) {
// 如果泛型参数是 Class 类型,则传递给递归调用
if (actualTypeArguments[0] instanceof Class tempResultClass) {
result = Optional.of(tempResultClass);
}
// 递归调用,继续查找
Optional<Class> innerResult = getDoByMapperClass(interfaceClass, result);
if (innerResult.isPresent()) {
return innerResult;
}
}
}
}
// 如果没有找到,返回传递进来的 tempResult
return Optional.empty();
}
}