From a235a6ea8b574c3f719857bb99d05e874d4e9bd2 Mon Sep 17 00:00:00 2001
From: cary <95016047+hxlcw@users.noreply.github.com>
Date: Tue, 18 Jun 2024 14:14:55 +0800
Subject: [PATCH] =?UTF-8?q?fix(security/crypto):=20=E4=BF=AE=E5=A4=8D?=
=?UTF-8?q?=E5=A4=84=E7=90=86=20MP=20Wrapper=20=E6=97=B6=20=E6=97=A0?=
=?UTF-8?q?=E6=B3=95=E5=8A=A0=E5=AF=86=E7=9A=84=E6=83=85=E5=86=B5=20(#4)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../continew-starter-security-crypto/pom.xml | 1 +
.../core/AbstractMyBatisInterceptor.java | 27 +++++-
.../core/MyBatisEncryptInterceptor.java | 97 ++++++++++++++++++-
3 files changed, 120 insertions(+), 5 deletions(-)
diff --git a/continew-starter-security/continew-starter-security-crypto/pom.xml b/continew-starter-security/continew-starter-security-crypto/pom.xml
index b08549a4..2ac25442 100644
--- a/continew-starter-security/continew-starter-security-crypto/pom.xml
+++ b/continew-starter-security/continew-starter-security-crypto/pom.xml
@@ -24,5 +24,6 @@
com.baomidou
mybatis-plus-core
+
\ No newline at end of file
diff --git a/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/AbstractMyBatisInterceptor.java b/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/AbstractMyBatisInterceptor.java
index e2f97ce0..e51eb845 100644
--- a/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/AbstractMyBatisInterceptor.java
+++ b/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/AbstractMyBatisInterceptor.java
@@ -53,7 +53,19 @@ public abstract class AbstractMyBatisInterceptor implements Interceptor {
* @return 加密参数
*/
public Map getEncryptParams(String mappedStatementId) {
- return ENCRYPT_PARAM_CACHE.computeIfAbsent(mappedStatementId, this::getEncryptParamsNoCached);
+ return getEncryptParams(mappedStatementId, null);
+ }
+
+ /**
+ * 获取加密参数
+ *
+ * @param mappedStatementId 映射语句 ID
+ * @param parameterIndex 参数数量
+ * @return 加密参数
+ */
+ public Map getEncryptParams(String mappedStatementId, Integer parameterIndex) {
+ return ENCRYPT_PARAM_CACHE
+ .computeIfAbsent(mappedStatementId, it -> getEncryptParamsNoCached(mappedStatementId, parameterIndex));
}
/**
@@ -105,9 +117,10 @@ public abstract class AbstractMyBatisInterceptor implements Interceptor {
* 获取参数列表(无缓存)
*
* @param mappedStatementId 映射语句 ID
+ * @param parameterIndex 参数数量
* @return 参数列表
*/
- private Map getEncryptParamsNoCached(String mappedStatementId) {
+ private Map getEncryptParamsNoCached(String mappedStatementId, Integer parameterIndex) {
try {
String className = CharSequenceUtil.subBefore(mappedStatementId, StringConstants.DOT, true);
String wrapperMethodName = CharSequenceUtil.subAfter(mappedStatementId, StringConstants.DOT, true);
@@ -117,8 +130,12 @@ public abstract class AbstractMyBatisInterceptor implements Interceptor {
.map(suffix -> wrapperMethodName.substring(0, wrapperMethodName.length() - suffix.length()))
.orElse(wrapperMethodName);
// 获取真实方法
- Optional methodOptional = Arrays.stream(ReflectUtil.getMethods(Class
- .forName(className), m -> Objects.equals(m.getName(), methodName))).findFirst();
+ Optional methodOptional = Arrays.stream(ReflectUtil.getMethods(Class.forName(className), m -> {
+ if (Objects.nonNull(parameterIndex)) {
+ return Objects.equals(m.getName(), methodName) && m.getParameterCount() == parameterIndex;
+ }
+ return Objects.equals(m.getName(), methodName);
+ })).findFirst();
if (methodOptional.isEmpty()) {
return Collections.emptyMap();
}
@@ -136,6 +153,8 @@ public abstract class AbstractMyBatisInterceptor implements Interceptor {
}
} else if (parameterName.startsWith(Constants.ENTITY)) {
map.put(parameterName, null);
+ } else if (parameterName.startsWith(Constants.WRAPPER)) {
+ map.put(parameterName, null);
}
}
return map;
diff --git a/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/MyBatisEncryptInterceptor.java b/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/MyBatisEncryptInterceptor.java
index f906c7ae..9b99a360 100644
--- a/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/MyBatisEncryptInterceptor.java
+++ b/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/MyBatisEncryptInterceptor.java
@@ -16,9 +16,17 @@
package top.continew.starter.security.crypto.core;
+import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.ReflectUtil;
+import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
+import com.baomidou.mybatisplus.core.conditions.Wrapper;
+import com.baomidou.mybatisplus.core.mapper.BaseMapper;
+import com.baomidou.mybatisplus.core.metadata.TableFieldInfo;
+import com.baomidou.mybatisplus.core.metadata.TableInfo;
+import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
import com.baomidou.mybatisplus.core.toolkit.Constants;
+import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
@@ -28,11 +36,14 @@ import org.apache.ibatis.plugin.*;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.type.SimpleTypeRegistry;
+import top.continew.starter.core.constant.StringConstants;
import top.continew.starter.security.crypto.annotation.FieldEncrypt;
import top.continew.starter.security.crypto.autoconfigure.CryptoProperties;
import top.continew.starter.security.crypto.encryptor.IEncryptor;
import java.lang.reflect.Field;
+import java.lang.reflect.ParameterizedType;
+import java.lang.reflect.Type;
import java.util.*;
/**
@@ -99,13 +110,18 @@ public class MyBatisEncryptInterceptor extends AbstractMyBatisInterceptor {
* @throws Exception /
*/
private void encryptMap(HashMap parameterMap, MappedStatement mappedStatement) throws Exception {
- Map encryptParamMap = super.getEncryptParams(mappedStatement.getId());
+ Map encryptParamMap = super.getEncryptParams(mappedStatement.getId(), parameterMap
+ .isEmpty() ? null : parameterMap.size() / 2);
for (Map.Entry encryptParamEntry : encryptParamMap.entrySet()) {
String parameterName = encryptParamEntry.getKey();
if (parameterName.startsWith(Constants.ENTITY)) {
// 兼容 MyBatis Plus 封装的 update 相关方法,updateById、update
Object entity = parameterMap.getOrDefault(parameterName, null);
this.doEncrypt(this.getEncryptFields(entity), entity);
+ } else if (parameterName.startsWith(Constants.WRAPPER)) {
+ Wrapper wrapper = (Wrapper)parameterMap.getOrDefault(parameterName, null);
+ // 处理 wrapper 的情况
+ handleWrapperEncrypt(wrapper, mappedStatement);
} else {
FieldEncrypt fieldEncrypt = encryptParamEntry.getValue();
parameterMap.put(parameterName, this.doEncrypt(parameterMap.get(parameterName), fieldEncrypt));
@@ -148,4 +164,83 @@ public class MyBatisEncryptInterceptor extends AbstractMyBatisInterceptor {
ReflectUtil.setFieldValue(entity, field, ciphertext);
}
}
+
+ /**
+ * 处理 wrapper 的加密情况
+ *
+ * @param wrapper wrapper 对象
+ * @param mappedStatement 映射语句
+ * @throws Exception /
+ */
+ private void handleWrapperEncrypt(Wrapper wrapper, MappedStatement mappedStatement) throws Exception {
+ if (wrapper instanceof AbstractWrapper abstractWrapper) {
+ String sqlSet = abstractWrapper.getSqlSet();
+ if (StringUtils.isEmpty(sqlSet)) {
+ return;
+ }
+ String className = CharSequenceUtil.subBefore(mappedStatement.getId(), StringConstants.DOT, true);
+ Class> mapperClass = Class.forName(className);
+ Optional baseMapperGenerics = getDoByMapperClass(mapperClass, Optional.empty());
+ // 获取不到泛型对象 则不进行下面的逻辑
+ if (baseMapperGenerics.isEmpty()) {
+ return;
+ }
+ TableInfo tableInfo = TableInfoHelper.getTableInfo(baseMapperGenerics.get());
+ List fieldList = tableInfo.getFieldList();
+ // 将 name=#{ew.paramNameValuePairs.xxx},age=#{ew.paramNameValuePairs.xxx} 切出来
+ for (String sqlFragment : sqlSet.split(Constants.COMMA)) {
+ String columnName = sqlFragment.split(Constants.EQUALS)[0];
+ // 截取其中的 xxx 字符 :#{ew.paramNameValuePairs.xxx}
+ String paramNameVal = sqlFragment.split(Constants.EQUALS)[1].substring(25, sqlFragment
+ .split(Constants.EQUALS)[1].length() - 1);
+ Optional fieldInfo = fieldList.stream()
+ .filter(f -> f.getColumn().equals(columnName))
+ .findAny();
+ if (fieldInfo.isPresent()) {
+ TableFieldInfo tableFieldInfo = fieldInfo.get();
+ FieldEncrypt fieldEncrypt = tableFieldInfo.getField().getAnnotation(FieldEncrypt.class);
+ if (fieldEncrypt != null) {
+ Map paramNameValuePairs = abstractWrapper.getParamNameValuePairs();
+ Object o = paramNameValuePairs.get(paramNameVal);
+ paramNameValuePairs.put(paramNameVal, this.doEncrypt(o, fieldEncrypt));
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * 从 Mapper 获取泛型
+ *
+ * @param mapperClass mapper class
+ * @param tempResult 临时存储的泛型对象
+ * @return domain 对象
+ */
+ private static Optional getDoByMapperClass(Class> mapperClass, Optional tempResult) {
+ Type[] genericInterfaces = mapperClass.getGenericInterfaces();
+ Optional result = tempResult;
+ for (Type genericInterface : genericInterfaces) {
+ if (genericInterface instanceof ParameterizedType parameterizedType) {
+ Type rawType = parameterizedType.getRawType();
+ Type[] actualTypeArguments = parameterizedType.getActualTypeArguments();
+ // 如果匹配上 BaseMapper 且泛型参数是 Class 类型,则直接返回
+ if (rawType.equals(BaseMapper.class)) {
+ return actualTypeArguments[0] instanceof Class ?
+ Optional.of((Class) actualTypeArguments[0]) : result;
+ } else if (rawType instanceof Class interfaceClass) {
+ // 如果泛型参数是 Class 类型,则传递给递归调用
+ if (actualTypeArguments[0] instanceof Class tempResultClass) {
+ result = Optional.of(tempResultClass);
+ }
+ // 递归调用,继续查找
+ Optional innerResult = getDoByMapperClass(interfaceClass, result);
+ if (innerResult.isPresent()) {
+ return innerResult;
+ }
+ }
+ }
+ }
+ // 如果没有找到,返回传递进来的 tempResult
+ return Optional.empty();
+ }
}