From a235a6ea8b574c3f719857bb99d05e874d4e9bd2 Mon Sep 17 00:00:00 2001 From: cary <95016047+hxlcw@users.noreply.github.com> Date: Tue, 18 Jun 2024 14:14:55 +0800 Subject: [PATCH] =?UTF-8?q?fix(security/crypto):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E5=A4=84=E7=90=86=20MP=20Wrapper=20=E6=97=B6=20=E6=97=A0?= =?UTF-8?q?=E6=B3=95=E5=8A=A0=E5=AF=86=E7=9A=84=E6=83=85=E5=86=B5=20(#4)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../continew-starter-security-crypto/pom.xml | 1 + .../core/AbstractMyBatisInterceptor.java | 27 +++++- .../core/MyBatisEncryptInterceptor.java | 97 ++++++++++++++++++- 3 files changed, 120 insertions(+), 5 deletions(-) diff --git a/continew-starter-security/continew-starter-security-crypto/pom.xml b/continew-starter-security/continew-starter-security-crypto/pom.xml index b08549a4..2ac25442 100644 --- a/continew-starter-security/continew-starter-security-crypto/pom.xml +++ b/continew-starter-security/continew-starter-security-crypto/pom.xml @@ -24,5 +24,6 @@ com.baomidou mybatis-plus-core + \ No newline at end of file diff --git a/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/AbstractMyBatisInterceptor.java b/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/AbstractMyBatisInterceptor.java index e2f97ce0..e51eb845 100644 --- a/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/AbstractMyBatisInterceptor.java +++ b/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/AbstractMyBatisInterceptor.java @@ -53,7 +53,19 @@ public abstract class AbstractMyBatisInterceptor implements Interceptor { * @return 加密参数 */ public Map getEncryptParams(String mappedStatementId) { - return ENCRYPT_PARAM_CACHE.computeIfAbsent(mappedStatementId, this::getEncryptParamsNoCached); + return getEncryptParams(mappedStatementId, null); + } + + /** + * 获取加密参数 + * + * @param mappedStatementId 映射语句 ID + * @param parameterIndex 参数数量 + * @return 加密参数 + */ + public Map getEncryptParams(String mappedStatementId, Integer parameterIndex) { + return ENCRYPT_PARAM_CACHE + .computeIfAbsent(mappedStatementId, it -> getEncryptParamsNoCached(mappedStatementId, parameterIndex)); } /** @@ -105,9 +117,10 @@ public abstract class AbstractMyBatisInterceptor implements Interceptor { * 获取参数列表(无缓存) * * @param mappedStatementId 映射语句 ID + * @param parameterIndex 参数数量 * @return 参数列表 */ - private Map getEncryptParamsNoCached(String mappedStatementId) { + private Map getEncryptParamsNoCached(String mappedStatementId, Integer parameterIndex) { try { String className = CharSequenceUtil.subBefore(mappedStatementId, StringConstants.DOT, true); String wrapperMethodName = CharSequenceUtil.subAfter(mappedStatementId, StringConstants.DOT, true); @@ -117,8 +130,12 @@ public abstract class AbstractMyBatisInterceptor implements Interceptor { .map(suffix -> wrapperMethodName.substring(0, wrapperMethodName.length() - suffix.length())) .orElse(wrapperMethodName); // 获取真实方法 - Optional methodOptional = Arrays.stream(ReflectUtil.getMethods(Class - .forName(className), m -> Objects.equals(m.getName(), methodName))).findFirst(); + Optional methodOptional = Arrays.stream(ReflectUtil.getMethods(Class.forName(className), m -> { + if (Objects.nonNull(parameterIndex)) { + return Objects.equals(m.getName(), methodName) && m.getParameterCount() == parameterIndex; + } + return Objects.equals(m.getName(), methodName); + })).findFirst(); if (methodOptional.isEmpty()) { return Collections.emptyMap(); } @@ -136,6 +153,8 @@ public abstract class AbstractMyBatisInterceptor implements Interceptor { } } else if (parameterName.startsWith(Constants.ENTITY)) { map.put(parameterName, null); + } else if (parameterName.startsWith(Constants.WRAPPER)) { + map.put(parameterName, null); } } return map; diff --git a/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/MyBatisEncryptInterceptor.java b/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/MyBatisEncryptInterceptor.java index f906c7ae..9b99a360 100644 --- a/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/MyBatisEncryptInterceptor.java +++ b/continew-starter-security/continew-starter-security-crypto/src/main/java/top/continew/starter/security/crypto/core/MyBatisEncryptInterceptor.java @@ -16,9 +16,17 @@ package top.continew.starter.security.crypto.core; +import cn.hutool.core.text.CharSequenceUtil; import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.ReflectUtil; +import com.baomidou.mybatisplus.core.conditions.AbstractWrapper; +import com.baomidou.mybatisplus.core.conditions.Wrapper; +import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import com.baomidou.mybatisplus.core.metadata.TableFieldInfo; +import com.baomidou.mybatisplus.core.metadata.TableInfo; +import com.baomidou.mybatisplus.core.metadata.TableInfoHelper; import com.baomidou.mybatisplus.core.toolkit.Constants; +import com.baomidou.mybatisplus.core.toolkit.StringUtils; import org.apache.ibatis.cache.CacheKey; import org.apache.ibatis.executor.Executor; import org.apache.ibatis.mapping.BoundSql; @@ -28,11 +36,14 @@ import org.apache.ibatis.plugin.*; import org.apache.ibatis.session.ResultHandler; import org.apache.ibatis.session.RowBounds; import org.apache.ibatis.type.SimpleTypeRegistry; +import top.continew.starter.core.constant.StringConstants; import top.continew.starter.security.crypto.annotation.FieldEncrypt; import top.continew.starter.security.crypto.autoconfigure.CryptoProperties; import top.continew.starter.security.crypto.encryptor.IEncryptor; import java.lang.reflect.Field; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; import java.util.*; /** @@ -99,13 +110,18 @@ public class MyBatisEncryptInterceptor extends AbstractMyBatisInterceptor { * @throws Exception / */ private void encryptMap(HashMap parameterMap, MappedStatement mappedStatement) throws Exception { - Map encryptParamMap = super.getEncryptParams(mappedStatement.getId()); + Map encryptParamMap = super.getEncryptParams(mappedStatement.getId(), parameterMap + .isEmpty() ? null : parameterMap.size() / 2); for (Map.Entry encryptParamEntry : encryptParamMap.entrySet()) { String parameterName = encryptParamEntry.getKey(); if (parameterName.startsWith(Constants.ENTITY)) { // 兼容 MyBatis Plus 封装的 update 相关方法,updateById、update Object entity = parameterMap.getOrDefault(parameterName, null); this.doEncrypt(this.getEncryptFields(entity), entity); + } else if (parameterName.startsWith(Constants.WRAPPER)) { + Wrapper wrapper = (Wrapper)parameterMap.getOrDefault(parameterName, null); + // 处理 wrapper 的情况 + handleWrapperEncrypt(wrapper, mappedStatement); } else { FieldEncrypt fieldEncrypt = encryptParamEntry.getValue(); parameterMap.put(parameterName, this.doEncrypt(parameterMap.get(parameterName), fieldEncrypt)); @@ -148,4 +164,83 @@ public class MyBatisEncryptInterceptor extends AbstractMyBatisInterceptor { ReflectUtil.setFieldValue(entity, field, ciphertext); } } + + /** + * 处理 wrapper 的加密情况 + * + * @param wrapper wrapper 对象 + * @param mappedStatement 映射语句 + * @throws Exception / + */ + private void handleWrapperEncrypt(Wrapper wrapper, MappedStatement mappedStatement) throws Exception { + if (wrapper instanceof AbstractWrapper abstractWrapper) { + String sqlSet = abstractWrapper.getSqlSet(); + if (StringUtils.isEmpty(sqlSet)) { + return; + } + String className = CharSequenceUtil.subBefore(mappedStatement.getId(), StringConstants.DOT, true); + Class mapperClass = Class.forName(className); + Optional baseMapperGenerics = getDoByMapperClass(mapperClass, Optional.empty()); + // 获取不到泛型对象 则不进行下面的逻辑 + if (baseMapperGenerics.isEmpty()) { + return; + } + TableInfo tableInfo = TableInfoHelper.getTableInfo(baseMapperGenerics.get()); + List fieldList = tableInfo.getFieldList(); + // 将 name=#{ew.paramNameValuePairs.xxx},age=#{ew.paramNameValuePairs.xxx} 切出来 + for (String sqlFragment : sqlSet.split(Constants.COMMA)) { + String columnName = sqlFragment.split(Constants.EQUALS)[0]; + // 截取其中的 xxx 字符 :#{ew.paramNameValuePairs.xxx} + String paramNameVal = sqlFragment.split(Constants.EQUALS)[1].substring(25, sqlFragment + .split(Constants.EQUALS)[1].length() - 1); + Optional fieldInfo = fieldList.stream() + .filter(f -> f.getColumn().equals(columnName)) + .findAny(); + if (fieldInfo.isPresent()) { + TableFieldInfo tableFieldInfo = fieldInfo.get(); + FieldEncrypt fieldEncrypt = tableFieldInfo.getField().getAnnotation(FieldEncrypt.class); + if (fieldEncrypt != null) { + Map paramNameValuePairs = abstractWrapper.getParamNameValuePairs(); + Object o = paramNameValuePairs.get(paramNameVal); + paramNameValuePairs.put(paramNameVal, this.doEncrypt(o, fieldEncrypt)); + } + } + } + } + } + + /** + * 从 Mapper 获取泛型 + * + * @param mapperClass mapper class + * @param tempResult 临时存储的泛型对象 + * @return domain 对象 + */ + private static Optional getDoByMapperClass(Class mapperClass, Optional tempResult) { + Type[] genericInterfaces = mapperClass.getGenericInterfaces(); + Optional result = tempResult; + for (Type genericInterface : genericInterfaces) { + if (genericInterface instanceof ParameterizedType parameterizedType) { + Type rawType = parameterizedType.getRawType(); + Type[] actualTypeArguments = parameterizedType.getActualTypeArguments(); + // 如果匹配上 BaseMapper 且泛型参数是 Class 类型,则直接返回 + if (rawType.equals(BaseMapper.class)) { + return actualTypeArguments[0] instanceof Class ? + Optional.of((Class) actualTypeArguments[0]) : result; + } else if (rawType instanceof Class interfaceClass) { + // 如果泛型参数是 Class 类型,则传递给递归调用 + if (actualTypeArguments[0] instanceof Class tempResultClass) { + result = Optional.of(tempResultClass); + } + // 递归调用,继续查找 + Optional innerResult = getDoByMapperClass(interfaceClass, result); + if (innerResult.isPresent()) { + return innerResult; + } + } + } + } + // 如果没有找到,返回传递进来的 tempResult + return Optional.empty(); + } }