fix(security/crypto): 修复 updateById 修改未正确加密的问题

This commit is contained in:
2024-08-13 23:59:49 +08:00
parent ea6b316296
commit b0a2a8c927
2 changed files with 96 additions and 64 deletions

View File

@@ -22,7 +22,9 @@ import cn.hutool.core.util.ReflectUtil;
import cn.hutool.extra.spring.SpringUtil; import cn.hutool.extra.spring.SpringUtil;
import com.baomidou.mybatisplus.core.toolkit.Constants; import com.baomidou.mybatisplus.core.toolkit.Constants;
import org.apache.ibatis.annotations.Param; import org.apache.ibatis.annotations.Param;
import org.apache.ibatis.plugin.*; import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.Interceptor;
import top.continew.starter.core.constant.StringConstants; import top.continew.starter.core.constant.StringConstants;
import top.continew.starter.core.exception.BusinessException; import top.continew.starter.core.exception.BusinessException;
import top.continew.starter.security.crypto.annotation.FieldEncrypt; import top.continew.starter.security.crypto.annotation.FieldEncrypt;
@@ -46,39 +48,6 @@ public abstract class AbstractMyBatisInterceptor implements Interceptor {
private static final Map<String, Map<String, FieldEncrypt>> ENCRYPT_PARAM_CACHE = new ConcurrentHashMap<>(); private static final Map<String, Map<String, FieldEncrypt>> ENCRYPT_PARAM_CACHE = new ConcurrentHashMap<>();
/**
* 获取加密参数
*
* @param mappedStatementId 映射语句 ID
* @return 加密参数
*/
public Map<String, FieldEncrypt> getEncryptParams(String mappedStatementId) {
return getEncryptParams(mappedStatementId, null);
}
/**
* 获取加密参数
*
* @param mappedStatementId 映射语句 ID
* @param parameterIndex 参数索引
* @return 加密参数
*/
public Map<String, FieldEncrypt> getEncryptParams(String mappedStatementId, Integer parameterIndex) {
return ENCRYPT_PARAM_CACHE
.computeIfAbsent(mappedStatementId, key -> getEncryptParamsNoCached(mappedStatementId, parameterIndex));
}
/**
* 获取参数名称
*
* @param parameter 参数
* @return 参数名称
*/
public String getParameterName(Parameter parameter) {
Param param = parameter.getAnnotation(Param.class);
return null != param ? param.value() : parameter.getName();
}
/** /**
* 获取所有字符串类型、需要加/解密的、有值字段 * 获取所有字符串类型、需要加/解密的、有值字段
* *
@@ -114,13 +83,67 @@ public abstract class AbstractMyBatisInterceptor implements Interceptor {
} }
/** /**
* 获取参数列表(无缓存) * 获取加密参数
*
* @param mappedStatement 映射语句
* @return 加密参数
*/
public Map<String, FieldEncrypt> getEncryptParams(MappedStatement mappedStatement) {
return getEncryptParams(mappedStatement, null);
}
/**
* 获取加密参数
*
* @param mappedStatement 映射语句
* @param parameterCount 参数数量
* @return 加密参数
*/
public Map<String, FieldEncrypt> getEncryptParams(MappedStatement mappedStatement, Integer parameterCount) {
String mappedStatementId = mappedStatement.getId();
SqlCommandType sqlCommandType = mappedStatement.getSqlCommandType();
if (SqlCommandType.UPDATE != sqlCommandType) {
return ENCRYPT_PARAM_CACHE.computeIfAbsent(mappedStatementId, key -> this
.getEncryptParams(mappedStatementId, parameterCount));
} else {
return this.getEncryptParams(mappedStatementId, parameterCount);
}
}
/**
* 获取参数名称
*
* @param parameter 参数
* @return 参数名称
*/
public String getParameterName(Parameter parameter) {
Param param = parameter.getAnnotation(Param.class);
return null != param ? param.value() : parameter.getName();
}
/**
* 获取加密参数列表
* *
* @param mappedStatementId 映射语句 ID * @param mappedStatementId 映射语句 ID
* @param parameterIndex 参数数量 * @param parameterCount 参数数量
* @return 参数列表 * @return 加密参数列表
*/ */
private Map<String, FieldEncrypt> getEncryptParamsNoCached(String mappedStatementId, Integer parameterIndex) { private Map<String, FieldEncrypt> getEncryptParams(String mappedStatementId, Integer parameterCount) {
Method method = this.getMethod(mappedStatementId, parameterCount);
if (method == null) {
return Collections.emptyMap();
}
return this.getEncryptParams(method);
}
/**
* 获取映射方法
*
* @param mappedStatementId 映射语句 ID
* @param parameterCount 参数数量
* @return 映射方法
*/
private Method getMethod(String mappedStatementId, Integer parameterCount) {
try { try {
String className = CharSequenceUtil.subBefore(mappedStatementId, StringConstants.DOT, true); String className = CharSequenceUtil.subBefore(mappedStatementId, StringConstants.DOT, true);
String wrapperMethodName = CharSequenceUtil.subAfter(mappedStatementId, StringConstants.DOT, true); String wrapperMethodName = CharSequenceUtil.subAfter(mappedStatementId, StringConstants.DOT, true);
@@ -131,17 +154,27 @@ public abstract class AbstractMyBatisInterceptor implements Interceptor {
.orElse(wrapperMethodName); .orElse(wrapperMethodName);
// 获取真实方法 // 获取真实方法
Optional<Method> methodOptional = Arrays.stream(ReflectUtil.getMethods(Class.forName(className), m -> { Optional<Method> methodOptional = Arrays.stream(ReflectUtil.getMethods(Class.forName(className), m -> {
if (Objects.nonNull(parameterIndex)) { if (parameterCount != null) {
return Objects.equals(m.getName(), methodName) && m.getParameterCount() == parameterIndex; return Objects.equals(m.getName(), methodName) && m.getParameterCount() == parameterCount;
} }
return Objects.equals(m.getName(), methodName); return Objects.equals(m.getName(), methodName);
})).findFirst(); })).findFirst();
if (methodOptional.isEmpty()) { return methodOptional.orElse(null);
return Collections.emptyMap(); } catch (ClassNotFoundException e) {
throw new BusinessException(e.getMessage());
} }
}
/**
* 获取加密参数列表
*
* @param method 方法
* @return 加密参数列表
*/
private Map<String, FieldEncrypt> getEncryptParams(Method method) {
// 获取方法中的加密参数 // 获取方法中的加密参数
Map<String, FieldEncrypt> map = MapUtil.newHashMap(); Map<String, FieldEncrypt> map = MapUtil.newHashMap();
Parameter[] parameterArr = methodOptional.get().getParameters(); Parameter[] parameterArr = method.getParameters();
for (int i = 0; i < parameterArr.length; i++) { for (int i = 0; i < parameterArr.length; i++) {
Parameter parameter = parameterArr[i]; Parameter parameter = parameterArr[i];
String parameterName = this.getParameterName(parameter); String parameterName = this.getParameterName(parameter);
@@ -158,8 +191,5 @@ public abstract class AbstractMyBatisInterceptor implements Interceptor {
} }
} }
return map; return map;
} catch (ClassNotFoundException e) {
throw new BusinessException(e.getMessage());
}
} }
} }

View File

@@ -109,8 +109,10 @@ public class MyBatisEncryptInterceptor extends AbstractMyBatisInterceptor {
* @throws Exception / * @throws Exception /
*/ */
private void encryptMap(HashMap<String, Object> parameterMap, MappedStatement mappedStatement) throws Exception { private void encryptMap(HashMap<String, Object> parameterMap, MappedStatement mappedStatement) throws Exception {
Map<String, FieldEncrypt> encryptParamMap = super.getEncryptParams(mappedStatement.getId(), parameterMap Map<String, FieldEncrypt> encryptParamMap = super.getEncryptParams(mappedStatement);
.isEmpty() ? null : parameterMap.size() / 2); if (encryptParamMap.isEmpty() && !parameterMap.isEmpty()) {
encryptParamMap = super.getEncryptParams(mappedStatement, parameterMap.size() / 2);
}
for (Map.Entry<String, FieldEncrypt> encryptParamEntry : encryptParamMap.entrySet()) { for (Map.Entry<String, FieldEncrypt> encryptParamEntry : encryptParamMap.entrySet()) {
String parameterName = encryptParamEntry.getKey(); String parameterName = encryptParamEntry.getKey();
if (parameterName.startsWith(Constants.ENTITY)) { if (parameterName.startsWith(Constants.ENTITY)) {